1use bytes::Bytes;
12use http::Request;
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{
19 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
20 ProviderClient,
21};
22use crate::completion::GetTokenUsage;
23use crate::http_client::sse::{Event, GenericEventSource};
24use crate::http_client::{self, HttpClientExt};
25use crate::json_utils::empty_or_none;
26use crate::providers::openai::{AssistantContent, Function, ToolType};
27use async_stream::stream;
28use futures::StreamExt;
29
30use crate::{
31 OneOrMany,
32 completion::{self, CompletionError, CompletionRequest},
33 json_utils,
34 message::{self, MessageError},
35 providers::openai::ToolDefinition,
36 transcription::{self, TranscriptionError},
37};
38use reqwest::multipart::Part;
39use serde::{Deserialize, Serialize};
40
41const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqBuilder;
50
51type GroqApiKey = BearerAuth;
52
53impl Provider for GroqExt {
54 type Builder = GroqBuilder;
55
56 const VERIFY_PATH: &'static str = "/models";
57
58 fn build<H>(
59 _: &crate::client::ClientBuilder<
60 Self::Builder,
61 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
62 H,
63 >,
64 ) -> http_client::Result<Self> {
65 Ok(Self)
66 }
67}
68
69impl<H> Capabilities<H> for GroqExt {
70 type Completion = Capable<CompletionModel<H>>;
71 type Embeddings = Nothing;
72 type Transcription = Capable<TranscriptionModel<H>>;
73 #[cfg(feature = "image")]
74 type ImageGeneration = Nothing;
75
76 #[cfg(feature = "audio")]
77 type AudioGeneration = Nothing;
78}
79
80impl DebugExt for GroqExt {}
81
82impl ProviderBuilder for GroqBuilder {
83 type Output = GroqExt;
84 type ApiKey = GroqApiKey;
85
86 const BASE_URL: &'static str = GROQ_API_BASE_URL;
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93 type Input = String;
94
95 fn from_env() -> Self {
98 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
99 Self::new(&api_key).unwrap()
100 }
101
102 fn from_val(input: Self::Input) -> Self {
103 Self::new(&input).unwrap()
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub struct Message {
121 pub role: String,
122 pub content: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 pub reasoning: Option<String>,
125}
126
127impl Message {
128 fn system(preamble: &str) -> Self {
129 Self {
130 role: "system".to_string(),
131 content: Some(preamble.to_string()),
132 reasoning: None,
133 }
134 }
135}
136
137impl TryFrom<Message> for message::Message {
138 type Error = message::MessageError;
139
140 fn try_from(message: Message) -> Result<Self, Self::Error> {
141 match message.role.as_str() {
142 "user" => Ok(Self::User {
143 content: OneOrMany::one(
144 message
145 .content
146 .map(|content| message::UserContent::text(&content))
147 .ok_or_else(|| {
148 message::MessageError::ConversionError("Empty user message".to_string())
149 })?,
150 ),
151 }),
152 "assistant" => Ok(Self::Assistant {
153 id: None,
154 content: OneOrMany::one(
155 message
156 .content
157 .map(|content| message::AssistantContent::text(&content))
158 .ok_or_else(|| {
159 message::MessageError::ConversionError(
160 "Empty assistant message".to_string(),
161 )
162 })?,
163 ),
164 }),
165 _ => Err(message::MessageError::ConversionError(format!(
166 "Unknown role: {}",
167 message.role
168 ))),
169 }
170 }
171}
172
173impl TryFrom<message::Message> for Message {
174 type Error = message::MessageError;
175
176 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
177 match message {
178 message::Message::User { content } => Ok(Self {
179 role: "user".to_string(),
180 content: content.iter().find_map(|c| match c {
181 message::UserContent::Text(text) => Some(text.text.clone()),
182 _ => None,
183 }),
184 reasoning: None,
185 }),
186 message::Message::Assistant { content, .. } => {
187 let mut text_content: Option<String> = None;
188 let mut groq_reasoning: Option<String> = None;
189
190 for c in content.iter() {
191 match c {
192 message::AssistantContent::Text(text) => {
193 text_content = Some(
194 text_content
195 .map(|mut existing| {
196 existing.push('\n');
197 existing.push_str(&text.text);
198 existing
199 })
200 .unwrap_or_else(|| text.text.clone()),
201 );
202 }
203 message::AssistantContent::ToolCall(_tool_call) => {
204 return Err(MessageError::ConversionError(
205 "Tool calls do not exist on this message".into(),
206 ));
207 }
208 message::AssistantContent::Reasoning(message::Reasoning {
209 reasoning,
210 ..
211 }) => {
212 groq_reasoning =
213 Some(reasoning.first().cloned().unwrap_or(String::new()));
214 }
215 message::AssistantContent::Image(_) => {
216 return Err(MessageError::ConversionError(
217 "Ollama currently doesn't support images.".into(),
218 ));
219 }
220 }
221 }
222
223 Ok(Self {
224 role: "assistant".to_string(),
225 content: text_content,
226 reasoning: groq_reasoning,
227 })
228 }
229 }
230 }
231}
232
233pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
239pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
241pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
243pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
245pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
247pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
249pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
251pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
253pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
255pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
257pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
259pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
261pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "lowercase")]
266pub enum ReasoningFormat {
267 Parsed,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub(super) struct GroqCompletionRequest {
272 model: String,
273 pub messages: Vec<Message>,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 temperature: Option<f64>,
276 #[serde(skip_serializing_if = "Vec::is_empty")]
277 tools: Vec<ToolDefinition>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
280 #[serde(flatten, skip_serializing_if = "Option::is_none")]
281 pub additional_params: Option<serde_json::Value>,
282 reasoning_format: ReasoningFormat,
283}
284
285impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
286 type Error = CompletionError;
287
288 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
289 let mut partial_history = vec![];
291 if let Some(docs) = req.normalized_documents() {
292 partial_history.push(docs);
293 }
294 partial_history.extend(req.chat_history);
295
296 let mut full_history: Vec<Message> = match &req.preamble {
298 Some(preamble) => vec![Message::system(preamble)],
299 None => vec![],
300 };
301
302 full_history.extend(
304 partial_history
305 .into_iter()
306 .map(message::Message::try_into)
307 .collect::<Result<Vec<Message>, _>>()?,
308 );
309
310 let tool_choice = req
311 .tool_choice
312 .clone()
313 .map(crate::providers::openai::ToolChoice::try_from)
314 .transpose()?;
315
316 Ok(Self {
317 model: model.to_string(),
318 messages: full_history,
319 temperature: req.temperature,
320 tools: req
321 .tools
322 .clone()
323 .into_iter()
324 .map(ToolDefinition::from)
325 .collect::<Vec<_>>(),
326 tool_choice,
327 additional_params: req.additional_params,
328 reasoning_format: ReasoningFormat::Parsed,
329 })
330 }
331}
332
333#[derive(Clone, Debug)]
334pub struct CompletionModel<T = reqwest::Client> {
335 client: Client<T>,
336 pub model: String,
338}
339
340impl<T> CompletionModel<T> {
341 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
342 Self {
343 client,
344 model: model.into(),
345 }
346 }
347}
348
349impl<T> completion::CompletionModel for CompletionModel<T>
350where
351 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
352{
353 type Response = CompletionResponse;
354 type StreamingResponse = StreamingCompletionResponse;
355
356 type Client = Client<T>;
357
358 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
359 Self::new(client.clone(), model)
360 }
361
362 #[cfg_attr(feature = "worker", worker::send)]
363 async fn completion(
364 &self,
365 completion_request: CompletionRequest,
366 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
367 let preamble = completion_request.preamble.clone();
368
369 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
370 let span = if tracing::Span::current().is_disabled() {
371 info_span!(
372 target: "rig::completions",
373 "chat",
374 gen_ai.operation.name = "chat",
375 gen_ai.provider.name = "groq",
376 gen_ai.request.model = self.model,
377 gen_ai.system_instructions = preamble,
378 gen_ai.response.id = tracing::field::Empty,
379 gen_ai.response.model = tracing::field::Empty,
380 gen_ai.usage.output_tokens = tracing::field::Empty,
381 gen_ai.usage.input_tokens = tracing::field::Empty,
382 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
383 gen_ai.output.messages = tracing::field::Empty,
384 )
385 } else {
386 tracing::Span::current()
387 };
388
389 let body = serde_json::to_vec(&request)?;
390 let req = self
391 .client
392 .post("/chat/completions")?
393 .body(body)
394 .map_err(|e| http_client::Error::Instance(e.into()))?;
395
396 let async_block = async move {
397 let response = self.client.send::<_, Bytes>(req).await?;
398 let status = response.status();
399 let response_body = response.into_body().into_future().await?.to_vec();
400
401 if status.is_success() {
402 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
403 ApiResponse::Ok(response) => {
404 let span = tracing::Span::current();
405 span.record("gen_ai.response.id", response.id.clone());
406 span.record("gen_ai.response.model_name", response.model.clone());
407 span.record(
408 "gen_ai.output.messages",
409 serde_json::to_string(&response.choices)?,
410 );
411 if let Some(ref usage) = response.usage {
412 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
413 span.record(
414 "gen_ai.usage.output_tokens",
415 usage.total_tokens - usage.prompt_tokens,
416 );
417 }
418 response.try_into()
419 }
420 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
421 }
422 } else {
423 Err(CompletionError::ProviderError(
424 String::from_utf8_lossy(&response_body).to_string(),
425 ))
426 }
427 };
428
429 tracing::Instrument::instrument(async_block, span).await
430 }
431
432 #[cfg_attr(feature = "worker", worker::send)]
433 async fn stream(
434 &self,
435 request: CompletionRequest,
436 ) -> Result<
437 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
438 CompletionError,
439 > {
440 let preamble = request.preamble.clone();
441 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
442
443 let params = json_utils::merge(
444 request.additional_params.unwrap_or(serde_json::json!({})),
445 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
446 );
447
448 request.additional_params = Some(params);
449
450 let body = serde_json::to_vec(&request)?;
451 let req = self
452 .client
453 .post("/chat/completions")?
454 .body(body)
455 .map_err(|e| http_client::Error::Instance(e.into()))?;
456
457 let span = if tracing::Span::current().is_disabled() {
458 info_span!(
459 target: "rig::completions",
460 "chat_streaming",
461 gen_ai.operation.name = "chat_streaming",
462 gen_ai.provider.name = "groq",
463 gen_ai.request.model = self.model,
464 gen_ai.system_instructions = preamble,
465 gen_ai.response.id = tracing::field::Empty,
466 gen_ai.response.model = tracing::field::Empty,
467 gen_ai.usage.output_tokens = tracing::field::Empty,
468 gen_ai.usage.input_tokens = tracing::field::Empty,
469 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
470 gen_ai.output.messages = tracing::field::Empty,
471 )
472 } else {
473 tracing::Span::current()
474 };
475
476 tracing::Instrument::instrument(
477 send_compatible_streaming_request(self.client.http_client().clone(), req),
478 span,
479 )
480 .await
481 }
482}
483
484pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
489pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
490pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
491
492#[derive(Clone)]
493pub struct TranscriptionModel<T> {
494 client: Client<T>,
495 pub model: String,
497}
498
499impl<T> TranscriptionModel<T> {
500 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
501 Self {
502 client,
503 model: model.into(),
504 }
505 }
506}
507impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
508where
509 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
510{
511 type Response = TranscriptionResponse;
512
513 type Client = Client<T>;
514
515 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
516 Self::new(client.clone(), model)
517 }
518
519 #[cfg_attr(feature = "worker", worker::send)]
520 async fn transcription(
521 &self,
522 request: transcription::TranscriptionRequest,
523 ) -> Result<
524 transcription::TranscriptionResponse<Self::Response>,
525 transcription::TranscriptionError,
526 > {
527 let data = request.data;
528
529 let mut body = reqwest::multipart::Form::new()
530 .text("model", self.model.clone())
531 .part(
532 "file",
533 Part::bytes(data).file_name(request.filename.clone()),
534 );
535
536 if let Some(language) = request.language {
537 body = body.text("language", language);
538 }
539
540 if let Some(prompt) = request.prompt {
541 body = body.text("prompt", prompt.clone());
542 }
543
544 if let Some(ref temperature) = request.temperature {
545 body = body.text("temperature", temperature.to_string());
546 }
547
548 if let Some(ref additional_params) = request.additional_params {
549 for (key, value) in additional_params
550 .as_object()
551 .expect("Additional Parameters to OpenAI Transcription should be a map")
552 {
553 body = body.text(key.to_owned(), value.to_string());
554 }
555 }
556
557 let req = self
558 .client
559 .post("/audio/transcriptions")?
560 .body(body)
561 .unwrap();
562
563 let response = self
564 .client
565 .http_client()
566 .send_multipart::<Bytes>(req)
567 .await
568 .unwrap();
569
570 let status = response.status();
571 let response_body = response.into_body().into_future().await?.to_vec();
572
573 if status.is_success() {
574 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
575 ApiResponse::Ok(response) => response.try_into(),
576 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
577 api_error_response.message,
578 )),
579 }
580 } else {
581 Err(TranscriptionError::ProviderError(
582 String::from_utf8_lossy(&response_body).to_string(),
583 ))
584 }
585 }
586}
587
588#[derive(Deserialize, Debug)]
589#[serde(untagged)]
590enum StreamingDelta {
591 Reasoning {
592 reasoning: String,
593 },
594 MessageContent {
595 #[serde(default)]
596 content: Option<String>,
597 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
598 tool_calls: Vec<StreamingToolCall>,
599 },
600}
601
602#[derive(Deserialize, Debug)]
603struct StreamingChoice {
604 delta: StreamingDelta,
605}
606
607#[derive(Deserialize, Debug)]
608struct StreamingCompletionChunk {
609 choices: Vec<StreamingChoice>,
610 usage: Option<Usage>,
611}
612
613#[derive(Clone, Deserialize, Serialize, Debug)]
614pub struct StreamingCompletionResponse {
615 pub usage: Usage,
616}
617
618impl GetTokenUsage for StreamingCompletionResponse {
619 fn token_usage(&self) -> Option<crate::completion::Usage> {
620 let mut usage = crate::completion::Usage::new();
621
622 usage.input_tokens = self.usage.prompt_tokens as u64;
623 usage.total_tokens = self.usage.total_tokens as u64;
624 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
625
626 Some(usage)
627 }
628}
629
630pub async fn send_compatible_streaming_request<T>(
631 client: T,
632 req: Request<Vec<u8>>,
633) -> Result<
634 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
635 CompletionError,
636>
637where
638 T: HttpClientExt + Clone + 'static,
639{
640 let span = tracing::Span::current();
641
642 let mut event_source = GenericEventSource::new(client, req);
643
644 let stream = stream! {
645 let span = tracing::Span::current();
646 let mut final_usage = Usage {
647 prompt_tokens: 0,
648 total_tokens: 0
649 };
650
651 let mut text_response = String::new();
652
653 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
654
655 while let Some(event_result) = event_source.next().await {
656 match event_result {
657 Ok(Event::Open) => {
658 tracing::trace!("SSE connection opened");
659 continue;
660 }
661
662 Ok(Event::Message(message)) => {
663 let data_str = message.data.trim();
664
665 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
666 let Ok(data) = parsed else {
667 let err = parsed.unwrap_err();
668 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
669 continue;
670 };
671
672 if let Some(choice) = data.choices.first() {
673 match &choice.delta {
674 StreamingDelta::Reasoning { reasoning } => {
675 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
676 id: None,
677 reasoning: reasoning.to_string(),
678 signature: None,
679 });
680 }
681
682 StreamingDelta::MessageContent { content, tool_calls } => {
683 for tool_call in tool_calls {
685 let function = &tool_call.function;
686
687 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
689 && empty_or_none(&function.arguments)
690 {
691 let id = tool_call.id.clone().unwrap_or_default();
692 let name = function.name.clone().unwrap();
693 calls.insert(tool_call.index, (id, name, String::new()));
694 }
695 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
697 && let Some(arguments) = &function.arguments
698 && !arguments.is_empty()
699 {
700 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
701 let combined = format!("{}{}", existing_args, arguments);
702 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
703 } else {
704 tracing::debug!("Partial tool call received but tool call was never started.");
705 }
706 }
707 else {
709 let id = tool_call.id.clone().unwrap_or_default();
710 let name = function.name.clone().unwrap_or_default();
711 let arguments_str = function.arguments.clone().unwrap_or_default();
712
713 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
714 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
715 continue;
716 };
717
718 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
719 id,
720 name,
721 arguments: arguments_json,
722 call_id: None
723 });
724 }
725 }
726
727 if let Some(content) = content {
729 text_response += content;
730 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
731 }
732 }
733 }
734 }
735
736 if let Some(usage) = data.usage {
737 final_usage = usage.clone();
738 }
739 }
740
741 Err(crate::http_client::Error::StreamEnded) => break,
742 Err(err) => {
743 tracing::error!(?err, "SSE error");
744 yield Err(CompletionError::ResponseError(err.to_string()));
745 break;
746 }
747 }
748 }
749
750 event_source.close();
751
752 let mut tool_calls = Vec::new();
753 for (_, (id, name, arguments)) in calls {
755 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
756 continue;
757 };
758
759 tool_calls.push(rig::providers::openai::completion::ToolCall {
760 id: id.clone(),
761 r#type: ToolType::Function,
762 function: Function {
763 name: name.clone(),
764 arguments: arguments_json.clone()
765 }
766 });
767 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
768 id,
769 name,
770 arguments: arguments_json,
771 call_id: None,
772 });
773 }
774
775 let response_message = crate::providers::openai::completion::Message::Assistant {
776 content: vec![AssistantContent::Text { text: text_response }],
777 refusal: None,
778 audio: None,
779 name: None,
780 tool_calls
781 };
782
783 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
784 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
785 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
786
787 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
789 StreamingCompletionResponse { usage: final_usage.clone() }
790 ));
791 }.instrument(span);
792
793 Ok(crate::streaming::StreamingCompletionResponse::stream(
794 Box::pin(stream),
795 ))
796}