rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use crate::{
12    agent::AgentBuilder,
13    completion::{self, CompletionError, CompletionRequest},
14    extractor::ExtractorBuilder,
15    json_utils,
16    message::{self, MessageError},
17    providers::openai::ToolDefinition,
18    OneOrMany,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23
24use super::openai::CompletionResponse;
25
26// ================================================================
27// Main Groq Client
28// ================================================================
29const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
30
31#[derive(Clone)]
32pub struct Client {
33    base_url: String,
34    http_client: reqwest::Client,
35}
36
37impl Client {
38    /// Create a new Groq client with the given API key.
39    pub fn new(api_key: &str) -> Self {
40        Self::from_url(api_key, GROQ_API_BASE_URL)
41    }
42
43    /// Create a new Groq client with the given API key and base API URL.
44    pub fn from_url(api_key: &str, base_url: &str) -> Self {
45        Self {
46            base_url: base_url.to_string(),
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        "Authorization",
52                        format!("Bearer {}", api_key)
53                            .parse()
54                            .expect("Bearer token should parse"),
55                    );
56                    headers
57                })
58                .build()
59                .expect("Groq reqwest client should build"),
60        }
61    }
62
63    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
64    /// Panics if the environment variable is not set.
65    pub fn from_env() -> Self {
66        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
67        Self::new(&api_key)
68    }
69
70    fn post(&self, path: &str) -> reqwest::RequestBuilder {
71        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
72        self.http_client.post(url)
73    }
74
75    /// Create a completion model with the given name.
76    ///
77    /// # Example
78    /// ```
79    /// use rig::providers::groq::{Client, self};
80    ///
81    /// // Initialize the Groq client
82    /// let groq = Client::new("your-groq-api-key");
83    ///
84    /// let gpt4 = groq.completion_model(groq::GPT_4);
85    /// ```
86    pub fn completion_model(&self, model: &str) -> CompletionModel {
87        CompletionModel::new(self.clone(), model)
88    }
89
90    /// Create an agent builder with the given completion model.
91    ///
92    /// # Example
93    /// ```
94    /// use rig::providers::groq::{Client, self};
95    ///
96    /// // Initialize the Groq client
97    /// let groq = Client::new("your-groq-api-key");
98    ///
99    /// let agent = groq.agent(groq::GPT_4)
100    ///    .preamble("You are comedian AI with a mission to make people laugh.")
101    ///    .temperature(0.0)
102    ///    .build();
103    /// ```
104    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105        AgentBuilder::new(self.completion_model(model))
106    }
107
108    /// Create an extractor builder with the given completion model.
109    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
110        &self,
111        model: &str,
112    ) -> ExtractorBuilder<T, CompletionModel> {
113        ExtractorBuilder::new(self.completion_model(model))
114    }
115}
116
117#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119    message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125    Ok(T),
126    Err(ApiErrorResponse),
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130pub struct Message {
131    pub role: String,
132    pub content: Option<String>,
133}
134
135impl TryFrom<Message> for message::Message {
136    type Error = message::MessageError;
137
138    fn try_from(message: Message) -> Result<Self, Self::Error> {
139        match message.role.as_str() {
140            "user" => Ok(Self::User {
141                content: OneOrMany::one(
142                    message
143                        .content
144                        .map(|content| message::UserContent::text(&content))
145                        .ok_or_else(|| {
146                            message::MessageError::ConversionError("Empty user message".to_string())
147                        })?,
148                ),
149            }),
150            "assistant" => Ok(Self::Assistant {
151                content: OneOrMany::one(
152                    message
153                        .content
154                        .map(|content| message::AssistantContent::text(&content))
155                        .ok_or_else(|| {
156                            message::MessageError::ConversionError(
157                                "Empty assistant message".to_string(),
158                            )
159                        })?,
160                ),
161            }),
162            _ => Err(message::MessageError::ConversionError(format!(
163                "Unknown role: {}",
164                message.role
165            ))),
166        }
167    }
168}
169
170impl TryFrom<message::Message> for Message {
171    type Error = message::MessageError;
172
173    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
174        match message {
175            message::Message::User { content } => Ok(Self {
176                role: "user".to_string(),
177                content: content.iter().find_map(|c| match c {
178                    message::UserContent::Text(text) => Some(text.text.clone()),
179                    _ => None,
180                }),
181            }),
182            message::Message::Assistant { content } => {
183                let mut text_content: Option<String> = None;
184
185                for c in content.iter() {
186                    match c {
187                        message::AssistantContent::Text(text) => {
188                            text_content = Some(
189                                text_content
190                                    .map(|mut existing| {
191                                        existing.push('\n');
192                                        existing.push_str(&text.text);
193                                        existing
194                                    })
195                                    .unwrap_or_else(|| text.text.clone()),
196                            );
197                        }
198                        message::AssistantContent::ToolCall(_tool_call) => {
199                            return Err(MessageError::ConversionError(
200                                "Tool calls do not exist on this message".into(),
201                            ))
202                        }
203                    }
204                }
205
206                Ok(Self {
207                    role: "assistant".to_string(),
208                    content: text_content,
209                })
210            }
211        }
212    }
213}
214
215// ================================================================
216// Groq Completion API
217// ================================================================
218/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
219pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
220/// The `gemma2-9b-it` model. Used for chat completion.
221pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
222/// The `llama-3.1-8b-instant` model. Used for chat completion.
223pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
224/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
225pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
226/// The `llama-3.2-1b-preview` model. Used for chat completion.
227pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
228/// The `llama-3.2-3b-preview` model. Used for chat completion.
229pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
230/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
231pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
232/// The `llama-3.2-70b-specdec` model. Used for chat completion.
233pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
234/// The `llama-3.2-70b-versatile` model. Used for chat completion.
235pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
236/// The `llama-guard-3-8b` model. Used for chat completion.
237pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
238/// The `llama3-70b-8192` model. Used for chat completion.
239pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
240/// The `llama3-8b-8192` model. Used for chat completion.
241pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
242/// The `mixtral-8x7b-32768` model. Used for chat completion.
243pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
244
245#[derive(Clone)]
246pub struct CompletionModel {
247    client: Client,
248    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
249    pub model: String,
250}
251
252impl CompletionModel {
253    pub fn new(client: Client, model: &str) -> Self {
254        Self {
255            client,
256            model: model.to_string(),
257        }
258    }
259}
260
261impl completion::CompletionModel for CompletionModel {
262    type Response = CompletionResponse;
263
264    #[cfg_attr(feature = "worker", worker::send)]
265    async fn completion(
266        &self,
267        completion_request: CompletionRequest,
268    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
269        // Add preamble to chat history (if available)
270        let mut full_history: Vec<Message> = match &completion_request.preamble {
271            Some(preamble) => vec![Message {
272                role: "system".to_string(),
273                content: Some(preamble.to_string()),
274            }],
275            None => vec![],
276        };
277
278        // Convert prompt to user message
279        let prompt: Message = completion_request.prompt_with_context().try_into()?;
280
281        // Convert existing chat history
282        let chat_history: Vec<Message> = completion_request
283            .chat_history
284            .into_iter()
285            .map(|message| message.try_into())
286            .collect::<Result<Vec<Message>, _>>()?;
287
288        // Combine all messages into a single history
289        full_history.extend(chat_history);
290        full_history.push(prompt);
291
292        let request = if completion_request.tools.is_empty() {
293            json!({
294                "model": self.model,
295                "messages": full_history,
296                "temperature": completion_request.temperature,
297            })
298        } else {
299            json!({
300                "model": self.model,
301                "messages": full_history,
302                "temperature": completion_request.temperature,
303                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
304                "tool_choice": "auto",
305            })
306        };
307
308        let response = self
309            .client
310            .post("/chat/completions")
311            .json(
312                &if let Some(params) = completion_request.additional_params {
313                    json_utils::merge(request, params)
314                } else {
315                    request
316                },
317            )
318            .send()
319            .await?;
320
321        if response.status().is_success() {
322            match response.json::<ApiResponse<CompletionResponse>>().await? {
323                ApiResponse::Ok(response) => {
324                    tracing::info!(target: "rig",
325                        "groq completion token usage: {:?}",
326                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
327                    );
328                    response.try_into()
329                }
330                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
331            }
332        } else {
333            Err(CompletionError::ProviderError(response.text().await?))
334        }
335    }
336}