1use crate::{
12 agent::AgentBuilder,
13 completion::{self, CompletionError, CompletionRequest},
14 extractor::ExtractorBuilder,
15 json_utils,
16 message::{self, MessageError},
17 providers::openai::ToolDefinition,
18 OneOrMany,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23
24use super::openai::CompletionResponse;
25
26const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
30
31#[derive(Clone)]
32pub struct Client {
33 base_url: String,
34 http_client: reqwest::Client,
35}
36
37impl Client {
38 pub fn new(api_key: &str) -> Self {
40 Self::from_url(api_key, GROQ_API_BASE_URL)
41 }
42
43 pub fn from_url(api_key: &str, base_url: &str) -> Self {
45 Self {
46 base_url: base_url.to_string(),
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 "Authorization",
52 format!("Bearer {}", api_key)
53 .parse()
54 .expect("Bearer token should parse"),
55 );
56 headers
57 })
58 .build()
59 .expect("Groq reqwest client should build"),
60 }
61 }
62
63 pub fn from_env() -> Self {
66 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
67 Self::new(&api_key)
68 }
69
70 fn post(&self, path: &str) -> reqwest::RequestBuilder {
71 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
72 self.http_client.post(url)
73 }
74
75 pub fn completion_model(&self, model: &str) -> CompletionModel {
87 CompletionModel::new(self.clone(), model)
88 }
89
90 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
105 AgentBuilder::new(self.completion_model(model))
106 }
107
108 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
110 &self,
111 model: &str,
112 ) -> ExtractorBuilder<T, CompletionModel> {
113 ExtractorBuilder::new(self.completion_model(model))
114 }
115}
116
117#[derive(Debug, Deserialize)]
118struct ApiErrorResponse {
119 message: String,
120}
121
122#[derive(Debug, Deserialize)]
123#[serde(untagged)]
124enum ApiResponse<T> {
125 Ok(T),
126 Err(ApiErrorResponse),
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130pub struct Message {
131 pub role: String,
132 pub content: Option<String>,
133}
134
135impl TryFrom<Message> for message::Message {
136 type Error = message::MessageError;
137
138 fn try_from(message: Message) -> Result<Self, Self::Error> {
139 match message.role.as_str() {
140 "user" => Ok(Self::User {
141 content: OneOrMany::one(
142 message
143 .content
144 .map(|content| message::UserContent::text(&content))
145 .ok_or_else(|| {
146 message::MessageError::ConversionError("Empty user message".to_string())
147 })?,
148 ),
149 }),
150 "assistant" => Ok(Self::Assistant {
151 content: OneOrMany::one(
152 message
153 .content
154 .map(|content| message::AssistantContent::text(&content))
155 .ok_or_else(|| {
156 message::MessageError::ConversionError(
157 "Empty assistant message".to_string(),
158 )
159 })?,
160 ),
161 }),
162 _ => Err(message::MessageError::ConversionError(format!(
163 "Unknown role: {}",
164 message.role
165 ))),
166 }
167 }
168}
169
170impl TryFrom<message::Message> for Message {
171 type Error = message::MessageError;
172
173 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
174 match message {
175 message::Message::User { content } => Ok(Self {
176 role: "user".to_string(),
177 content: content.iter().find_map(|c| match c {
178 message::UserContent::Text(text) => Some(text.text.clone()),
179 _ => None,
180 }),
181 }),
182 message::Message::Assistant { content } => {
183 let mut text_content: Option<String> = None;
184
185 for c in content.iter() {
186 match c {
187 message::AssistantContent::Text(text) => {
188 text_content = Some(
189 text_content
190 .map(|mut existing| {
191 existing.push('\n');
192 existing.push_str(&text.text);
193 existing
194 })
195 .unwrap_or_else(|| text.text.clone()),
196 );
197 }
198 message::AssistantContent::ToolCall(_tool_call) => {
199 return Err(MessageError::ConversionError(
200 "Tool calls do not exist on this message".into(),
201 ))
202 }
203 }
204 }
205
206 Ok(Self {
207 role: "assistant".to_string(),
208 content: text_content,
209 })
210 }
211 }
212 }
213}
214
215pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
220pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
222pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
224pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
226pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
228pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
230pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
232pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
234pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
236pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
238pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
240pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
242pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
244
245#[derive(Clone)]
246pub struct CompletionModel {
247 client: Client,
248 pub model: String,
250}
251
252impl CompletionModel {
253 pub fn new(client: Client, model: &str) -> Self {
254 Self {
255 client,
256 model: model.to_string(),
257 }
258 }
259}
260
261impl completion::CompletionModel for CompletionModel {
262 type Response = CompletionResponse;
263
264 #[cfg_attr(feature = "worker", worker::send)]
265 async fn completion(
266 &self,
267 completion_request: CompletionRequest,
268 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
269 let mut full_history: Vec<Message> = match &completion_request.preamble {
271 Some(preamble) => vec![Message {
272 role: "system".to_string(),
273 content: Some(preamble.to_string()),
274 }],
275 None => vec![],
276 };
277
278 let prompt: Message = completion_request.prompt_with_context().try_into()?;
280
281 let chat_history: Vec<Message> = completion_request
283 .chat_history
284 .into_iter()
285 .map(|message| message.try_into())
286 .collect::<Result<Vec<Message>, _>>()?;
287
288 full_history.extend(chat_history);
290 full_history.push(prompt);
291
292 let request = if completion_request.tools.is_empty() {
293 json!({
294 "model": self.model,
295 "messages": full_history,
296 "temperature": completion_request.temperature,
297 })
298 } else {
299 json!({
300 "model": self.model,
301 "messages": full_history,
302 "temperature": completion_request.temperature,
303 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
304 "tool_choice": "auto",
305 })
306 };
307
308 let response = self
309 .client
310 .post("/chat/completions")
311 .json(
312 &if let Some(params) = completion_request.additional_params {
313 json_utils::merge(request, params)
314 } else {
315 request
316 },
317 )
318 .send()
319 .await?;
320
321 if response.status().is_success() {
322 match response.json::<ApiResponse<CompletionResponse>>().await? {
323 ApiResponse::Ok(response) => {
324 tracing::info!(target: "rig",
325 "groq completion token usage: {:?}",
326 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
327 );
328 response.try_into()
329 }
330 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
331 }
332 } else {
333 Err(CompletionError::ProviderError(response.text().await?))
334 }
335 }
336}