1use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23 ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 json_utils,
37 message::{self},
38 providers::openai::ToolDefinition,
39 transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56 type Builder = GroqBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for GroqExt {
61 type Completion = Capable<CompletionModel<H>>;
62 type Embeddings = Nothing;
63 type Transcription = Capable<TranscriptionModel<H>>;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Nothing;
67
68 #[cfg(feature = "audio")]
69 type AudioGeneration = Nothing;
70}
71
72impl DebugExt for GroqExt {}
73
74impl ProviderBuilder for GroqBuilder {
75 type Extension<H>
76 = GroqExt
77 where
78 H: HttpClientExt;
79 type ApiKey = GroqApiKey;
80
81 const BASE_URL: &'static str = GROQ_API_BASE_URL;
82
83 fn build<H>(
84 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
85 ) -> http_client::Result<Self::Extension<H>>
86 where
87 H: HttpClientExt,
88 {
89 Ok(GroqExt)
90 }
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
94pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
95
96impl ProviderClient for Client {
97 type Input = String;
98
99 fn from_env() -> Self {
102 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
103 Self::new(&api_key).unwrap()
104 }
105
106 fn from_val(input: Self::Input) -> Self {
107 Self::new(&input).unwrap()
108 }
109}
110
111#[derive(Debug, Deserialize)]
112struct ApiErrorResponse {
113 message: String,
114}
115
116#[derive(Debug, Deserialize)]
117#[serde(untagged)]
118enum ApiResponse<T> {
119 Ok(T),
120 Err(ApiErrorResponse),
121}
122
123pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
129pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
131pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
133pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
135pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
137pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
139pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
141pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
143pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
145pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
147pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
149pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
151pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum ReasoningFormat {
157 Parsed,
158 Raw,
159 Hidden,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub(super) struct GroqCompletionRequest {
164 model: String,
165 pub messages: Vec<OpenAIMessage>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 temperature: Option<f64>,
168 #[serde(skip_serializing_if = "Vec::is_empty")]
169 tools: Vec<ToolDefinition>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
172 #[serde(flatten, skip_serializing_if = "Option::is_none")]
173 pub additional_params: Option<GroqAdditionalParameters>,
174 pub(super) stream: bool,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub(super) stream_options: Option<StreamOptions>,
177}
178
179#[derive(Debug, Serialize, Deserialize, Default)]
180pub(super) struct StreamOptions {
181 pub(super) include_usage: bool,
182}
183
184impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
185 type Error = CompletionError;
186
187 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
188 if req.output_schema.is_some() {
189 tracing::warn!("Structured outputs currently not supported for Groq");
190 }
191 let model = req.model.clone().unwrap_or_else(|| model.to_string());
192 let mut partial_history = vec![];
194 if let Some(docs) = req.normalized_documents() {
195 partial_history.push(docs);
196 }
197 partial_history.extend(req.chat_history);
198
199 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
201 Some(preamble) => vec![OpenAIMessage::system(preamble)],
202 None => vec![],
203 };
204
205 full_history.extend(
207 partial_history
208 .into_iter()
209 .map(message::Message::try_into)
210 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
211 .into_iter()
212 .flatten()
213 .collect::<Vec<_>>(),
214 );
215
216 let tool_choice = req
217 .tool_choice
218 .clone()
219 .map(crate::providers::openai::ToolChoice::try_from)
220 .transpose()?;
221
222 let additional_params: Option<GroqAdditionalParameters> =
223 if let Some(params) = req.additional_params {
224 Some(serde_json::from_value(params)?)
225 } else {
226 None
227 };
228
229 Ok(Self {
230 model: model.to_string(),
231 messages: full_history,
232 temperature: req.temperature,
233 tools: req
234 .tools
235 .clone()
236 .into_iter()
237 .map(ToolDefinition::from)
238 .collect::<Vec<_>>(),
239 tool_choice,
240 additional_params,
241 stream: false,
242 stream_options: None,
243 })
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub struct GroqAdditionalParameters {
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub reasoning_format: Option<ReasoningFormat>,
253 #[serde(skip_serializing_if = "Option::is_none")]
255 pub include_reasoning: Option<bool>,
256 #[serde(flatten, skip_serializing_if = "Option::is_none")]
258 pub extra: Option<Map<String, serde_json::Value>>,
259}
260
261#[derive(Clone, Debug)]
262pub struct CompletionModel<T = reqwest::Client> {
263 client: Client<T>,
264 pub model: String,
266}
267
268impl<T> CompletionModel<T> {
269 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
270 Self {
271 client,
272 model: model.into(),
273 }
274 }
275}
276
277impl<T> completion::CompletionModel for CompletionModel<T>
278where
279 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
280{
281 type Response = CompletionResponse;
282 type StreamingResponse = StreamingCompletionResponse;
283
284 type Client = Client<T>;
285
286 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
287 Self::new(client.clone(), model)
288 }
289
290 async fn completion(
291 &self,
292 completion_request: CompletionRequest,
293 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
294 let span = if tracing::Span::current().is_disabled() {
295 info_span!(
296 target: "rig::completions",
297 "chat",
298 gen_ai.operation.name = "chat",
299 gen_ai.provider.name = "groq",
300 gen_ai.request.model = self.model,
301 gen_ai.system_instructions = tracing::field::Empty,
302 gen_ai.response.id = tracing::field::Empty,
303 gen_ai.response.model = tracing::field::Empty,
304 gen_ai.usage.output_tokens = tracing::field::Empty,
305 gen_ai.usage.input_tokens = tracing::field::Empty,
306 )
307 } else {
308 tracing::Span::current()
309 };
310
311 span.record("gen_ai.system_instructions", &completion_request.preamble);
312
313 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
314
315 if tracing::enabled!(tracing::Level::TRACE) {
316 tracing::trace!(target: "rig::completions",
317 "Groq completion request: {}",
318 serde_json::to_string_pretty(&request)?
319 );
320 }
321
322 let body = serde_json::to_vec(&request)?;
323 let req = self
324 .client
325 .post("/chat/completions")?
326 .body(body)
327 .map_err(|e| http_client::Error::Instance(e.into()))?;
328
329 let async_block = async move {
330 let response = self.client.send::<_, Bytes>(req).await?;
331 let status = response.status();
332 let response_body = response.into_body().into_future().await?.to_vec();
333
334 if status.is_success() {
335 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
336 ApiResponse::Ok(response) => {
337 let span = tracing::Span::current();
338 span.record("gen_ai.response.id", response.id.clone());
339 span.record("gen_ai.response.model_name", response.model.clone());
340 if let Some(ref usage) = response.usage {
341 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
342 span.record(
343 "gen_ai.usage.output_tokens",
344 usage.total_tokens - usage.prompt_tokens,
345 );
346 }
347
348 if tracing::enabled!(tracing::Level::TRACE) {
349 tracing::trace!(target: "rig::completions",
350 "Groq completion response: {}",
351 serde_json::to_string_pretty(&response)?
352 );
353 }
354
355 response.try_into()
356 }
357 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
358 }
359 } else {
360 Err(CompletionError::ProviderError(
361 String::from_utf8_lossy(&response_body).to_string(),
362 ))
363 }
364 };
365
366 tracing::Instrument::instrument(async_block, span).await
367 }
368
369 async fn stream(
370 &self,
371 request: CompletionRequest,
372 ) -> Result<
373 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
374 CompletionError,
375 > {
376 let span = if tracing::Span::current().is_disabled() {
377 info_span!(
378 target: "rig::completions",
379 "chat_streaming",
380 gen_ai.operation.name = "chat_streaming",
381 gen_ai.provider.name = "groq",
382 gen_ai.request.model = self.model,
383 gen_ai.system_instructions = tracing::field::Empty,
384 gen_ai.response.id = tracing::field::Empty,
385 gen_ai.response.model = tracing::field::Empty,
386 gen_ai.usage.output_tokens = tracing::field::Empty,
387 gen_ai.usage.input_tokens = tracing::field::Empty,
388 )
389 } else {
390 tracing::Span::current()
391 };
392
393 span.record("gen_ai.system_instructions", &request.preamble);
394
395 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
396
397 request.stream = true;
398 request.stream_options = Some(StreamOptions {
399 include_usage: true,
400 });
401
402 if tracing::enabled!(tracing::Level::TRACE) {
403 tracing::trace!(target: "rig::completions",
404 "Groq streaming completion request: {}",
405 serde_json::to_string_pretty(&request)?
406 );
407 }
408
409 let body = serde_json::to_vec(&request)?;
410 let req = self
411 .client
412 .post("/chat/completions")?
413 .body(body)
414 .map_err(|e| http_client::Error::Instance(e.into()))?;
415
416 tracing::Instrument::instrument(
417 send_compatible_streaming_request(self.client.clone(), req),
418 span,
419 )
420 .await
421 }
422}
423
424pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
429pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
430pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
431
432#[derive(Clone)]
433pub struct TranscriptionModel<T> {
434 client: Client<T>,
435 pub model: String,
437}
438
439impl<T> TranscriptionModel<T> {
440 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
441 Self {
442 client,
443 model: model.into(),
444 }
445 }
446}
447impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
448where
449 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
450{
451 type Response = TranscriptionResponse;
452
453 type Client = Client<T>;
454
455 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
456 Self::new(client.clone(), model)
457 }
458
459 async fn transcription(
460 &self,
461 request: transcription::TranscriptionRequest,
462 ) -> Result<
463 transcription::TranscriptionResponse<Self::Response>,
464 transcription::TranscriptionError,
465 > {
466 let data = request.data;
467
468 let mut body = MultipartForm::new()
469 .text("model", self.model.clone())
470 .part(Part::bytes("file", data).filename(request.filename.clone()));
471
472 if let Some(language) = request.language {
473 body = body.text("language", language);
474 }
475
476 if let Some(prompt) = request.prompt {
477 body = body.text("prompt", prompt.clone());
478 }
479
480 if let Some(ref temperature) = request.temperature {
481 body = body.text("temperature", temperature.to_string());
482 }
483
484 if let Some(ref additional_params) = request.additional_params {
485 for (key, value) in additional_params
486 .as_object()
487 .expect("Additional Parameters to OpenAI Transcription should be a map")
488 {
489 body = body.text(key.to_owned(), value.to_string());
490 }
491 }
492
493 let req = self
494 .client
495 .post("/audio/transcriptions")?
496 .body(body)
497 .unwrap();
498
499 let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
500
501 let status = response.status();
502 let response_body = response.into_body().into_future().await?.to_vec();
503
504 if status.is_success() {
505 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
506 ApiResponse::Ok(response) => response.try_into(),
507 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
508 api_error_response.message,
509 )),
510 }
511 } else {
512 Err(TranscriptionError::ProviderError(
513 String::from_utf8_lossy(&response_body).to_string(),
514 ))
515 }
516 }
517}
518
519#[derive(Deserialize, Debug)]
520#[serde(untagged)]
521enum StreamingDelta {
522 Reasoning {
523 reasoning: String,
524 },
525 MessageContent {
526 #[serde(default)]
527 content: Option<String>,
528 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
529 tool_calls: Vec<StreamingToolCall>,
530 },
531}
532
533#[derive(Deserialize, Debug)]
534struct StreamingChoice {
535 delta: StreamingDelta,
536}
537
538#[derive(Deserialize, Debug)]
539struct StreamingCompletionChunk {
540 choices: Vec<StreamingChoice>,
541 usage: Option<Usage>,
542}
543
544#[derive(Clone, Deserialize, Serialize, Debug)]
545pub struct StreamingCompletionResponse {
546 pub usage: Usage,
547}
548
549impl GetTokenUsage for StreamingCompletionResponse {
550 fn token_usage(&self) -> Option<crate::completion::Usage> {
551 let mut usage = crate::completion::Usage::new();
552
553 usage.input_tokens = self.usage.prompt_tokens as u64;
554 usage.total_tokens = self.usage.total_tokens as u64;
555 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
556 usage.cached_input_tokens = self
557 .usage
558 .prompt_tokens_details
559 .as_ref()
560 .map(|d| d.cached_tokens as u64)
561 .unwrap_or(0);
562
563 Some(usage)
564 }
565}
566
567pub async fn send_compatible_streaming_request<T>(
568 client: T,
569 req: Request<Vec<u8>>,
570) -> Result<
571 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
572 CompletionError,
573>
574where
575 T: HttpClientExt + Clone + 'static,
576{
577 let span = tracing::Span::current();
578
579 let mut event_source = GenericEventSource::new(client, req);
580
581 let stream = stream! {
582 let span = tracing::Span::current();
583 let mut final_usage = Usage {
584 prompt_tokens: 0,
585 total_tokens: 0,
586 prompt_tokens_details: None,
587 };
588
589 let mut text_response = String::new();
590
591 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
592
593 while let Some(event_result) = event_source.next().await {
594 match event_result {
595 Ok(Event::Open) => {
596 tracing::trace!("SSE connection opened");
597 continue;
598 }
599
600 Ok(Event::Message(message)) => {
601 let data_str = message.data.trim();
602
603 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
604 let Ok(data) = parsed else {
605 let err = parsed.unwrap_err();
606 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
607 continue;
608 };
609
610 if let Some(choice) = data.choices.first() {
611 match &choice.delta {
612 StreamingDelta::Reasoning { reasoning } => {
613 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
614 id: None,
615 reasoning: reasoning.to_string(),
616 });
617 }
618
619 StreamingDelta::MessageContent { content, tool_calls } => {
620 for tool_call in tool_calls {
622 let function = &tool_call.function;
623
624 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
626 && empty_or_none(&function.arguments)
627 {
628 let id = tool_call.id.clone().unwrap_or_default();
629 let name = function.name.clone().unwrap();
630 calls.insert(tool_call.index, (id, name, String::new()));
631 }
632 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
634 && let Some(arguments) = &function.arguments
635 && !arguments.is_empty()
636 {
637 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
638 let combined = format!("{}{}", existing_args, arguments);
639 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
640 } else {
641 tracing::debug!("Partial tool call received but tool call was never started.");
642 }
643 }
644 else {
646 let id = tool_call.id.clone().unwrap_or_default();
647 let name = function.name.clone().unwrap_or_default();
648 let arguments_str = function.arguments.clone().unwrap_or_default();
649
650 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
651 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
652 continue;
653 };
654
655 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
656 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
657 ));
658 }
659 }
660
661 if let Some(content) = content {
663 text_response += content;
664 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
665 }
666 }
667 }
668 }
669
670 if let Some(usage) = data.usage {
671 final_usage = usage.clone();
672 }
673 }
674
675 Err(crate::http_client::Error::StreamEnded) => break,
676 Err(err) => {
677 tracing::error!(?err, "SSE error");
678 yield Err(CompletionError::ResponseError(err.to_string()));
679 break;
680 }
681 }
682 }
683
684 event_source.close();
685
686 let mut tool_calls = Vec::new();
687 for (_, (id, name, arguments)) in calls {
689 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
690 continue;
691 };
692
693 tool_calls.push(rig::providers::openai::completion::ToolCall {
694 id: id.clone(),
695 r#type: ToolType::Function,
696 function: Function {
697 name: name.clone(),
698 arguments: arguments_json.clone()
699 }
700 });
701 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
702 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
703 ));
704 }
705
706 let response_message = crate::providers::openai::completion::Message::Assistant {
707 content: vec![AssistantContent::Text { text: text_response }],
708 refusal: None,
709 audio: None,
710 name: None,
711 tool_calls
712 };
713
714 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
715 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
716 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
717
718 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
720 StreamingCompletionResponse { usage: final_usage.clone() }
721 ));
722 }.instrument(span);
723
724 Ok(crate::streaming::StreamingCompletionResponse::stream(
725 Box::pin(stream),
726 ))
727}
728
729#[cfg(test)]
730mod tests {
731 use crate::{
732 OneOrMany,
733 providers::{
734 groq::{GroqAdditionalParameters, GroqCompletionRequest},
735 openai::{Message, UserContent},
736 },
737 };
738
739 #[test]
740 fn serialize_groq_request() {
741 let additional_params = GroqAdditionalParameters {
742 include_reasoning: Some(true),
743 reasoning_format: Some(super::ReasoningFormat::Parsed),
744 ..Default::default()
745 };
746
747 let groq = GroqCompletionRequest {
748 model: "openai/gpt-120b-oss".to_string(),
749 temperature: None,
750 tool_choice: None,
751 stream_options: None,
752 tools: Vec::new(),
753 messages: vec![Message::User {
754 content: OneOrMany::one(UserContent::Text {
755 text: "Hello world!".to_string(),
756 }),
757 name: None,
758 }],
759 stream: false,
760 additional_params: Some(additional_params),
761 };
762
763 let json = serde_json::to_value(&groq).unwrap();
764
765 assert_eq!(
766 json,
767 serde_json::json!({
768 "model": "openai/gpt-120b-oss",
769 "messages": [
770 {
771 "role": "user",
772 "content": "Hello world!"
773 }
774 ],
775 "stream": false,
776 "include_reasoning": true,
777 "reasoning_format": "parsed"
778 })
779 )
780 }
781 #[test]
782 fn test_client_initialization() {
783 let _client =
784 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
785 let _client_from_builder = crate::providers::groq::Client::builder()
786 .api_key("dummy-key")
787 .build()
788 .expect("Client::builder() failed");
789 }
790}