rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use super::openai::{CompletionResponse, TranscriptionResponse, send_compatible_streaming_request};
12use crate::client::{CompletionClient, TranscriptionClient};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    OneOrMany,
18    completion::{self, CompletionError, CompletionRequest},
19    json_utils,
20    message::{self, MessageError},
21    providers::openai::ToolDefinition,
22    transcription::{self, TranscriptionError},
23};
24use reqwest::multipart::Part;
25use rig::client::ProviderClient;
26use rig::impl_conversion_traits;
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29
30// ================================================================
31// Main Groq Client
32// ================================================================
33const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
34
35#[derive(Clone)]
36pub struct Client {
37    base_url: String,
38    api_key: String,
39    http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("Client")
45            .field("base_url", &self.base_url)
46            .field("http_client", &self.http_client)
47            .field("api_key", &"<REDACTED>")
48            .finish()
49    }
50}
51
52impl Client {
53    /// Create a new Groq client with the given API key.
54    pub fn new(api_key: &str) -> Self {
55        Self::from_url(api_key, GROQ_API_BASE_URL)
56    }
57
58    /// Create a new Groq client with the given API key and base API URL.
59    pub fn from_url(api_key: &str, base_url: &str) -> Self {
60        Self {
61            base_url: base_url.to_string(),
62            api_key: api_key.to_string(),
63            http_client: reqwest::Client::builder()
64                .build()
65                .expect("Groq reqwest client should build"),
66        }
67    }
68
69    /// Use your own `reqwest::Client`.
70    /// The required headers will be automatically attached upon trying to make a request.
71    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72        self.http_client = client;
73
74        self
75    }
76
77    fn post(&self, path: &str) -> reqwest::RequestBuilder {
78        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79        self.http_client.post(url).bearer_auth(&self.api_key)
80    }
81}
82
83impl ProviderClient for Client {
84    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
85    /// Panics if the environment variable is not set.
86    fn from_env() -> Self {
87        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
88        Self::new(&api_key)
89    }
90}
91
92impl CompletionClient for Client {
93    type CompletionModel = CompletionModel;
94
95    /// Create a completion model with the given name.
96    ///
97    /// # Example
98    /// ```
99    /// use rig::providers::groq::{Client, self};
100    ///
101    /// // Initialize the Groq client
102    /// let groq = Client::new("your-groq-api-key");
103    ///
104    /// let gpt4 = groq.completion_model(groq::GPT_4);
105    /// ```
106    fn completion_model(&self, model: &str) -> CompletionModel {
107        CompletionModel::new(self.clone(), model)
108    }
109}
110
111impl TranscriptionClient for Client {
112    type TranscriptionModel = TranscriptionModel;
113
114    /// Create a transcription model with the given name.
115    ///
116    /// # Example
117    /// ```
118    /// use rig::providers::groq::{Client, self};
119    ///
120    /// // Initialize the Groq client
121    /// let groq = Client::new("your-groq-api-key");
122    ///
123    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
124    /// ```
125    fn transcription_model(&self, model: &str) -> TranscriptionModel {
126        TranscriptionModel::new(self.clone(), model)
127    }
128}
129
130impl_conversion_traits!(
131    AsEmbeddings,
132    AsImageGeneration,
133    AsAudioGeneration for Client
134);
135
136#[derive(Debug, Deserialize)]
137struct ApiErrorResponse {
138    message: String,
139}
140
141#[derive(Debug, Deserialize)]
142#[serde(untagged)]
143enum ApiResponse<T> {
144    Ok(T),
145    Err(ApiErrorResponse),
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct Message {
150    pub role: String,
151    pub content: Option<String>,
152}
153
154impl TryFrom<Message> for message::Message {
155    type Error = message::MessageError;
156
157    fn try_from(message: Message) -> Result<Self, Self::Error> {
158        match message.role.as_str() {
159            "user" => Ok(Self::User {
160                content: OneOrMany::one(
161                    message
162                        .content
163                        .map(|content| message::UserContent::text(&content))
164                        .ok_or_else(|| {
165                            message::MessageError::ConversionError("Empty user message".to_string())
166                        })?,
167                ),
168            }),
169            "assistant" => Ok(Self::Assistant {
170                id: None,
171                content: OneOrMany::one(
172                    message
173                        .content
174                        .map(|content| message::AssistantContent::text(&content))
175                        .ok_or_else(|| {
176                            message::MessageError::ConversionError(
177                                "Empty assistant message".to_string(),
178                            )
179                        })?,
180                ),
181            }),
182            _ => Err(message::MessageError::ConversionError(format!(
183                "Unknown role: {}",
184                message.role
185            ))),
186        }
187    }
188}
189
190impl TryFrom<message::Message> for Message {
191    type Error = message::MessageError;
192
193    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
194        match message {
195            message::Message::User { content } => Ok(Self {
196                role: "user".to_string(),
197                content: content.iter().find_map(|c| match c {
198                    message::UserContent::Text(text) => Some(text.text.clone()),
199                    _ => None,
200                }),
201            }),
202            message::Message::Assistant { content, .. } => {
203                let mut text_content: Option<String> = None;
204
205                for c in content.iter() {
206                    match c {
207                        message::AssistantContent::Text(text) => {
208                            text_content = Some(
209                                text_content
210                                    .map(|mut existing| {
211                                        existing.push('\n');
212                                        existing.push_str(&text.text);
213                                        existing
214                                    })
215                                    .unwrap_or_else(|| text.text.clone()),
216                            );
217                        }
218                        message::AssistantContent::ToolCall(_tool_call) => {
219                            return Err(MessageError::ConversionError(
220                                "Tool calls do not exist on this message".into(),
221                            ));
222                        }
223                    }
224                }
225
226                Ok(Self {
227                    role: "assistant".to_string(),
228                    content: text_content,
229                })
230            }
231        }
232    }
233}
234
235// ================================================================
236// Groq Completion API
237// ================================================================
238/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
239pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
240/// The `gemma2-9b-it` model. Used for chat completion.
241pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
242/// The `llama-3.1-8b-instant` model. Used for chat completion.
243pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
244/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
245pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
246/// The `llama-3.2-1b-preview` model. Used for chat completion.
247pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
248/// The `llama-3.2-3b-preview` model. Used for chat completion.
249pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
250/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
251pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
252/// The `llama-3.2-70b-specdec` model. Used for chat completion.
253pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
254/// The `llama-3.2-70b-versatile` model. Used for chat completion.
255pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
256/// The `llama-guard-3-8b` model. Used for chat completion.
257pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
258/// The `llama3-70b-8192` model. Used for chat completion.
259pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
260/// The `llama3-8b-8192` model. Used for chat completion.
261pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
262/// The `mixtral-8x7b-32768` model. Used for chat completion.
263pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
264
265#[derive(Clone, Debug)]
266pub struct CompletionModel {
267    client: Client,
268    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
269    pub model: String,
270}
271
272impl CompletionModel {
273    pub fn new(client: Client, model: &str) -> Self {
274        Self {
275            client,
276            model: model.to_string(),
277        }
278    }
279
280    fn create_completion_request(
281        &self,
282        completion_request: CompletionRequest,
283    ) -> Result<Value, CompletionError> {
284        // Build up the order of messages (context, chat_history, prompt)
285        let mut partial_history = vec![];
286        if let Some(docs) = completion_request.normalized_documents() {
287            partial_history.push(docs);
288        }
289        partial_history.extend(completion_request.chat_history);
290
291        // Initialize full history with preamble (or empty if non-existent)
292        let mut full_history: Vec<Message> =
293            completion_request
294                .preamble
295                .map_or_else(Vec::new, |preamble| {
296                    vec![Message {
297                        role: "system".to_string(),
298                        content: Some(preamble),
299                    }]
300                });
301
302        // Convert and extend the rest of the history
303        full_history.extend(
304            partial_history
305                .into_iter()
306                .map(message::Message::try_into)
307                .collect::<Result<Vec<Message>, _>>()?,
308        );
309
310        let request = if completion_request.tools.is_empty() {
311            json!({
312                "model": self.model,
313                "messages": full_history,
314                "temperature": completion_request.temperature,
315            })
316        } else {
317            json!({
318                "model": self.model,
319                "messages": full_history,
320                "temperature": completion_request.temperature,
321                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
322                "tool_choice": "auto",
323            })
324        };
325
326        let request = if let Some(params) = completion_request.additional_params {
327            json_utils::merge(request, params)
328        } else {
329            request
330        };
331
332        Ok(request)
333    }
334}
335
336impl completion::CompletionModel for CompletionModel {
337    type Response = CompletionResponse;
338    type StreamingResponse = openai::StreamingCompletionResponse;
339
340    #[cfg_attr(feature = "worker", worker::send)]
341    async fn completion(
342        &self,
343        completion_request: CompletionRequest,
344    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
345        let request = self.create_completion_request(completion_request)?;
346
347        let response = self
348            .client
349            .post("/chat/completions")
350            .json(&request)
351            .send()
352            .await?;
353
354        if response.status().is_success() {
355            match response.json::<ApiResponse<CompletionResponse>>().await? {
356                ApiResponse::Ok(response) => {
357                    tracing::info!(target: "rig",
358                        "groq completion token usage: {:?}",
359                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
360                    );
361                    response.try_into()
362                }
363                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
364            }
365        } else {
366            Err(CompletionError::ProviderError(response.text().await?))
367        }
368    }
369
370    #[cfg_attr(feature = "worker", worker::send)]
371    async fn stream(
372        &self,
373        request: CompletionRequest,
374    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
375        let mut request = self.create_completion_request(request)?;
376
377        request = merge(
378            request,
379            json!({"stream": true, "stream_options": {"include_usage": true}}),
380        );
381
382        let builder = self.client.post("/chat/completions").json(&request);
383
384        send_compatible_streaming_request(builder).await
385    }
386}
387
388// ================================================================
389// Groq Transcription API
390// ================================================================
391pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
392pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
393pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
394
395#[derive(Clone)]
396pub struct TranscriptionModel {
397    client: Client,
398    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
399    pub model: String,
400}
401
402impl TranscriptionModel {
403    pub fn new(client: Client, model: &str) -> Self {
404        Self {
405            client,
406            model: model.to_string(),
407        }
408    }
409}
410impl transcription::TranscriptionModel for TranscriptionModel {
411    type Response = TranscriptionResponse;
412
413    #[cfg_attr(feature = "worker", worker::send)]
414    async fn transcription(
415        &self,
416        request: transcription::TranscriptionRequest,
417    ) -> Result<
418        transcription::TranscriptionResponse<Self::Response>,
419        transcription::TranscriptionError,
420    > {
421        let data = request.data;
422
423        let mut body = reqwest::multipart::Form::new()
424            .text("model", self.model.clone())
425            .text("language", request.language)
426            .part(
427                "file",
428                Part::bytes(data).file_name(request.filename.clone()),
429            );
430
431        if let Some(prompt) = request.prompt {
432            body = body.text("prompt", prompt.clone());
433        }
434
435        if let Some(ref temperature) = request.temperature {
436            body = body.text("temperature", temperature.to_string());
437        }
438
439        if let Some(ref additional_params) = request.additional_params {
440            for (key, value) in additional_params
441                .as_object()
442                .expect("Additional Parameters to OpenAI Transcription should be a map")
443            {
444                body = body.text(key.to_owned(), value.to_string());
445            }
446        }
447
448        let response = self
449            .client
450            .post("audio/transcriptions")
451            .multipart(body)
452            .send()
453            .await?;
454
455        if response.status().is_success() {
456            match response
457                .json::<ApiResponse<TranscriptionResponse>>()
458                .await?
459            {
460                ApiResponse::Ok(response) => response.try_into(),
461                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
462                    api_error_response.message,
463                )),
464            }
465        } else {
466            Err(TranscriptionError::ProviderError(response.text().await?))
467        }
468    }
469}