1use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{ClientBuilderError, CompletionClient, TranscriptionClient};
15use crate::completion::GetTokenUsage;
16use crate::json_utils::merge;
17use futures::StreamExt;
18
19use crate::streaming::RawStreamingChoice;
20use crate::{
21 OneOrMany,
22 completion::{self, CompletionError, CompletionRequest},
23 json_utils,
24 message::{self, MessageError},
25 providers::openai::ToolDefinition,
26 transcription::{self, TranscriptionError},
27};
28use reqwest::RequestBuilder;
29use reqwest::multipart::Part;
30use rig::client::ProviderClient;
31use rig::impl_conversion_traits;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, json};
34
35const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
39
40pub struct ClientBuilder<'a> {
41 api_key: &'a str,
42 base_url: &'a str,
43 http_client: Option<reqwest::Client>,
44}
45
46impl<'a> ClientBuilder<'a> {
47 pub fn new(api_key: &'a str) -> Self {
48 Self {
49 api_key,
50 base_url: GROQ_API_BASE_URL,
51 http_client: None,
52 }
53 }
54
55 pub fn base_url(mut self, base_url: &'a str) -> Self {
56 self.base_url = base_url;
57 self
58 }
59
60 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
61 self.http_client = Some(client);
62 self
63 }
64
65 pub fn build(self) -> Result<Client, ClientBuilderError> {
66 let http_client = if let Some(http_client) = self.http_client {
67 http_client
68 } else {
69 reqwest::Client::builder().build()?
70 };
71
72 Ok(Client {
73 base_url: self.base_url.to_string(),
74 api_key: self.api_key.to_string(),
75 http_client,
76 })
77 }
78}
79
80#[derive(Clone)]
81pub struct Client {
82 base_url: String,
83 api_key: String,
84 http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for Client {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Client")
90 .field("base_url", &self.base_url)
91 .field("http_client", &self.http_client)
92 .field("api_key", &"<REDACTED>")
93 .finish()
94 }
95}
96
97impl Client {
98 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
109 ClientBuilder::new(api_key)
110 }
111
112 pub fn new(api_key: &str) -> Self {
117 Self::builder(api_key)
118 .build()
119 .expect("Groq client should build")
120 }
121
122 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
123 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
124 self.http_client.post(url).bearer_auth(&self.api_key)
125 }
126}
127
128impl ProviderClient for Client {
129 fn from_env() -> Self {
132 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
133 Self::new(&api_key)
134 }
135
136 fn from_val(input: crate::client::ProviderValue) -> Self {
137 let crate::client::ProviderValue::Simple(api_key) = input else {
138 panic!("Incorrect provider value type")
139 };
140 Self::new(&api_key)
141 }
142}
143
144impl CompletionClient for Client {
145 type CompletionModel = CompletionModel;
146
147 fn completion_model(&self, model: &str) -> CompletionModel {
159 CompletionModel::new(self.clone(), model)
160 }
161}
162
163impl TranscriptionClient for Client {
164 type TranscriptionModel = TranscriptionModel;
165
166 fn transcription_model(&self, model: &str) -> TranscriptionModel {
178 TranscriptionModel::new(self.clone(), model)
179 }
180}
181
182impl_conversion_traits!(
183 AsEmbeddings,
184 AsImageGeneration,
185 AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189struct ApiErrorResponse {
190 message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195enum ApiResponse<T> {
196 Ok(T),
197 Err(ApiErrorResponse),
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct Message {
202 pub role: String,
203 pub content: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub reasoning: Option<String>,
206}
207
208impl TryFrom<Message> for message::Message {
209 type Error = message::MessageError;
210
211 fn try_from(message: Message) -> Result<Self, Self::Error> {
212 match message.role.as_str() {
213 "user" => Ok(Self::User {
214 content: OneOrMany::one(
215 message
216 .content
217 .map(|content| message::UserContent::text(&content))
218 .ok_or_else(|| {
219 message::MessageError::ConversionError("Empty user message".to_string())
220 })?,
221 ),
222 }),
223 "assistant" => Ok(Self::Assistant {
224 id: None,
225 content: OneOrMany::one(
226 message
227 .content
228 .map(|content| message::AssistantContent::text(&content))
229 .ok_or_else(|| {
230 message::MessageError::ConversionError(
231 "Empty assistant message".to_string(),
232 )
233 })?,
234 ),
235 }),
236 _ => Err(message::MessageError::ConversionError(format!(
237 "Unknown role: {}",
238 message.role
239 ))),
240 }
241 }
242}
243
244impl TryFrom<message::Message> for Message {
245 type Error = message::MessageError;
246
247 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248 match message {
249 message::Message::User { content } => Ok(Self {
250 role: "user".to_string(),
251 content: content.iter().find_map(|c| match c {
252 message::UserContent::Text(text) => Some(text.text.clone()),
253 _ => None,
254 }),
255 reasoning: None,
256 }),
257 message::Message::Assistant { content, .. } => {
258 let mut text_content: Option<String> = None;
259 let mut groq_reasoning: Option<String> = None;
260
261 for c in content.iter() {
262 match c {
263 message::AssistantContent::Text(text) => {
264 text_content = Some(
265 text_content
266 .map(|mut existing| {
267 existing.push('\n');
268 existing.push_str(&text.text);
269 existing
270 })
271 .unwrap_or_else(|| text.text.clone()),
272 );
273 }
274 message::AssistantContent::ToolCall(_tool_call) => {
275 return Err(MessageError::ConversionError(
276 "Tool calls do not exist on this message".into(),
277 ));
278 }
279 message::AssistantContent::Reasoning(message::Reasoning {
280 reasoning,
281 ..
282 }) => {
283 groq_reasoning =
284 Some(reasoning.first().cloned().unwrap_or(String::new()));
285 }
286 }
287 }
288
289 Ok(Self {
290 role: "assistant".to_string(),
291 content: text_content,
292 reasoning: groq_reasoning,
293 })
294 }
295 }
296 }
297}
298
299pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
304pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
306pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
308pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
310pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
312pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
314pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
316pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
318pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
320pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
322pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
324pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
326pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
328
329#[derive(Clone, Debug)]
330pub struct CompletionModel {
331 client: Client,
332 pub model: String,
334}
335
336impl CompletionModel {
337 pub fn new(client: Client, model: &str) -> Self {
338 Self {
339 client,
340 model: model.to_string(),
341 }
342 }
343
344 fn create_completion_request(
345 &self,
346 completion_request: CompletionRequest,
347 ) -> Result<Value, CompletionError> {
348 let mut partial_history = vec![];
350 if let Some(docs) = completion_request.normalized_documents() {
351 partial_history.push(docs);
352 }
353 partial_history.extend(completion_request.chat_history);
354
355 let mut full_history: Vec<Message> =
357 completion_request
358 .preamble
359 .map_or_else(Vec::new, |preamble| {
360 vec![Message {
361 role: "system".to_string(),
362 content: Some(preamble),
363 reasoning: None,
364 }]
365 });
366
367 full_history.extend(
369 partial_history
370 .into_iter()
371 .map(message::Message::try_into)
372 .collect::<Result<Vec<Message>, _>>()?,
373 );
374
375 let request = if completion_request.tools.is_empty() {
376 json!({
377 "model": self.model,
378 "messages": full_history,
379 "temperature": completion_request.temperature,
380 })
381 } else {
382 json!({
383 "model": self.model,
384 "messages": full_history,
385 "temperature": completion_request.temperature,
386 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
387 "tool_choice": "auto",
388 "reasoning_format": "parsed"
389 })
390 };
391
392 let request = if let Some(params) = completion_request.additional_params {
393 json_utils::merge(request, params)
394 } else {
395 request
396 };
397
398 Ok(request)
399 }
400}
401
402impl completion::CompletionModel for CompletionModel {
403 type Response = CompletionResponse;
404 type StreamingResponse = StreamingCompletionResponse;
405
406 #[cfg_attr(feature = "worker", worker::send)]
407 async fn completion(
408 &self,
409 completion_request: CompletionRequest,
410 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
411 let request = self.create_completion_request(completion_request)?;
412
413 let response = self
414 .client
415 .post("/chat/completions")
416 .json(&request)
417 .send()
418 .await?;
419
420 if response.status().is_success() {
421 match response.json::<ApiResponse<CompletionResponse>>().await? {
422 ApiResponse::Ok(response) => {
423 tracing::info!(target: "rig",
424 "groq completion token usage: {:?}",
425 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
426 );
427 response.try_into()
428 }
429 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
430 }
431 } else {
432 Err(CompletionError::ProviderError(response.text().await?))
433 }
434 }
435
436 #[cfg_attr(feature = "worker", worker::send)]
437 async fn stream(
438 &self,
439 request: CompletionRequest,
440 ) -> Result<
441 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
442 CompletionError,
443 > {
444 let mut request = self.create_completion_request(request)?;
445
446 request = merge(
447 request,
448 json!({"stream": true, "stream_options": {"include_usage": true}}),
449 );
450
451 let builder = self.client.post("/chat/completions").json(&request);
452
453 send_compatible_streaming_request(builder).await
454 }
455}
456
457pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
461pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
462pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
463
464#[derive(Clone)]
465pub struct TranscriptionModel {
466 client: Client,
467 pub model: String,
469}
470
471impl TranscriptionModel {
472 pub fn new(client: Client, model: &str) -> Self {
473 Self {
474 client,
475 model: model.to_string(),
476 }
477 }
478}
479impl transcription::TranscriptionModel for TranscriptionModel {
480 type Response = TranscriptionResponse;
481
482 #[cfg_attr(feature = "worker", worker::send)]
483 async fn transcription(
484 &self,
485 request: transcription::TranscriptionRequest,
486 ) -> Result<
487 transcription::TranscriptionResponse<Self::Response>,
488 transcription::TranscriptionError,
489 > {
490 let data = request.data;
491
492 let mut body = reqwest::multipart::Form::new()
493 .text("model", self.model.clone())
494 .text("language", request.language)
495 .part(
496 "file",
497 Part::bytes(data).file_name(request.filename.clone()),
498 );
499
500 if let Some(prompt) = request.prompt {
501 body = body.text("prompt", prompt.clone());
502 }
503
504 if let Some(ref temperature) = request.temperature {
505 body = body.text("temperature", temperature.to_string());
506 }
507
508 if let Some(ref additional_params) = request.additional_params {
509 for (key, value) in additional_params
510 .as_object()
511 .expect("Additional Parameters to OpenAI Transcription should be a map")
512 {
513 body = body.text(key.to_owned(), value.to_string());
514 }
515 }
516
517 let response = self
518 .client
519 .post("audio/transcriptions")
520 .multipart(body)
521 .send()
522 .await?;
523
524 if response.status().is_success() {
525 match response
526 .json::<ApiResponse<TranscriptionResponse>>()
527 .await?
528 {
529 ApiResponse::Ok(response) => response.try_into(),
530 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
531 api_error_response.message,
532 )),
533 }
534 } else {
535 Err(TranscriptionError::ProviderError(response.text().await?))
536 }
537 }
538}
539
540#[derive(Deserialize, Debug)]
541#[serde(untagged)]
542pub enum StreamingDelta {
543 Reasoning {
544 reasoning: String,
545 },
546 MessageContent {
547 #[serde(default)]
548 content: Option<String>,
549 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
550 tool_calls: Vec<StreamingToolCall>,
551 },
552}
553
554#[derive(Deserialize, Debug)]
555struct StreamingChoice {
556 delta: StreamingDelta,
557}
558
559#[derive(Deserialize, Debug)]
560struct StreamingCompletionChunk {
561 choices: Vec<StreamingChoice>,
562 usage: Option<Usage>,
563}
564
565#[derive(Clone, Deserialize, Serialize, Debug)]
566pub struct StreamingCompletionResponse {
567 pub usage: Usage,
568}
569
570impl GetTokenUsage for StreamingCompletionResponse {
571 fn token_usage(&self) -> Option<crate::completion::Usage> {
572 let mut usage = crate::completion::Usage::new();
573
574 usage.input_tokens = self.usage.prompt_tokens as u64;
575 usage.total_tokens = self.usage.total_tokens as u64;
576 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
577
578 Some(usage)
579 }
580}
581
582pub async fn send_compatible_streaming_request(
583 request_builder: RequestBuilder,
584) -> Result<
585 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
586 CompletionError,
587> {
588 let response = request_builder.send().await?;
589
590 if !response.status().is_success() {
591 return Err(CompletionError::ProviderError(format!(
592 "{}: {}",
593 response.status(),
594 response.text().await?
595 )));
596 }
597
598 let inner = Box::pin(async_stream::stream! {
600 let mut stream = response.bytes_stream();
601
602 let mut final_usage = Usage {
603 prompt_tokens: 0,
604 total_tokens: 0
605 };
606
607 let mut partial_data = None;
608 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
609
610 while let Some(chunk_result) = stream.next().await {
611 let chunk = match chunk_result {
612 Ok(c) => c,
613 Err(e) => {
614 yield Err(CompletionError::from(e));
615 break;
616 }
617 };
618
619 let text = match String::from_utf8(chunk.to_vec()) {
620 Ok(t) => t,
621 Err(e) => {
622 yield Err(CompletionError::ResponseError(e.to_string()));
623 break;
624 }
625 };
626
627
628 for line in text.lines() {
629 let mut line = line.to_string();
630
631 if partial_data.is_some() {
633 line = format!("{}{}", partial_data.unwrap(), line);
634 partial_data = None;
635 }
636 else {
638 let Some(data) = line.strip_prefix("data:") else {
639 continue;
640 };
641
642 let data = data.trim_start();
643
644 if !line.ends_with("}") {
646 partial_data = Some(data.to_string());
647 } else {
648 line = data.to_string();
649 }
650 }
651
652 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
653
654 let Ok(data) = data else {
655 let err = data.unwrap_err();
656 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
657 continue;
658 };
659
660
661 if let Some(choice) = data.choices.first() {
662 let delta = &choice.delta;
663
664 match delta {
665 StreamingDelta::Reasoning { reasoning } => {
666 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { id: None, reasoning: reasoning.to_string() })
667 },
668 StreamingDelta::MessageContent { content, tool_calls } => {
669 if !tool_calls.is_empty() {
670 for tool_call in tool_calls {
671 let function = tool_call.function.clone();
672 if function.name.is_some() && function.arguments.is_empty() {
676 let id = tool_call.id.clone().unwrap_or("".to_string());
677
678 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
679 }
680 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
684 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
685 tracing::debug!("Partial tool call received but tool call was never started.");
686 continue;
687 };
688
689 let new_arguments = &function.arguments;
690 let arguments = format!("{arguments}{new_arguments}");
691
692 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
693 }
694 else {
696 let id = tool_call.id.clone().unwrap_or("".to_string());
697 let name = function.name.expect("function name should be present for complete tool call");
698 let arguments = function.arguments;
699 let Ok(arguments) = serde_json::from_str(&arguments) else {
700 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
701 continue;
702 };
703
704 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
705 }
706 }
707 }
708
709 if let Some(content) = &content {
710 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
711 }
712 }
713 }
714 }
715
716
717 if let Some(usage) = data.usage {
718 final_usage = usage.clone();
719 }
720 }
721 }
722
723 for (_, (id, name, arguments)) in calls {
724 let Ok(arguments) = serde_json::from_str(&arguments) else {
725 continue;
726 };
727
728 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
729 }
730
731 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
732 usage: final_usage.clone()
733 }))
734 });
735
736 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
737}