rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
12use crate::json_utils::merge;
13use crate::streaming::{StreamingCompletionModel, StreamingResult};
14use crate::{
15    agent::AgentBuilder,
16    completion::{self, CompletionError, CompletionRequest},
17    extractor::ExtractorBuilder,
18    json_utils,
19    message::{self, MessageError},
20    providers::openai::ToolDefinition,
21    transcription::{self, TranscriptionError},
22    OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{json, Value};
28
29// ================================================================
30// Main Groq Client
31// ================================================================
32const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
33
34#[derive(Clone)]
35pub struct Client {
36    base_url: String,
37    http_client: reqwest::Client,
38}
39
40impl Client {
41    /// Create a new Groq client with the given API key.
42    pub fn new(api_key: &str) -> Self {
43        Self::from_url(api_key, GROQ_API_BASE_URL)
44    }
45
46    /// Create a new Groq client with the given API key and base API URL.
47    pub fn from_url(api_key: &str, base_url: &str) -> Self {
48        Self {
49            base_url: base_url.to_string(),
50            http_client: reqwest::Client::builder()
51                .default_headers({
52                    let mut headers = reqwest::header::HeaderMap::new();
53                    headers.insert(
54                        "Authorization",
55                        format!("Bearer {}", api_key)
56                            .parse()
57                            .expect("Bearer token should parse"),
58                    );
59                    headers
60                })
61                .build()
62                .expect("Groq reqwest client should build"),
63        }
64    }
65
66    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
67    /// Panics if the environment variable is not set.
68    pub fn from_env() -> Self {
69        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
70        Self::new(&api_key)
71    }
72
73    fn post(&self, path: &str) -> reqwest::RequestBuilder {
74        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75        self.http_client.post(url)
76    }
77
78    /// Create a completion model with the given name.
79    ///
80    /// # Example
81    /// ```
82    /// use rig::providers::groq::{Client, self};
83    ///
84    /// // Initialize the Groq client
85    /// let groq = Client::new("your-groq-api-key");
86    ///
87    /// let gpt4 = groq.completion_model(groq::GPT_4);
88    /// ```
89    pub fn completion_model(&self, model: &str) -> CompletionModel {
90        CompletionModel::new(self.clone(), model)
91    }
92
93    /// Create a transcription model with the given name.
94    ///
95    /// # Example
96    /// ```
97    /// use rig::providers::groq::{Client, self};
98    ///
99    /// // Initialize the Groq client
100    /// let groq = Client::new("your-groq-api-key");
101    ///
102    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
103    /// ```
104    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
105        TranscriptionModel::new(self.clone(), model)
106    }
107
108    /// Create an agent builder with the given completion model.
109    ///
110    /// # Example
111    /// ```
112    /// use rig::providers::groq::{Client, self};
113    ///
114    /// // Initialize the Groq client
115    /// let groq = Client::new("your-groq-api-key");
116    ///
117    /// let agent = groq.agent(groq::GPT_4)
118    ///    .preamble("You are comedian AI with a mission to make people laugh.")
119    ///    .temperature(0.0)
120    ///    .build();
121    /// ```
122    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123        AgentBuilder::new(self.completion_model(model))
124    }
125
126    /// Create an extractor builder with the given completion model.
127    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128        &self,
129        model: &str,
130    ) -> ExtractorBuilder<T, CompletionModel> {
131        ExtractorBuilder::new(self.completion_model(model))
132    }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137    message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143    Ok(T),
144    Err(ApiErrorResponse),
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub struct Message {
149    pub role: String,
150    pub content: Option<String>,
151}
152
153impl TryFrom<Message> for message::Message {
154    type Error = message::MessageError;
155
156    fn try_from(message: Message) -> Result<Self, Self::Error> {
157        match message.role.as_str() {
158            "user" => Ok(Self::User {
159                content: OneOrMany::one(
160                    message
161                        .content
162                        .map(|content| message::UserContent::text(&content))
163                        .ok_or_else(|| {
164                            message::MessageError::ConversionError("Empty user message".to_string())
165                        })?,
166                ),
167            }),
168            "assistant" => Ok(Self::Assistant {
169                content: OneOrMany::one(
170                    message
171                        .content
172                        .map(|content| message::AssistantContent::text(&content))
173                        .ok_or_else(|| {
174                            message::MessageError::ConversionError(
175                                "Empty assistant message".to_string(),
176                            )
177                        })?,
178                ),
179            }),
180            _ => Err(message::MessageError::ConversionError(format!(
181                "Unknown role: {}",
182                message.role
183            ))),
184        }
185    }
186}
187
188impl TryFrom<message::Message> for Message {
189    type Error = message::MessageError;
190
191    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
192        match message {
193            message::Message::User { content } => Ok(Self {
194                role: "user".to_string(),
195                content: content.iter().find_map(|c| match c {
196                    message::UserContent::Text(text) => Some(text.text.clone()),
197                    _ => None,
198                }),
199            }),
200            message::Message::Assistant { content } => {
201                let mut text_content: Option<String> = None;
202
203                for c in content.iter() {
204                    match c {
205                        message::AssistantContent::Text(text) => {
206                            text_content = Some(
207                                text_content
208                                    .map(|mut existing| {
209                                        existing.push('\n');
210                                        existing.push_str(&text.text);
211                                        existing
212                                    })
213                                    .unwrap_or_else(|| text.text.clone()),
214                            );
215                        }
216                        message::AssistantContent::ToolCall(_tool_call) => {
217                            return Err(MessageError::ConversionError(
218                                "Tool calls do not exist on this message".into(),
219                            ))
220                        }
221                    }
222                }
223
224                Ok(Self {
225                    role: "assistant".to_string(),
226                    content: text_content,
227                })
228            }
229        }
230    }
231}
232
233// ================================================================
234// Groq Completion API
235// ================================================================
236/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
237pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
238/// The `gemma2-9b-it` model. Used for chat completion.
239pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
240/// The `llama-3.1-8b-instant` model. Used for chat completion.
241pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
242/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
243pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
244/// The `llama-3.2-1b-preview` model. Used for chat completion.
245pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
246/// The `llama-3.2-3b-preview` model. Used for chat completion.
247pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
248/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
249pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
250/// The `llama-3.2-70b-specdec` model. Used for chat completion.
251pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
252/// The `llama-3.2-70b-versatile` model. Used for chat completion.
253pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
254/// The `llama-guard-3-8b` model. Used for chat completion.
255pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
256/// The `llama3-70b-8192` model. Used for chat completion.
257pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
258/// The `llama3-8b-8192` model. Used for chat completion.
259pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
260/// The `mixtral-8x7b-32768` model. Used for chat completion.
261pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
262
263#[derive(Clone)]
264pub struct CompletionModel {
265    client: Client,
266    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
267    pub model: String,
268}
269
270impl CompletionModel {
271    pub fn new(client: Client, model: &str) -> Self {
272        Self {
273            client,
274            model: model.to_string(),
275        }
276    }
277
278    fn create_completion_request(
279        &self,
280        completion_request: CompletionRequest,
281    ) -> Result<Value, CompletionError> {
282        // Add preamble to chat history (if available)
283        let mut full_history: Vec<Message> = match &completion_request.preamble {
284            Some(preamble) => vec![Message {
285                role: "system".to_string(),
286                content: Some(preamble.to_string()),
287            }],
288            None => vec![],
289        };
290
291        // Convert prompt to user message
292        let prompt: Message = completion_request.prompt_with_context().try_into()?;
293
294        // Convert existing chat history
295        let chat_history: Vec<Message> = completion_request
296            .chat_history
297            .into_iter()
298            .map(|message| message.try_into())
299            .collect::<Result<Vec<Message>, _>>()?;
300
301        // Combine all messages into a single history
302        full_history.extend(chat_history);
303        full_history.push(prompt);
304
305        let request = if completion_request.tools.is_empty() {
306            json!({
307                "model": self.model,
308                "messages": full_history,
309                "temperature": completion_request.temperature,
310            })
311        } else {
312            json!({
313                "model": self.model,
314                "messages": full_history,
315                "temperature": completion_request.temperature,
316                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
317                "tool_choice": "auto",
318            })
319        };
320
321        let request = if let Some(params) = completion_request.additional_params {
322            json_utils::merge(request, params)
323        } else {
324            request
325        };
326
327        Ok(request)
328    }
329}
330
331impl completion::CompletionModel for CompletionModel {
332    type Response = CompletionResponse;
333
334    #[cfg_attr(feature = "worker", worker::send)]
335    async fn completion(
336        &self,
337        completion_request: CompletionRequest,
338    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
339        let request = self.create_completion_request(completion_request)?;
340
341        let response = self
342            .client
343            .post("/chat/completions")
344            .json(&request)
345            .send()
346            .await?;
347
348        if response.status().is_success() {
349            match response.json::<ApiResponse<CompletionResponse>>().await? {
350                ApiResponse::Ok(response) => {
351                    tracing::info!(target: "rig",
352                        "groq completion token usage: {:?}",
353                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
354                    );
355                    response.try_into()
356                }
357                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
358            }
359        } else {
360            Err(CompletionError::ProviderError(response.text().await?))
361        }
362    }
363}
364
365impl StreamingCompletionModel for CompletionModel {
366    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
367        let mut request = self.create_completion_request(request)?;
368
369        request = merge(request, json!({"stream": true}));
370
371        let builder = self.client.post("/chat/completions").json(&request);
372
373        send_compatible_streaming_request(builder).await
374    }
375}
376
377// ================================================================
378// Groq Transcription API
379// ================================================================
380pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
381pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
382pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
383
384#[derive(Clone)]
385pub struct TranscriptionModel {
386    client: Client,
387    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
388    pub model: String,
389}
390
391impl TranscriptionModel {
392    pub fn new(client: Client, model: &str) -> Self {
393        Self {
394            client,
395            model: model.to_string(),
396        }
397    }
398}
399impl transcription::TranscriptionModel for TranscriptionModel {
400    type Response = TranscriptionResponse;
401
402    #[cfg_attr(feature = "worker", worker::send)]
403    async fn transcription(
404        &self,
405        request: transcription::TranscriptionRequest,
406    ) -> Result<
407        transcription::TranscriptionResponse<Self::Response>,
408        transcription::TranscriptionError,
409    > {
410        let data = request.data;
411
412        let mut body = reqwest::multipart::Form::new()
413            .text("model", self.model.clone())
414            .text("language", request.language)
415            .part(
416                "file",
417                Part::bytes(data).file_name(request.filename.clone()),
418            );
419
420        if let Some(prompt) = request.prompt {
421            body = body.text("prompt", prompt.clone());
422        }
423
424        if let Some(ref temperature) = request.temperature {
425            body = body.text("temperature", temperature.to_string());
426        }
427
428        if let Some(ref additional_params) = request.additional_params {
429            for (key, value) in additional_params
430                .as_object()
431                .expect("Additional Parameters to OpenAI Transcription should be a map")
432            {
433                body = body.text(key.to_owned(), value.to_string());
434            }
435        }
436
437        let response = self
438            .client
439            .post("audio/transcriptions")
440            .multipart(body)
441            .send()
442            .await?;
443
444        if response.status().is_success() {
445            match response
446                .json::<ApiResponse<TranscriptionResponse>>()
447                .await?
448            {
449                ApiResponse::Ok(response) => response.try_into(),
450                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
451                    api_error_response.message,
452                )),
453            }
454        } else {
455            Err(TranscriptionError::ProviderError(response.text().await?))
456        }
457    }
458}