1use super::openai::{CompletionResponse, TranscriptionResponse, send_compatible_streaming_request};
12use crate::client::{CompletionClient, TranscriptionClient};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 OneOrMany,
18 completion::{self, CompletionError, CompletionRequest},
19 json_utils,
20 message::{self, MessageError},
21 providers::openai::ToolDefinition,
22 transcription::{self, TranscriptionError},
23};
24use reqwest::multipart::Part;
25use rig::client::ProviderClient;
26use rig::impl_conversion_traits;
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29
30const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
34
35#[derive(Clone)]
36pub struct Client {
37 base_url: String,
38 api_key: String,
39 http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("Client")
45 .field("base_url", &self.base_url)
46 .field("http_client", &self.http_client)
47 .field("api_key", &"<REDACTED>")
48 .finish()
49 }
50}
51
52impl Client {
53 pub fn new(api_key: &str) -> Self {
55 Self::from_url(api_key, GROQ_API_BASE_URL)
56 }
57
58 pub fn from_url(api_key: &str, base_url: &str) -> Self {
60 Self {
61 base_url: base_url.to_string(),
62 api_key: api_key.to_string(),
63 http_client: reqwest::Client::builder()
64 .build()
65 .expect("Groq reqwest client should build"),
66 }
67 }
68
69 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72 self.http_client = client;
73
74 self
75 }
76
77 fn post(&self, path: &str) -> reqwest::RequestBuilder {
78 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79 self.http_client.post(url).bearer_auth(&self.api_key)
80 }
81}
82
83impl ProviderClient for Client {
84 fn from_env() -> Self {
87 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
88 Self::new(&api_key)
89 }
90
91 fn from_val(input: crate::client::ProviderValue) -> Self {
92 let crate::client::ProviderValue::Simple(api_key) = input else {
93 panic!("Incorrect provider value type")
94 };
95 Self::new(&api_key)
96 }
97}
98
99impl CompletionClient for Client {
100 type CompletionModel = CompletionModel;
101
102 fn completion_model(&self, model: &str) -> CompletionModel {
114 CompletionModel::new(self.clone(), model)
115 }
116}
117
118impl TranscriptionClient for Client {
119 type TranscriptionModel = TranscriptionModel;
120
121 fn transcription_model(&self, model: &str) -> TranscriptionModel {
133 TranscriptionModel::new(self.clone(), model)
134 }
135}
136
137impl_conversion_traits!(
138 AsEmbeddings,
139 AsImageGeneration,
140 AsAudioGeneration for Client
141);
142
143#[derive(Debug, Deserialize)]
144struct ApiErrorResponse {
145 message: String,
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150enum ApiResponse<T> {
151 Ok(T),
152 Err(ApiErrorResponse),
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156pub struct Message {
157 pub role: String,
158 pub content: Option<String>,
159}
160
161impl TryFrom<Message> for message::Message {
162 type Error = message::MessageError;
163
164 fn try_from(message: Message) -> Result<Self, Self::Error> {
165 match message.role.as_str() {
166 "user" => Ok(Self::User {
167 content: OneOrMany::one(
168 message
169 .content
170 .map(|content| message::UserContent::text(&content))
171 .ok_or_else(|| {
172 message::MessageError::ConversionError("Empty user message".to_string())
173 })?,
174 ),
175 }),
176 "assistant" => Ok(Self::Assistant {
177 id: None,
178 content: OneOrMany::one(
179 message
180 .content
181 .map(|content| message::AssistantContent::text(&content))
182 .ok_or_else(|| {
183 message::MessageError::ConversionError(
184 "Empty assistant message".to_string(),
185 )
186 })?,
187 ),
188 }),
189 _ => Err(message::MessageError::ConversionError(format!(
190 "Unknown role: {}",
191 message.role
192 ))),
193 }
194 }
195}
196
197impl TryFrom<message::Message> for Message {
198 type Error = message::MessageError;
199
200 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
201 match message {
202 message::Message::User { content } => Ok(Self {
203 role: "user".to_string(),
204 content: content.iter().find_map(|c| match c {
205 message::UserContent::Text(text) => Some(text.text.clone()),
206 _ => None,
207 }),
208 }),
209 message::Message::Assistant { content, .. } => {
210 let mut text_content: Option<String> = None;
211
212 for c in content.iter() {
213 match c {
214 message::AssistantContent::Text(text) => {
215 text_content = Some(
216 text_content
217 .map(|mut existing| {
218 existing.push('\n');
219 existing.push_str(&text.text);
220 existing
221 })
222 .unwrap_or_else(|| text.text.clone()),
223 );
224 }
225 message::AssistantContent::ToolCall(_tool_call) => {
226 return Err(MessageError::ConversionError(
227 "Tool calls do not exist on this message".into(),
228 ));
229 }
230 }
231 }
232
233 Ok(Self {
234 role: "assistant".to_string(),
235 content: text_content,
236 })
237 }
238 }
239 }
240}
241
242pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
247pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
249pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
251pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
253pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
255pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
257pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
259pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
261pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
263pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
265pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
267pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
269pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
271
272#[derive(Clone, Debug)]
273pub struct CompletionModel {
274 client: Client,
275 pub model: String,
277}
278
279impl CompletionModel {
280 pub fn new(client: Client, model: &str) -> Self {
281 Self {
282 client,
283 model: model.to_string(),
284 }
285 }
286
287 fn create_completion_request(
288 &self,
289 completion_request: CompletionRequest,
290 ) -> Result<Value, CompletionError> {
291 let mut partial_history = vec![];
293 if let Some(docs) = completion_request.normalized_documents() {
294 partial_history.push(docs);
295 }
296 partial_history.extend(completion_request.chat_history);
297
298 let mut full_history: Vec<Message> =
300 completion_request
301 .preamble
302 .map_or_else(Vec::new, |preamble| {
303 vec![Message {
304 role: "system".to_string(),
305 content: Some(preamble),
306 }]
307 });
308
309 full_history.extend(
311 partial_history
312 .into_iter()
313 .map(message::Message::try_into)
314 .collect::<Result<Vec<Message>, _>>()?,
315 );
316
317 let request = if completion_request.tools.is_empty() {
318 json!({
319 "model": self.model,
320 "messages": full_history,
321 "temperature": completion_request.temperature,
322 })
323 } else {
324 json!({
325 "model": self.model,
326 "messages": full_history,
327 "temperature": completion_request.temperature,
328 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
329 "tool_choice": "auto",
330 })
331 };
332
333 let request = if let Some(params) = completion_request.additional_params {
334 json_utils::merge(request, params)
335 } else {
336 request
337 };
338
339 Ok(request)
340 }
341}
342
343impl completion::CompletionModel for CompletionModel {
344 type Response = CompletionResponse;
345 type StreamingResponse = openai::StreamingCompletionResponse;
346
347 #[cfg_attr(feature = "worker", worker::send)]
348 async fn completion(
349 &self,
350 completion_request: CompletionRequest,
351 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352 let request = self.create_completion_request(completion_request)?;
353
354 let response = self
355 .client
356 .post("/chat/completions")
357 .json(&request)
358 .send()
359 .await?;
360
361 if response.status().is_success() {
362 match response.json::<ApiResponse<CompletionResponse>>().await? {
363 ApiResponse::Ok(response) => {
364 tracing::info!(target: "rig",
365 "groq completion token usage: {:?}",
366 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
367 );
368 response.try_into()
369 }
370 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
371 }
372 } else {
373 Err(CompletionError::ProviderError(response.text().await?))
374 }
375 }
376
377 #[cfg_attr(feature = "worker", worker::send)]
378 async fn stream(
379 &self,
380 request: CompletionRequest,
381 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
382 let mut request = self.create_completion_request(request)?;
383
384 request = merge(
385 request,
386 json!({"stream": true, "stream_options": {"include_usage": true}}),
387 );
388
389 let builder = self.client.post("/chat/completions").json(&request);
390
391 send_compatible_streaming_request(builder).await
392 }
393}
394
395pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
399pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
400pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
401
402#[derive(Clone)]
403pub struct TranscriptionModel {
404 client: Client,
405 pub model: String,
407}
408
409impl TranscriptionModel {
410 pub fn new(client: Client, model: &str) -> Self {
411 Self {
412 client,
413 model: model.to_string(),
414 }
415 }
416}
417impl transcription::TranscriptionModel for TranscriptionModel {
418 type Response = TranscriptionResponse;
419
420 #[cfg_attr(feature = "worker", worker::send)]
421 async fn transcription(
422 &self,
423 request: transcription::TranscriptionRequest,
424 ) -> Result<
425 transcription::TranscriptionResponse<Self::Response>,
426 transcription::TranscriptionError,
427 > {
428 let data = request.data;
429
430 let mut body = reqwest::multipart::Form::new()
431 .text("model", self.model.clone())
432 .text("language", request.language)
433 .part(
434 "file",
435 Part::bytes(data).file_name(request.filename.clone()),
436 );
437
438 if let Some(prompt) = request.prompt {
439 body = body.text("prompt", prompt.clone());
440 }
441
442 if let Some(ref temperature) = request.temperature {
443 body = body.text("temperature", temperature.to_string());
444 }
445
446 if let Some(ref additional_params) = request.additional_params {
447 for (key, value) in additional_params
448 .as_object()
449 .expect("Additional Parameters to OpenAI Transcription should be a map")
450 {
451 body = body.text(key.to_owned(), value.to_string());
452 }
453 }
454
455 let response = self
456 .client
457 .post("audio/transcriptions")
458 .multipart(body)
459 .send()
460 .await?;
461
462 if response.status().is_success() {
463 match response
464 .json::<ApiResponse<TranscriptionResponse>>()
465 .await?
466 {
467 ApiResponse::Ok(response) => response.try_into(),
468 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
469 api_error_response.message,
470 )),
471 }
472 } else {
473 Err(TranscriptionError::ProviderError(response.text().await?))
474 }
475 }
476}