1use bytes::Bytes;
12use http::Request;
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{
19 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
20 ProviderClient,
21};
22use crate::completion::GetTokenUsage;
23use crate::http_client::sse::{Event, GenericEventSource};
24use crate::http_client::{self, HttpClientExt};
25use crate::json_utils::empty_or_none;
26use crate::providers::openai::{AssistantContent, Function, ToolType};
27use async_stream::stream;
28use futures::StreamExt;
29
30use crate::{
31 OneOrMany,
32 completion::{self, CompletionError, CompletionRequest},
33 json_utils,
34 message::{self, MessageError},
35 providers::openai::ToolDefinition,
36 transcription::{self, TranscriptionError},
37};
38use reqwest::multipart::Part;
39use serde::{Deserialize, Serialize};
40
41const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqBuilder;
50
51type GroqApiKey = BearerAuth;
52
53impl Provider for GroqExt {
54 type Builder = GroqBuilder;
55
56 const VERIFY_PATH: &'static str = "/models";
57
58 fn build<H>(
59 _: &crate::client::ClientBuilder<
60 Self::Builder,
61 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
62 H,
63 >,
64 ) -> http_client::Result<Self> {
65 Ok(Self)
66 }
67}
68
69impl<H> Capabilities<H> for GroqExt {
70 type Completion = Capable<CompletionModel<H>>;
71 type Embeddings = Nothing;
72 type Transcription = Capable<TranscriptionModel<H>>;
73 #[cfg(feature = "image")]
74 type ImageGeneration = Nothing;
75
76 #[cfg(feature = "audio")]
77 type AudioGeneration = Nothing;
78}
79
80impl DebugExt for GroqExt {}
81
82impl ProviderBuilder for GroqBuilder {
83 type Output = GroqExt;
84 type ApiKey = GroqApiKey;
85
86 const BASE_URL: &'static str = GROQ_API_BASE_URL;
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93 type Input = String;
94
95 fn from_env() -> Self {
98 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
99 Self::new(&api_key).unwrap()
100 }
101
102 fn from_val(input: Self::Input) -> Self {
103 Self::new(&input).unwrap()
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub struct Message {
121 pub role: String,
122 pub content: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 pub reasoning: Option<String>,
125}
126
127impl Message {
128 fn system(preamble: &str) -> Self {
129 Self {
130 role: "system".to_string(),
131 content: Some(preamble.to_string()),
132 reasoning: None,
133 }
134 }
135}
136
137impl TryFrom<Message> for message::Message {
138 type Error = message::MessageError;
139
140 fn try_from(message: Message) -> Result<Self, Self::Error> {
141 match message.role.as_str() {
142 "user" => Ok(Self::User {
143 content: OneOrMany::one(
144 message
145 .content
146 .map(|content| message::UserContent::text(&content))
147 .ok_or_else(|| {
148 message::MessageError::ConversionError("Empty user message".to_string())
149 })?,
150 ),
151 }),
152 "assistant" => Ok(Self::Assistant {
153 id: None,
154 content: OneOrMany::one(
155 message
156 .content
157 .map(|content| message::AssistantContent::text(&content))
158 .ok_or_else(|| {
159 message::MessageError::ConversionError(
160 "Empty assistant message".to_string(),
161 )
162 })?,
163 ),
164 }),
165 _ => Err(message::MessageError::ConversionError(format!(
166 "Unknown role: {}",
167 message.role
168 ))),
169 }
170 }
171}
172
173impl TryFrom<message::Message> for Message {
174 type Error = message::MessageError;
175
176 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
177 match message {
178 message::Message::User { content } => Ok(Self {
179 role: "user".to_string(),
180 content: content.iter().find_map(|c| match c {
181 message::UserContent::Text(text) => Some(text.text.clone()),
182 _ => None,
183 }),
184 reasoning: None,
185 }),
186 message::Message::Assistant { content, .. } => {
187 let mut text_content: Option<String> = None;
188 let mut groq_reasoning: Option<String> = None;
189
190 for c in content.iter() {
191 match c {
192 message::AssistantContent::Text(text) => {
193 text_content = Some(
194 text_content
195 .map(|mut existing| {
196 existing.push('\n');
197 existing.push_str(&text.text);
198 existing
199 })
200 .unwrap_or_else(|| text.text.clone()),
201 );
202 }
203 message::AssistantContent::ToolCall(_tool_call) => {
204 return Err(MessageError::ConversionError(
205 "Tool calls do not exist on this message".into(),
206 ));
207 }
208 message::AssistantContent::Reasoning(message::Reasoning {
209 reasoning,
210 ..
211 }) => {
212 groq_reasoning =
213 Some(reasoning.first().cloned().unwrap_or(String::new()));
214 }
215 message::AssistantContent::Image(_) => {
216 return Err(MessageError::ConversionError(
217 "Ollama currently doesn't support images.".into(),
218 ));
219 }
220 }
221 }
222
223 Ok(Self {
224 role: "assistant".to_string(),
225 content: text_content,
226 reasoning: groq_reasoning,
227 })
228 }
229 }
230 }
231}
232
233pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
239pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
241pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
243pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
245pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
247pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
249pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
251pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
253pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
255pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
257pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
259pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
261pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "lowercase")]
266pub enum ReasoningFormat {
267 Parsed,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub(super) struct GroqCompletionRequest {
272 model: String,
273 pub messages: Vec<Message>,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 temperature: Option<f64>,
276 #[serde(skip_serializing_if = "Vec::is_empty")]
277 tools: Vec<ToolDefinition>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
280 #[serde(flatten, skip_serializing_if = "Option::is_none")]
281 pub additional_params: Option<serde_json::Value>,
282 reasoning_format: ReasoningFormat,
283}
284
285impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
286 type Error = CompletionError;
287
288 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
289 let mut partial_history = vec![];
291 if let Some(docs) = req.normalized_documents() {
292 partial_history.push(docs);
293 }
294 partial_history.extend(req.chat_history);
295
296 let mut full_history: Vec<Message> = match &req.preamble {
298 Some(preamble) => vec![Message::system(preamble)],
299 None => vec![],
300 };
301
302 full_history.extend(
304 partial_history
305 .into_iter()
306 .map(message::Message::try_into)
307 .collect::<Result<Vec<Message>, _>>()?,
308 );
309
310 let tool_choice = req
311 .tool_choice
312 .clone()
313 .map(crate::providers::openai::ToolChoice::try_from)
314 .transpose()?;
315
316 Ok(Self {
317 model: model.to_string(),
318 messages: full_history,
319 temperature: req.temperature,
320 tools: req
321 .tools
322 .clone()
323 .into_iter()
324 .map(ToolDefinition::from)
325 .collect::<Vec<_>>(),
326 tool_choice,
327 additional_params: req.additional_params,
328 reasoning_format: ReasoningFormat::Parsed,
329 })
330 }
331}
332
333#[derive(Clone, Debug)]
334pub struct CompletionModel<T = reqwest::Client> {
335 client: Client<T>,
336 pub model: String,
338}
339
340impl<T> CompletionModel<T> {
341 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
342 Self {
343 client,
344 model: model.into(),
345 }
346 }
347}
348
349impl<T> completion::CompletionModel for CompletionModel<T>
350where
351 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
352{
353 type Response = CompletionResponse;
354 type StreamingResponse = StreamingCompletionResponse;
355
356 type Client = Client<T>;
357
358 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
359 Self::new(client.clone(), model)
360 }
361
362 #[cfg_attr(feature = "worker", worker::send)]
363 async fn completion(
364 &self,
365 completion_request: CompletionRequest,
366 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
367 let span = if tracing::Span::current().is_disabled() {
368 info_span!(
369 target: "rig::completions",
370 "chat",
371 gen_ai.operation.name = "chat",
372 gen_ai.provider.name = "groq",
373 gen_ai.request.model = self.model,
374 gen_ai.system_instructions = tracing::field::Empty,
375 gen_ai.response.id = tracing::field::Empty,
376 gen_ai.response.model = tracing::field::Empty,
377 gen_ai.usage.output_tokens = tracing::field::Empty,
378 gen_ai.usage.input_tokens = tracing::field::Empty,
379 )
380 } else {
381 tracing::Span::current()
382 };
383
384 span.record("gen_ai.system_instructions", &completion_request.preamble);
385
386 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
387
388 if tracing::enabled!(tracing::Level::TRACE) {
389 tracing::trace!(target: "rig::completions",
390 "Groq completion request: {}",
391 serde_json::to_string_pretty(&request)?
392 );
393 }
394
395 let body = serde_json::to_vec(&request)?;
396 let req = self
397 .client
398 .post("/chat/completions")?
399 .body(body)
400 .map_err(|e| http_client::Error::Instance(e.into()))?;
401
402 let async_block = async move {
403 let response = self.client.send::<_, Bytes>(req).await?;
404 let status = response.status();
405 let response_body = response.into_body().into_future().await?.to_vec();
406
407 if status.is_success() {
408 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
409 ApiResponse::Ok(response) => {
410 let span = tracing::Span::current();
411 span.record("gen_ai.response.id", response.id.clone());
412 span.record("gen_ai.response.model_name", response.model.clone());
413 if let Some(ref usage) = response.usage {
414 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
415 span.record(
416 "gen_ai.usage.output_tokens",
417 usage.total_tokens - usage.prompt_tokens,
418 );
419 }
420
421 if tracing::enabled!(tracing::Level::TRACE) {
422 tracing::trace!(target: "rig::completions",
423 "Groq completion response: {}",
424 serde_json::to_string_pretty(&response)?
425 );
426 }
427
428 response.try_into()
429 }
430 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
431 }
432 } else {
433 Err(CompletionError::ProviderError(
434 String::from_utf8_lossy(&response_body).to_string(),
435 ))
436 }
437 };
438
439 tracing::Instrument::instrument(async_block, span).await
440 }
441
442 #[cfg_attr(feature = "worker", worker::send)]
443 async fn stream(
444 &self,
445 request: CompletionRequest,
446 ) -> Result<
447 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
448 CompletionError,
449 > {
450 let span = if tracing::Span::current().is_disabled() {
451 info_span!(
452 target: "rig::completions",
453 "chat_streaming",
454 gen_ai.operation.name = "chat_streaming",
455 gen_ai.provider.name = "groq",
456 gen_ai.request.model = self.model,
457 gen_ai.system_instructions = tracing::field::Empty,
458 gen_ai.response.id = tracing::field::Empty,
459 gen_ai.response.model = tracing::field::Empty,
460 gen_ai.usage.output_tokens = tracing::field::Empty,
461 gen_ai.usage.input_tokens = tracing::field::Empty,
462 )
463 } else {
464 tracing::Span::current()
465 };
466
467 span.record("gen_ai.system_instructions", &request.preamble);
468
469 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
470
471 let params = json_utils::merge(
472 request.additional_params.unwrap_or(serde_json::json!({})),
473 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
474 );
475
476 request.additional_params = Some(params);
477
478 if tracing::enabled!(tracing::Level::TRACE) {
479 tracing::trace!(target: "rig::completions",
480 "Groq streaming completion request: {}",
481 serde_json::to_string_pretty(&request)?
482 );
483 }
484
485 let body = serde_json::to_vec(&request)?;
486 let req = self
487 .client
488 .post("/chat/completions")?
489 .body(body)
490 .map_err(|e| http_client::Error::Instance(e.into()))?;
491
492 tracing::Instrument::instrument(
493 send_compatible_streaming_request(self.client.clone(), req),
494 span,
495 )
496 .await
497 }
498}
499
500pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
505pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
506pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
507
508#[derive(Clone)]
509pub struct TranscriptionModel<T> {
510 client: Client<T>,
511 pub model: String,
513}
514
515impl<T> TranscriptionModel<T> {
516 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
517 Self {
518 client,
519 model: model.into(),
520 }
521 }
522}
523impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
524where
525 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
526{
527 type Response = TranscriptionResponse;
528
529 type Client = Client<T>;
530
531 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
532 Self::new(client.clone(), model)
533 }
534
535 #[cfg_attr(feature = "worker", worker::send)]
536 async fn transcription(
537 &self,
538 request: transcription::TranscriptionRequest,
539 ) -> Result<
540 transcription::TranscriptionResponse<Self::Response>,
541 transcription::TranscriptionError,
542 > {
543 let data = request.data;
544
545 let mut body = reqwest::multipart::Form::new()
546 .text("model", self.model.clone())
547 .part(
548 "file",
549 Part::bytes(data).file_name(request.filename.clone()),
550 );
551
552 if let Some(language) = request.language {
553 body = body.text("language", language);
554 }
555
556 if let Some(prompt) = request.prompt {
557 body = body.text("prompt", prompt.clone());
558 }
559
560 if let Some(ref temperature) = request.temperature {
561 body = body.text("temperature", temperature.to_string());
562 }
563
564 if let Some(ref additional_params) = request.additional_params {
565 for (key, value) in additional_params
566 .as_object()
567 .expect("Additional Parameters to OpenAI Transcription should be a map")
568 {
569 body = body.text(key.to_owned(), value.to_string());
570 }
571 }
572
573 let req = self
574 .client
575 .post("/audio/transcriptions")?
576 .body(body)
577 .unwrap();
578
579 let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
580
581 let status = response.status();
582 let response_body = response.into_body().into_future().await?.to_vec();
583
584 if status.is_success() {
585 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
586 ApiResponse::Ok(response) => response.try_into(),
587 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
588 api_error_response.message,
589 )),
590 }
591 } else {
592 Err(TranscriptionError::ProviderError(
593 String::from_utf8_lossy(&response_body).to_string(),
594 ))
595 }
596 }
597}
598
599#[derive(Deserialize, Debug)]
600#[serde(untagged)]
601enum StreamingDelta {
602 Reasoning {
603 reasoning: String,
604 },
605 MessageContent {
606 #[serde(default)]
607 content: Option<String>,
608 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
609 tool_calls: Vec<StreamingToolCall>,
610 },
611}
612
613#[derive(Deserialize, Debug)]
614struct StreamingChoice {
615 delta: StreamingDelta,
616}
617
618#[derive(Deserialize, Debug)]
619struct StreamingCompletionChunk {
620 choices: Vec<StreamingChoice>,
621 usage: Option<Usage>,
622}
623
624#[derive(Clone, Deserialize, Serialize, Debug)]
625pub struct StreamingCompletionResponse {
626 pub usage: Usage,
627}
628
629impl GetTokenUsage for StreamingCompletionResponse {
630 fn token_usage(&self) -> Option<crate::completion::Usage> {
631 let mut usage = crate::completion::Usage::new();
632
633 usage.input_tokens = self.usage.prompt_tokens as u64;
634 usage.total_tokens = self.usage.total_tokens as u64;
635 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
636
637 Some(usage)
638 }
639}
640
641pub async fn send_compatible_streaming_request<T>(
642 client: T,
643 req: Request<Vec<u8>>,
644) -> Result<
645 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
646 CompletionError,
647>
648where
649 T: HttpClientExt + Clone + 'static,
650{
651 let span = tracing::Span::current();
652
653 let mut event_source = GenericEventSource::new(client, req);
654
655 let stream = stream! {
656 let span = tracing::Span::current();
657 let mut final_usage = Usage {
658 prompt_tokens: 0,
659 total_tokens: 0
660 };
661
662 let mut text_response = String::new();
663
664 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
665
666 while let Some(event_result) = event_source.next().await {
667 match event_result {
668 Ok(Event::Open) => {
669 tracing::trace!("SSE connection opened");
670 continue;
671 }
672
673 Ok(Event::Message(message)) => {
674 let data_str = message.data.trim();
675
676 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
677 let Ok(data) = parsed else {
678 let err = parsed.unwrap_err();
679 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
680 continue;
681 };
682
683 if let Some(choice) = data.choices.first() {
684 match &choice.delta {
685 StreamingDelta::Reasoning { reasoning } => {
686 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
687 id: None,
688 reasoning: reasoning.to_string(),
689 signature: None,
690 });
691 }
692
693 StreamingDelta::MessageContent { content, tool_calls } => {
694 for tool_call in tool_calls {
696 let function = &tool_call.function;
697
698 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
700 && empty_or_none(&function.arguments)
701 {
702 let id = tool_call.id.clone().unwrap_or_default();
703 let name = function.name.clone().unwrap();
704 calls.insert(tool_call.index, (id, name, String::new()));
705 }
706 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
708 && let Some(arguments) = &function.arguments
709 && !arguments.is_empty()
710 {
711 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
712 let combined = format!("{}{}", existing_args, arguments);
713 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
714 } else {
715 tracing::debug!("Partial tool call received but tool call was never started.");
716 }
717 }
718 else {
720 let id = tool_call.id.clone().unwrap_or_default();
721 let name = function.name.clone().unwrap_or_default();
722 let arguments_str = function.arguments.clone().unwrap_or_default();
723
724 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
725 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
726 continue;
727 };
728
729 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
730 id,
731 name,
732 arguments: arguments_json,
733 call_id: None
734 });
735 }
736 }
737
738 if let Some(content) = content {
740 text_response += content;
741 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
742 }
743 }
744 }
745 }
746
747 if let Some(usage) = data.usage {
748 final_usage = usage.clone();
749 }
750 }
751
752 Err(crate::http_client::Error::StreamEnded) => break,
753 Err(err) => {
754 tracing::error!(?err, "SSE error");
755 yield Err(CompletionError::ResponseError(err.to_string()));
756 break;
757 }
758 }
759 }
760
761 event_source.close();
762
763 let mut tool_calls = Vec::new();
764 for (_, (id, name, arguments)) in calls {
766 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
767 continue;
768 };
769
770 tool_calls.push(rig::providers::openai::completion::ToolCall {
771 id: id.clone(),
772 r#type: ToolType::Function,
773 function: Function {
774 name: name.clone(),
775 arguments: arguments_json.clone()
776 }
777 });
778 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
779 id,
780 name,
781 arguments: arguments_json,
782 call_id: None,
783 });
784 }
785
786 let response_message = crate::providers::openai::completion::Message::Assistant {
787 content: vec![AssistantContent::Text { text: text_response }],
788 refusal: None,
789 audio: None,
790 name: None,
791 tool_calls
792 };
793
794 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
795 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
796 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
797
798 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
800 StreamingCompletionResponse { usage: final_usage.clone() }
801 ));
802 }.instrument(span);
803
804 Ok(crate::streaming::StreamingCompletionResponse::stream(
805 Box::pin(stream),
806 ))
807}