rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
12use crate::json_utils::merge;
13use crate::streaming::{StreamingCompletionModel, StreamingResult};
14use crate::{
15    agent::AgentBuilder,
16    completion::{self, CompletionError, CompletionRequest},
17    extractor::ExtractorBuilder,
18    json_utils,
19    message::{self, MessageError},
20    providers::openai::ToolDefinition,
21    transcription::{self, TranscriptionError},
22    OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{json, Value};
28
29// ================================================================
30// Main Groq Client
31// ================================================================
32const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
33
34#[derive(Clone)]
35pub struct Client {
36    base_url: String,
37    http_client: reqwest::Client,
38}
39
40impl Client {
41    /// Create a new Groq client with the given API key.
42    pub fn new(api_key: &str) -> Self {
43        Self::from_url(api_key, GROQ_API_BASE_URL)
44    }
45
46    /// Create a new Groq client with the given API key and base API URL.
47    pub fn from_url(api_key: &str, base_url: &str) -> Self {
48        Self {
49            base_url: base_url.to_string(),
50            http_client: reqwest::Client::builder()
51                .default_headers({
52                    let mut headers = reqwest::header::HeaderMap::new();
53                    headers.insert(
54                        "Authorization",
55                        format!("Bearer {}", api_key)
56                            .parse()
57                            .expect("Bearer token should parse"),
58                    );
59                    headers
60                })
61                .build()
62                .expect("Groq reqwest client should build"),
63        }
64    }
65
66    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
67    /// Panics if the environment variable is not set.
68    pub fn from_env() -> Self {
69        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
70        Self::new(&api_key)
71    }
72
73    fn post(&self, path: &str) -> reqwest::RequestBuilder {
74        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75        self.http_client.post(url)
76    }
77
78    /// Create a completion model with the given name.
79    ///
80    /// # Example
81    /// ```
82    /// use rig::providers::groq::{Client, self};
83    ///
84    /// // Initialize the Groq client
85    /// let groq = Client::new("your-groq-api-key");
86    ///
87    /// let gpt4 = groq.completion_model(groq::GPT_4);
88    /// ```
89    pub fn completion_model(&self, model: &str) -> CompletionModel {
90        CompletionModel::new(self.clone(), model)
91    }
92
93    /// Create a transcription model with the given name.
94    ///
95    /// # Example
96    /// ```
97    /// use rig::providers::groq::{Client, self};
98    ///
99    /// // Initialize the Groq client
100    /// let groq = Client::new("your-groq-api-key");
101    ///
102    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
103    /// ```
104    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
105        TranscriptionModel::new(self.clone(), model)
106    }
107
108    /// Create an agent builder with the given completion model.
109    ///
110    /// # Example
111    /// ```
112    /// use rig::providers::groq::{Client, self};
113    ///
114    /// // Initialize the Groq client
115    /// let groq = Client::new("your-groq-api-key");
116    ///
117    /// let agent = groq.agent(groq::GPT_4)
118    ///    .preamble("You are comedian AI with a mission to make people laugh.")
119    ///    .temperature(0.0)
120    ///    .build();
121    /// ```
122    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123        AgentBuilder::new(self.completion_model(model))
124    }
125
126    /// Create an extractor builder with the given completion model.
127    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128        &self,
129        model: &str,
130    ) -> ExtractorBuilder<T, CompletionModel> {
131        ExtractorBuilder::new(self.completion_model(model))
132    }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137    message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143    Ok(T),
144    Err(ApiErrorResponse),
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub struct Message {
149    pub role: String,
150    pub content: Option<String>,
151}
152
153impl TryFrom<Message> for message::Message {
154    type Error = message::MessageError;
155
156    fn try_from(message: Message) -> Result<Self, Self::Error> {
157        match message.role.as_str() {
158            "user" => Ok(Self::User {
159                content: OneOrMany::one(
160                    message
161                        .content
162                        .map(|content| message::UserContent::text(&content))
163                        .ok_or_else(|| {
164                            message::MessageError::ConversionError("Empty user message".to_string())
165                        })?,
166                ),
167            }),
168            "assistant" => Ok(Self::Assistant {
169                content: OneOrMany::one(
170                    message
171                        .content
172                        .map(|content| message::AssistantContent::text(&content))
173                        .ok_or_else(|| {
174                            message::MessageError::ConversionError(
175                                "Empty assistant message".to_string(),
176                            )
177                        })?,
178                ),
179            }),
180            _ => Err(message::MessageError::ConversionError(format!(
181                "Unknown role: {}",
182                message.role
183            ))),
184        }
185    }
186}
187
188impl TryFrom<message::Message> for Message {
189    type Error = message::MessageError;
190
191    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
192        match message {
193            message::Message::User { content } => Ok(Self {
194                role: "user".to_string(),
195                content: content.iter().find_map(|c| match c {
196                    message::UserContent::Text(text) => Some(text.text.clone()),
197                    _ => None,
198                }),
199            }),
200            message::Message::Assistant { content } => {
201                let mut text_content: Option<String> = None;
202
203                for c in content.iter() {
204                    match c {
205                        message::AssistantContent::Text(text) => {
206                            text_content = Some(
207                                text_content
208                                    .map(|mut existing| {
209                                        existing.push('\n');
210                                        existing.push_str(&text.text);
211                                        existing
212                                    })
213                                    .unwrap_or_else(|| text.text.clone()),
214                            );
215                        }
216                        message::AssistantContent::ToolCall(_tool_call) => {
217                            return Err(MessageError::ConversionError(
218                                "Tool calls do not exist on this message".into(),
219                            ))
220                        }
221                    }
222                }
223
224                Ok(Self {
225                    role: "assistant".to_string(),
226                    content: text_content,
227                })
228            }
229        }
230    }
231}
232
233// ================================================================
234// Groq Completion API
235// ================================================================
236/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
237pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
238/// The `gemma2-9b-it` model. Used for chat completion.
239pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
240/// The `llama-3.1-8b-instant` model. Used for chat completion.
241pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
242/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
243pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
244/// The `llama-3.2-1b-preview` model. Used for chat completion.
245pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
246/// The `llama-3.2-3b-preview` model. Used for chat completion.
247pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
248/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
249pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
250/// The `llama-3.2-70b-specdec` model. Used for chat completion.
251pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
252/// The `llama-3.2-70b-versatile` model. Used for chat completion.
253pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
254/// The `llama-guard-3-8b` model. Used for chat completion.
255pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
256/// The `llama3-70b-8192` model. Used for chat completion.
257pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
258/// The `llama3-8b-8192` model. Used for chat completion.
259pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
260/// The `mixtral-8x7b-32768` model. Used for chat completion.
261pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
262
263#[derive(Clone)]
264pub struct CompletionModel {
265    client: Client,
266    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
267    pub model: String,
268}
269
270impl CompletionModel {
271    pub fn new(client: Client, model: &str) -> Self {
272        Self {
273            client,
274            model: model.to_string(),
275        }
276    }
277
278    fn create_completion_request(
279        &self,
280        completion_request: CompletionRequest,
281    ) -> Result<Value, CompletionError> {
282        // Build up the order of messages (context, chat_history, prompt)
283        let mut partial_history = vec![];
284        if let Some(docs) = completion_request.normalized_documents() {
285            partial_history.push(docs);
286        }
287        partial_history.extend(completion_request.chat_history);
288
289        // Initialize full history with preamble (or empty if non-existent)
290        let mut full_history: Vec<Message> =
291            completion_request
292                .preamble
293                .map_or_else(Vec::new, |preamble| {
294                    vec![Message {
295                        role: "system".to_string(),
296                        content: Some(preamble),
297                    }]
298                });
299
300        // Convert and extend the rest of the history
301        full_history.extend(
302            partial_history
303                .into_iter()
304                .map(message::Message::try_into)
305                .collect::<Result<Vec<Message>, _>>()?,
306        );
307
308        let request = if completion_request.tools.is_empty() {
309            json!({
310                "model": self.model,
311                "messages": full_history,
312                "temperature": completion_request.temperature,
313            })
314        } else {
315            json!({
316                "model": self.model,
317                "messages": full_history,
318                "temperature": completion_request.temperature,
319                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
320                "tool_choice": "auto",
321            })
322        };
323
324        let request = if let Some(params) = completion_request.additional_params {
325            json_utils::merge(request, params)
326        } else {
327            request
328        };
329
330        Ok(request)
331    }
332}
333
334impl completion::CompletionModel for CompletionModel {
335    type Response = CompletionResponse;
336
337    #[cfg_attr(feature = "worker", worker::send)]
338    async fn completion(
339        &self,
340        completion_request: CompletionRequest,
341    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
342        let request = self.create_completion_request(completion_request)?;
343
344        let response = self
345            .client
346            .post("/chat/completions")
347            .json(&request)
348            .send()
349            .await?;
350
351        if response.status().is_success() {
352            match response.json::<ApiResponse<CompletionResponse>>().await? {
353                ApiResponse::Ok(response) => {
354                    tracing::info!(target: "rig",
355                        "groq completion token usage: {:?}",
356                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
357                    );
358                    response.try_into()
359                }
360                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
361            }
362        } else {
363            Err(CompletionError::ProviderError(response.text().await?))
364        }
365    }
366}
367
368impl StreamingCompletionModel for CompletionModel {
369    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
370        let mut request = self.create_completion_request(request)?;
371
372        request = merge(request, json!({"stream": true}));
373
374        let builder = self.client.post("/chat/completions").json(&request);
375
376        send_compatible_streaming_request(builder).await
377    }
378}
379
380// ================================================================
381// Groq Transcription API
382// ================================================================
383pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
384pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
385pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
386
387#[derive(Clone)]
388pub struct TranscriptionModel {
389    client: Client,
390    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
391    pub model: String,
392}
393
394impl TranscriptionModel {
395    pub fn new(client: Client, model: &str) -> Self {
396        Self {
397            client,
398            model: model.to_string(),
399        }
400    }
401}
402impl transcription::TranscriptionModel for TranscriptionModel {
403    type Response = TranscriptionResponse;
404
405    #[cfg_attr(feature = "worker", worker::send)]
406    async fn transcription(
407        &self,
408        request: transcription::TranscriptionRequest,
409    ) -> Result<
410        transcription::TranscriptionResponse<Self::Response>,
411        transcription::TranscriptionError,
412    > {
413        let data = request.data;
414
415        let mut body = reqwest::multipart::Form::new()
416            .text("model", self.model.clone())
417            .text("language", request.language)
418            .part(
419                "file",
420                Part::bytes(data).file_name(request.filename.clone()),
421            );
422
423        if let Some(prompt) = request.prompt {
424            body = body.text("prompt", prompt.clone());
425        }
426
427        if let Some(ref temperature) = request.temperature {
428            body = body.text("temperature", temperature.to_string());
429        }
430
431        if let Some(ref additional_params) = request.additional_params {
432            for (key, value) in additional_params
433                .as_object()
434                .expect("Additional Parameters to OpenAI Transcription should be a map")
435            {
436                body = body.text(key.to_owned(), value.to_string());
437            }
438        }
439
440        let response = self
441            .client
442            .post("audio/transcriptions")
443            .multipart(body)
444            .send()
445            .await?;
446
447        if response.status().is_success() {
448            match response
449                .json::<ApiResponse<TranscriptionResponse>>()
450                .await?
451            {
452                ApiResponse::Ok(response) => response.try_into(),
453                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
454                    api_error_response.message,
455                )),
456            }
457        } else {
458            Err(TranscriptionError::ProviderError(response.text().await?))
459        }
460    }
461}