1use super::openai::{send_compatible_streaming_request, CompletionResponse, TranscriptionResponse};
12use crate::json_utils::merge;
13use crate::streaming::{StreamingCompletionModel, StreamingResult};
14use crate::{
15 agent::AgentBuilder,
16 completion::{self, CompletionError, CompletionRequest},
17 extractor::ExtractorBuilder,
18 json_utils,
19 message::{self, MessageError},
20 providers::openai::ToolDefinition,
21 transcription::{self, TranscriptionError},
22 OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{json, Value};
28
29const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
33
34#[derive(Clone)]
35pub struct Client {
36 base_url: String,
37 http_client: reqwest::Client,
38}
39
40impl Client {
41 pub fn new(api_key: &str) -> Self {
43 Self::from_url(api_key, GROQ_API_BASE_URL)
44 }
45
46 pub fn from_url(api_key: &str, base_url: &str) -> Self {
48 Self {
49 base_url: base_url.to_string(),
50 http_client: reqwest::Client::builder()
51 .default_headers({
52 let mut headers = reqwest::header::HeaderMap::new();
53 headers.insert(
54 "Authorization",
55 format!("Bearer {}", api_key)
56 .parse()
57 .expect("Bearer token should parse"),
58 );
59 headers
60 })
61 .build()
62 .expect("Groq reqwest client should build"),
63 }
64 }
65
66 pub fn from_env() -> Self {
69 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
70 Self::new(&api_key)
71 }
72
73 fn post(&self, path: &str) -> reqwest::RequestBuilder {
74 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75 self.http_client.post(url)
76 }
77
78 pub fn completion_model(&self, model: &str) -> CompletionModel {
90 CompletionModel::new(self.clone(), model)
91 }
92
93 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
105 TranscriptionModel::new(self.clone(), model)
106 }
107
108 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123 AgentBuilder::new(self.completion_model(model))
124 }
125
126 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128 &self,
129 model: &str,
130 ) -> ExtractorBuilder<T, CompletionModel> {
131 ExtractorBuilder::new(self.completion_model(model))
132 }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137 message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143 Ok(T),
144 Err(ApiErrorResponse),
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148pub struct Message {
149 pub role: String,
150 pub content: Option<String>,
151}
152
153impl TryFrom<Message> for message::Message {
154 type Error = message::MessageError;
155
156 fn try_from(message: Message) -> Result<Self, Self::Error> {
157 match message.role.as_str() {
158 "user" => Ok(Self::User {
159 content: OneOrMany::one(
160 message
161 .content
162 .map(|content| message::UserContent::text(&content))
163 .ok_or_else(|| {
164 message::MessageError::ConversionError("Empty user message".to_string())
165 })?,
166 ),
167 }),
168 "assistant" => Ok(Self::Assistant {
169 content: OneOrMany::one(
170 message
171 .content
172 .map(|content| message::AssistantContent::text(&content))
173 .ok_or_else(|| {
174 message::MessageError::ConversionError(
175 "Empty assistant message".to_string(),
176 )
177 })?,
178 ),
179 }),
180 _ => Err(message::MessageError::ConversionError(format!(
181 "Unknown role: {}",
182 message.role
183 ))),
184 }
185 }
186}
187
188impl TryFrom<message::Message> for Message {
189 type Error = message::MessageError;
190
191 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
192 match message {
193 message::Message::User { content } => Ok(Self {
194 role: "user".to_string(),
195 content: content.iter().find_map(|c| match c {
196 message::UserContent::Text(text) => Some(text.text.clone()),
197 _ => None,
198 }),
199 }),
200 message::Message::Assistant { content } => {
201 let mut text_content: Option<String> = None;
202
203 for c in content.iter() {
204 match c {
205 message::AssistantContent::Text(text) => {
206 text_content = Some(
207 text_content
208 .map(|mut existing| {
209 existing.push('\n');
210 existing.push_str(&text.text);
211 existing
212 })
213 .unwrap_or_else(|| text.text.clone()),
214 );
215 }
216 message::AssistantContent::ToolCall(_tool_call) => {
217 return Err(MessageError::ConversionError(
218 "Tool calls do not exist on this message".into(),
219 ))
220 }
221 }
222 }
223
224 Ok(Self {
225 role: "assistant".to_string(),
226 content: text_content,
227 })
228 }
229 }
230 }
231}
232
233pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
238pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
240pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
242pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
244pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
246pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
248pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
250pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
252pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
254pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
256pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
258pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
260pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
262
263#[derive(Clone)]
264pub struct CompletionModel {
265 client: Client,
266 pub model: String,
268}
269
270impl CompletionModel {
271 pub fn new(client: Client, model: &str) -> Self {
272 Self {
273 client,
274 model: model.to_string(),
275 }
276 }
277
278 fn create_completion_request(
279 &self,
280 completion_request: CompletionRequest,
281 ) -> Result<Value, CompletionError> {
282 let mut full_history: Vec<Message> = match &completion_request.preamble {
284 Some(preamble) => vec![Message {
285 role: "system".to_string(),
286 content: Some(preamble.to_string()),
287 }],
288 None => vec![],
289 };
290
291 let prompt: Message = completion_request.prompt_with_context().try_into()?;
293
294 let chat_history: Vec<Message> = completion_request
296 .chat_history
297 .into_iter()
298 .map(|message| message.try_into())
299 .collect::<Result<Vec<Message>, _>>()?;
300
301 full_history.extend(chat_history);
303 full_history.push(prompt);
304
305 let request = if completion_request.tools.is_empty() {
306 json!({
307 "model": self.model,
308 "messages": full_history,
309 "temperature": completion_request.temperature,
310 })
311 } else {
312 json!({
313 "model": self.model,
314 "messages": full_history,
315 "temperature": completion_request.temperature,
316 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
317 "tool_choice": "auto",
318 })
319 };
320
321 let request = if let Some(params) = completion_request.additional_params {
322 json_utils::merge(request, params)
323 } else {
324 request
325 };
326
327 Ok(request)
328 }
329}
330
331impl completion::CompletionModel for CompletionModel {
332 type Response = CompletionResponse;
333
334 #[cfg_attr(feature = "worker", worker::send)]
335 async fn completion(
336 &self,
337 completion_request: CompletionRequest,
338 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
339 let request = self.create_completion_request(completion_request)?;
340
341 let response = self
342 .client
343 .post("/chat/completions")
344 .json(&request)
345 .send()
346 .await?;
347
348 if response.status().is_success() {
349 match response.json::<ApiResponse<CompletionResponse>>().await? {
350 ApiResponse::Ok(response) => {
351 tracing::info!(target: "rig",
352 "groq completion token usage: {:?}",
353 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
354 );
355 response.try_into()
356 }
357 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
358 }
359 } else {
360 Err(CompletionError::ProviderError(response.text().await?))
361 }
362 }
363}
364
365impl StreamingCompletionModel for CompletionModel {
366 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
367 let mut request = self.create_completion_request(request)?;
368
369 request = merge(request, json!({"stream": true}));
370
371 let builder = self.client.post("/chat/completions").json(&request);
372
373 send_compatible_streaming_request(builder).await
374 }
375}
376
377pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
381pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
382pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
383
384#[derive(Clone)]
385pub struct TranscriptionModel {
386 client: Client,
387 pub model: String,
389}
390
391impl TranscriptionModel {
392 pub fn new(client: Client, model: &str) -> Self {
393 Self {
394 client,
395 model: model.to_string(),
396 }
397 }
398}
399impl transcription::TranscriptionModel for TranscriptionModel {
400 type Response = TranscriptionResponse;
401
402 #[cfg_attr(feature = "worker", worker::send)]
403 async fn transcription(
404 &self,
405 request: transcription::TranscriptionRequest,
406 ) -> Result<
407 transcription::TranscriptionResponse<Self::Response>,
408 transcription::TranscriptionError,
409 > {
410 let data = request.data;
411
412 let mut body = reqwest::multipart::Form::new()
413 .text("model", self.model.clone())
414 .text("language", request.language)
415 .part(
416 "file",
417 Part::bytes(data).file_name(request.filename.clone()),
418 );
419
420 if let Some(prompt) = request.prompt {
421 body = body.text("prompt", prompt.clone());
422 }
423
424 if let Some(ref temperature) = request.temperature {
425 body = body.text("temperature", temperature.to_string());
426 }
427
428 if let Some(ref additional_params) = request.additional_params {
429 for (key, value) in additional_params
430 .as_object()
431 .expect("Additional Parameters to OpenAI Transcription should be a map")
432 {
433 body = body.text(key.to_owned(), value.to_string());
434 }
435 }
436
437 let response = self
438 .client
439 .post("audio/transcriptions")
440 .multipart(body)
441 .send()
442 .await?;
443
444 if response.status().is_success() {
445 match response
446 .json::<ApiResponse<TranscriptionResponse>>()
447 .await?
448 {
449 ApiResponse::Ok(response) => response.try_into(),
450 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
451 api_error_response.message,
452 )),
453 }
454 } else {
455 Err(TranscriptionError::ProviderError(response.text().await?))
456 }
457 }
458}