1use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
12use crate::json_utils::merge;
13use crate::streaming::{StreamingCompletionModel, StreamingResult};
14use crate::{
15 agent::AgentBuilder,
16 completion::{self, CompletionError, CompletionRequest},
17 extractor::ExtractorBuilder,
18 json_utils,
19 message::{self, MessageError},
20 providers::openai::ToolDefinition,
21 transcription::{self, TranscriptionError},
22 OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{json, Value};
28
29const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
33
34#[derive(Clone)]
35pub struct Client {
36 base_url: String,
37 http_client: reqwest::Client,
38}
39
40impl Client {
41 pub fn new(api_key: &str) -> Self {
43 Self::from_url(api_key, GROQ_API_BASE_URL)
44 }
45
46 pub fn from_url(api_key: &str, base_url: &str) -> Self {
48 Self {
49 base_url: base_url.to_string(),
50 http_client: reqwest::Client::builder()
51 .default_headers({
52 let mut headers = reqwest::header::HeaderMap::new();
53 headers.insert(
54 "Authorization",
55 format!("Bearer {}", api_key)
56 .parse()
57 .expect("Bearer token should parse"),
58 );
59 headers
60 })
61 .build()
62 .expect("Groq reqwest client should build"),
63 }
64 }
65
66 pub fn from_env() -> Self {
69 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
70 Self::new(&api_key)
71 }
72
73 fn post(&self, path: &str) -> reqwest::RequestBuilder {
74 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75 self.http_client.post(url)
76 }
77
78 pub fn completion_model(&self, model: &str) -> CompletionModel {
90 CompletionModel::new(self.clone(), model)
91 }
92
93 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
105 TranscriptionModel::new(self.clone(), model)
106 }
107
108 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123 AgentBuilder::new(self.completion_model(model))
124 }
125
126 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128 &self,
129 model: &str,
130 ) -> ExtractorBuilder<T, CompletionModel> {
131 ExtractorBuilder::new(self.completion_model(model))
132 }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137 message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143 Ok(T),
144 Err(ApiErrorResponse),
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub struct Message {
149 pub role: String,
150 pub content: Option<String>,
151}
152
153impl TryFrom<Message> for message::Message {
154 type Error = message::MessageError;
155
156 fn try_from(message: Message) -> Result<Self, Self::Error> {
157 match message.role.as_str() {
158 "user" => Ok(Self::User {
159 content: OneOrMany::one(
160 message
161 .content
162 .map(|content| message::UserContent::text(&content))
163 .ok_or_else(|| {
164 message::MessageError::ConversionError("Empty user message".to_string())
165 })?,
166 ),
167 }),
168 "assistant" => Ok(Self::Assistant {
169 content: OneOrMany::one(
170 message
171 .content
172 .map(|content| message::AssistantContent::text(&content))
173 .ok_or_else(|| {
174 message::MessageError::ConversionError(
175 "Empty assistant message".to_string(),
176 )
177 })?,
178 ),
179 }),
180 _ => Err(message::MessageError::ConversionError(format!(
181 "Unknown role: {}",
182 message.role
183 ))),
184 }
185 }
186}
187
188impl TryFrom<message::Message> for Message {
189 type Error = message::MessageError;
190
191 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
192 match message {
193 message::Message::User { content } => Ok(Self {
194 role: "user".to_string(),
195 content: content.iter().find_map(|c| match c {
196 message::UserContent::Text(text) => Some(text.text.clone()),
197 _ => None,
198 }),
199 }),
200 message::Message::Assistant { content } => {
201 let mut text_content: Option<String> = None;
202
203 for c in content.iter() {
204 match c {
205 message::AssistantContent::Text(text) => {
206 text_content = Some(
207 text_content
208 .map(|mut existing| {
209 existing.push('\n');
210 existing.push_str(&text.text);
211 existing
212 })
213 .unwrap_or_else(|| text.text.clone()),
214 );
215 }
216 message::AssistantContent::ToolCall(_tool_call) => {
217 return Err(MessageError::ConversionError(
218 "Tool calls do not exist on this message".into(),
219 ))
220 }
221 }
222 }
223
224 Ok(Self {
225 role: "assistant".to_string(),
226 content: text_content,
227 })
228 }
229 }
230 }
231}
232
233pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
238pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
240pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
242pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
244pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
246pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
248pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
250pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
252pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
254pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
256pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
258pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
260pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
262
263#[derive(Clone)]
264pub struct CompletionModel {
265 client: Client,
266 pub model: String,
268}
269
270impl CompletionModel {
271 pub fn new(client: Client, model: &str) -> Self {
272 Self {
273 client,
274 model: model.to_string(),
275 }
276 }
277
278 fn create_completion_request(
279 &self,
280 completion_request: CompletionRequest,
281 ) -> Result<Value, CompletionError> {
282 let mut partial_history = vec![];
284 if let Some(docs) = completion_request.normalized_documents() {
285 partial_history.push(docs);
286 }
287 partial_history.extend(completion_request.chat_history);
288
289 let mut full_history: Vec<Message> =
291 completion_request
292 .preamble
293 .map_or_else(Vec::new, |preamble| {
294 vec![Message {
295 role: "system".to_string(),
296 content: Some(preamble),
297 }]
298 });
299
300 full_history.extend(
302 partial_history
303 .into_iter()
304 .map(message::Message::try_into)
305 .collect::<Result<Vec<Message>, _>>()?,
306 );
307
308 let request = if completion_request.tools.is_empty() {
309 json!({
310 "model": self.model,
311 "messages": full_history,
312 "temperature": completion_request.temperature,
313 })
314 } else {
315 json!({
316 "model": self.model,
317 "messages": full_history,
318 "temperature": completion_request.temperature,
319 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
320 "tool_choice": "auto",
321 })
322 };
323
324 let request = if let Some(params) = completion_request.additional_params {
325 json_utils::merge(request, params)
326 } else {
327 request
328 };
329
330 Ok(request)
331 }
332}
333
334impl completion::CompletionModel for CompletionModel {
335 type Response = CompletionResponse;
336
337 #[cfg_attr(feature = "worker", worker::send)]
338 async fn completion(
339 &self,
340 completion_request: CompletionRequest,
341 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
342 let request = self.create_completion_request(completion_request)?;
343
344 let response = self
345 .client
346 .post("/chat/completions")
347 .json(&request)
348 .send()
349 .await?;
350
351 if response.status().is_success() {
352 match response.json::<ApiResponse<CompletionResponse>>().await? {
353 ApiResponse::Ok(response) => {
354 tracing::info!(target: "rig",
355 "groq completion token usage: {:?}",
356 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
357 );
358 response.try_into()
359 }
360 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
361 }
362 } else {
363 Err(CompletionError::ProviderError(response.text().await?))
364 }
365 }
366}
367
368impl StreamingCompletionModel for CompletionModel {
369 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
370 let mut request = self.create_completion_request(request)?;
371
372 request = merge(request, json!({"stream": true}));
373
374 let builder = self.client.post("/chat/completions").json(&request);
375
376 send_compatible_streaming_request(builder).await
377 }
378}
379
380pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
384pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
385pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
386
387#[derive(Clone)]
388pub struct TranscriptionModel {
389 client: Client,
390 pub model: String,
392}
393
394impl TranscriptionModel {
395 pub fn new(client: Client, model: &str) -> Self {
396 Self {
397 client,
398 model: model.to_string(),
399 }
400 }
401}
402impl transcription::TranscriptionModel for TranscriptionModel {
403 type Response = TranscriptionResponse;
404
405 #[cfg_attr(feature = "worker", worker::send)]
406 async fn transcription(
407 &self,
408 request: transcription::TranscriptionRequest,
409 ) -> Result<
410 transcription::TranscriptionResponse<Self::Response>,
411 transcription::TranscriptionError,
412 > {
413 let data = request.data;
414
415 let mut body = reqwest::multipart::Form::new()
416 .text("model", self.model.clone())
417 .text("language", request.language)
418 .part(
419 "file",
420 Part::bytes(data).file_name(request.filename.clone()),
421 );
422
423 if let Some(prompt) = request.prompt {
424 body = body.text("prompt", prompt.clone());
425 }
426
427 if let Some(ref temperature) = request.temperature {
428 body = body.text("temperature", temperature.to_string());
429 }
430
431 if let Some(ref additional_params) = request.additional_params {
432 for (key, value) in additional_params
433 .as_object()
434 .expect("Additional Parameters to OpenAI Transcription should be a map")
435 {
436 body = body.text(key.to_owned(), value.to_string());
437 }
438 }
439
440 let response = self
441 .client
442 .post("audio/transcriptions")
443 .multipart(body)
444 .send()
445 .await?;
446
447 if response.status().is_success() {
448 match response
449 .json::<ApiResponse<TranscriptionResponse>>()
450 .await?
451 {
452 ApiResponse::Ok(response) => response.try_into(),
453 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
454 api_error_response.message,
455 )),
456 }
457 } else {
458 Err(TranscriptionError::ProviderError(response.text().await?))
459 }
460 }
461}