rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use super::openai::{CompletionResponse, TranscriptionResponse, send_compatible_streaming_request};
12use crate::client::{CompletionClient, TranscriptionClient};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    OneOrMany,
18    completion::{self, CompletionError, CompletionRequest},
19    json_utils,
20    message::{self, MessageError},
21    providers::openai::ToolDefinition,
22    transcription::{self, TranscriptionError},
23};
24use reqwest::multipart::Part;
25use rig::client::ProviderClient;
26use rig::impl_conversion_traits;
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29
30// ================================================================
31// Main Groq Client
32// ================================================================
33const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
34
35#[derive(Clone)]
36pub struct Client {
37    base_url: String,
38    api_key: String,
39    http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("Client")
45            .field("base_url", &self.base_url)
46            .field("http_client", &self.http_client)
47            .field("api_key", &"<REDACTED>")
48            .finish()
49    }
50}
51
52impl Client {
53    /// Create a new Groq client with the given API key.
54    pub fn new(api_key: &str) -> Self {
55        Self::from_url(api_key, GROQ_API_BASE_URL)
56    }
57
58    /// Create a new Groq client with the given API key and base API URL.
59    pub fn from_url(api_key: &str, base_url: &str) -> Self {
60        Self {
61            base_url: base_url.to_string(),
62            api_key: api_key.to_string(),
63            http_client: reqwest::Client::builder()
64                .build()
65                .expect("Groq reqwest client should build"),
66        }
67    }
68
69    /// Use your own `reqwest::Client`.
70    /// The required headers will be automatically attached upon trying to make a request.
71    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72        self.http_client = client;
73
74        self
75    }
76
77    fn post(&self, path: &str) -> reqwest::RequestBuilder {
78        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79        self.http_client.post(url).bearer_auth(&self.api_key)
80    }
81}
82
83impl ProviderClient for Client {
84    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
85    /// Panics if the environment variable is not set.
86    fn from_env() -> Self {
87        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
88        Self::new(&api_key)
89    }
90
91    fn from_val(input: crate::client::ProviderValue) -> Self {
92        let crate::client::ProviderValue::Simple(api_key) = input else {
93            panic!("Incorrect provider value type")
94        };
95        Self::new(&api_key)
96    }
97}
98
99impl CompletionClient for Client {
100    type CompletionModel = CompletionModel;
101
102    /// Create a completion model with the given name.
103    ///
104    /// # Example
105    /// ```
106    /// use rig::providers::groq::{Client, self};
107    ///
108    /// // Initialize the Groq client
109    /// let groq = Client::new("your-groq-api-key");
110    ///
111    /// let gpt4 = groq.completion_model(groq::GPT_4);
112    /// ```
113    fn completion_model(&self, model: &str) -> CompletionModel {
114        CompletionModel::new(self.clone(), model)
115    }
116}
117
118impl TranscriptionClient for Client {
119    type TranscriptionModel = TranscriptionModel;
120
121    /// Create a transcription model with the given name.
122    ///
123    /// # Example
124    /// ```
125    /// use rig::providers::groq::{Client, self};
126    ///
127    /// // Initialize the Groq client
128    /// let groq = Client::new("your-groq-api-key");
129    ///
130    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
131    /// ```
132    fn transcription_model(&self, model: &str) -> TranscriptionModel {
133        TranscriptionModel::new(self.clone(), model)
134    }
135}
136
137impl_conversion_traits!(
138    AsEmbeddings,
139    AsImageGeneration,
140    AsAudioGeneration for Client
141);
142
143#[derive(Debug, Deserialize)]
144struct ApiErrorResponse {
145    message: String,
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150enum ApiResponse<T> {
151    Ok(T),
152    Err(ApiErrorResponse),
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156pub struct Message {
157    pub role: String,
158    pub content: Option<String>,
159}
160
161impl TryFrom<Message> for message::Message {
162    type Error = message::MessageError;
163
164    fn try_from(message: Message) -> Result<Self, Self::Error> {
165        match message.role.as_str() {
166            "user" => Ok(Self::User {
167                content: OneOrMany::one(
168                    message
169                        .content
170                        .map(|content| message::UserContent::text(&content))
171                        .ok_or_else(|| {
172                            message::MessageError::ConversionError("Empty user message".to_string())
173                        })?,
174                ),
175            }),
176            "assistant" => Ok(Self::Assistant {
177                id: None,
178                content: OneOrMany::one(
179                    message
180                        .content
181                        .map(|content| message::AssistantContent::text(&content))
182                        .ok_or_else(|| {
183                            message::MessageError::ConversionError(
184                                "Empty assistant message".to_string(),
185                            )
186                        })?,
187                ),
188            }),
189            _ => Err(message::MessageError::ConversionError(format!(
190                "Unknown role: {}",
191                message.role
192            ))),
193        }
194    }
195}
196
197impl TryFrom<message::Message> for Message {
198    type Error = message::MessageError;
199
200    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
201        match message {
202            message::Message::User { content } => Ok(Self {
203                role: "user".to_string(),
204                content: content.iter().find_map(|c| match c {
205                    message::UserContent::Text(text) => Some(text.text.clone()),
206                    _ => None,
207                }),
208            }),
209            message::Message::Assistant { content, .. } => {
210                let mut text_content: Option<String> = None;
211
212                for c in content.iter() {
213                    match c {
214                        message::AssistantContent::Text(text) => {
215                            text_content = Some(
216                                text_content
217                                    .map(|mut existing| {
218                                        existing.push('\n');
219                                        existing.push_str(&text.text);
220                                        existing
221                                    })
222                                    .unwrap_or_else(|| text.text.clone()),
223                            );
224                        }
225                        message::AssistantContent::ToolCall(_tool_call) => {
226                            return Err(MessageError::ConversionError(
227                                "Tool calls do not exist on this message".into(),
228                            ));
229                        }
230                    }
231                }
232
233                Ok(Self {
234                    role: "assistant".to_string(),
235                    content: text_content,
236                })
237            }
238        }
239    }
240}
241
242// ================================================================
243// Groq Completion API
244// ================================================================
245/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
246pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
247/// The `gemma2-9b-it` model. Used for chat completion.
248pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
249/// The `llama-3.1-8b-instant` model. Used for chat completion.
250pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
251/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
252pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
253/// The `llama-3.2-1b-preview` model. Used for chat completion.
254pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
255/// The `llama-3.2-3b-preview` model. Used for chat completion.
256pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
257/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
258pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
259/// The `llama-3.2-70b-specdec` model. Used for chat completion.
260pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
261/// The `llama-3.2-70b-versatile` model. Used for chat completion.
262pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
263/// The `llama-guard-3-8b` model. Used for chat completion.
264pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
265/// The `llama3-70b-8192` model. Used for chat completion.
266pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
267/// The `llama3-8b-8192` model. Used for chat completion.
268pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
269/// The `mixtral-8x7b-32768` model. Used for chat completion.
270pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
271
272#[derive(Clone, Debug)]
273pub struct CompletionModel {
274    client: Client,
275    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
276    pub model: String,
277}
278
279impl CompletionModel {
280    pub fn new(client: Client, model: &str) -> Self {
281        Self {
282            client,
283            model: model.to_string(),
284        }
285    }
286
287    fn create_completion_request(
288        &self,
289        completion_request: CompletionRequest,
290    ) -> Result<Value, CompletionError> {
291        // Build up the order of messages (context, chat_history, prompt)
292        let mut partial_history = vec![];
293        if let Some(docs) = completion_request.normalized_documents() {
294            partial_history.push(docs);
295        }
296        partial_history.extend(completion_request.chat_history);
297
298        // Initialize full history with preamble (or empty if non-existent)
299        let mut full_history: Vec<Message> =
300            completion_request
301                .preamble
302                .map_or_else(Vec::new, |preamble| {
303                    vec![Message {
304                        role: "system".to_string(),
305                        content: Some(preamble),
306                    }]
307                });
308
309        // Convert and extend the rest of the history
310        full_history.extend(
311            partial_history
312                .into_iter()
313                .map(message::Message::try_into)
314                .collect::<Result<Vec<Message>, _>>()?,
315        );
316
317        let request = if completion_request.tools.is_empty() {
318            json!({
319                "model": self.model,
320                "messages": full_history,
321                "temperature": completion_request.temperature,
322            })
323        } else {
324            json!({
325                "model": self.model,
326                "messages": full_history,
327                "temperature": completion_request.temperature,
328                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
329                "tool_choice": "auto",
330            })
331        };
332
333        let request = if let Some(params) = completion_request.additional_params {
334            json_utils::merge(request, params)
335        } else {
336            request
337        };
338
339        Ok(request)
340    }
341}
342
343impl completion::CompletionModel for CompletionModel {
344    type Response = CompletionResponse;
345    type StreamingResponse = openai::StreamingCompletionResponse;
346
347    #[cfg_attr(feature = "worker", worker::send)]
348    async fn completion(
349        &self,
350        completion_request: CompletionRequest,
351    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352        let request = self.create_completion_request(completion_request)?;
353
354        let response = self
355            .client
356            .post("/chat/completions")
357            .json(&request)
358            .send()
359            .await?;
360
361        if response.status().is_success() {
362            match response.json::<ApiResponse<CompletionResponse>>().await? {
363                ApiResponse::Ok(response) => {
364                    tracing::info!(target: "rig",
365                        "groq completion token usage: {:?}",
366                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
367                    );
368                    response.try_into()
369                }
370                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
371            }
372        } else {
373            Err(CompletionError::ProviderError(response.text().await?))
374        }
375    }
376
377    #[cfg_attr(feature = "worker", worker::send)]
378    async fn stream(
379        &self,
380        request: CompletionRequest,
381    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
382        let mut request = self.create_completion_request(request)?;
383
384        request = merge(
385            request,
386            json!({"stream": true, "stream_options": {"include_usage": true}}),
387        );
388
389        let builder = self.client.post("/chat/completions").json(&request);
390
391        send_compatible_streaming_request(builder).await
392    }
393}
394
395// ================================================================
396// Groq Transcription API
397// ================================================================
398pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
399pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
400pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
401
402#[derive(Clone)]
403pub struct TranscriptionModel {
404    client: Client,
405    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
406    pub model: String,
407}
408
409impl TranscriptionModel {
410    pub fn new(client: Client, model: &str) -> Self {
411        Self {
412            client,
413            model: model.to_string(),
414        }
415    }
416}
417impl transcription::TranscriptionModel for TranscriptionModel {
418    type Response = TranscriptionResponse;
419
420    #[cfg_attr(feature = "worker", worker::send)]
421    async fn transcription(
422        &self,
423        request: transcription::TranscriptionRequest,
424    ) -> Result<
425        transcription::TranscriptionResponse<Self::Response>,
426        transcription::TranscriptionError,
427    > {
428        let data = request.data;
429
430        let mut body = reqwest::multipart::Form::new()
431            .text("model", self.model.clone())
432            .text("language", request.language)
433            .part(
434                "file",
435                Part::bytes(data).file_name(request.filename.clone()),
436            );
437
438        if let Some(prompt) = request.prompt {
439            body = body.text("prompt", prompt.clone());
440        }
441
442        if let Some(ref temperature) = request.temperature {
443            body = body.text("temperature", temperature.to_string());
444        }
445
446        if let Some(ref additional_params) = request.additional_params {
447            for (key, value) in additional_params
448                .as_object()
449                .expect("Additional Parameters to OpenAI Transcription should be a map")
450            {
451                body = body.text(key.to_owned(), value.to_string());
452            }
453        }
454
455        let response = self
456            .client
457            .post("audio/transcriptions")
458            .multipart(body)
459            .send()
460            .await?;
461
462        if response.status().is_success() {
463            match response
464                .json::<ApiResponse<TranscriptionResponse>>()
465                .await?
466            {
467                ApiResponse::Ok(response) => response.try_into(),
468                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
469                    api_error_response.message,
470                )),
471            }
472        } else {
473            Err(TranscriptionError::ProviderError(response.text().await?))
474        }
475    }
476}