1use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23 ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 json_utils,
37 message::{self},
38 providers::openai::ToolDefinition,
39 transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56 type Builder = GroqBuilder;
57
58 const VERIFY_PATH: &'static str = "/models";
59
60 fn build<H>(
61 _: &crate::client::ClientBuilder<
62 Self::Builder,
63 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
64 H,
65 >,
66 ) -> http_client::Result<Self> {
67 Ok(Self)
68 }
69}
70
71impl<H> Capabilities<H> for GroqExt {
72 type Completion = Capable<CompletionModel<H>>;
73 type Embeddings = Nothing;
74 type Transcription = Capable<TranscriptionModel<H>>;
75 type ModelListing = Nothing;
76 #[cfg(feature = "image")]
77 type ImageGeneration = Nothing;
78
79 #[cfg(feature = "audio")]
80 type AudioGeneration = Nothing;
81}
82
83impl DebugExt for GroqExt {}
84
85impl ProviderBuilder for GroqBuilder {
86 type Output = GroqExt;
87 type ApiKey = GroqApiKey;
88
89 const BASE_URL: &'static str = GROQ_API_BASE_URL;
90}
91
92pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
93pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
94
95impl ProviderClient for Client {
96 type Input = String;
97
98 fn from_env() -> Self {
101 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
102 Self::new(&api_key).unwrap()
103 }
104
105 fn from_val(input: Self::Input) -> Self {
106 Self::new(&input).unwrap()
107 }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112 message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118 Ok(T),
119 Err(ApiErrorResponse),
120}
121
122pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
128pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
130pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
132pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
134pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
136pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
138pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
140pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
142pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
144pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
146pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
148pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
150pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
152
153#[derive(Clone, Debug, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ReasoningFormat {
156 Parsed,
157 Raw,
158 Hidden,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162pub(super) struct GroqCompletionRequest {
163 model: String,
164 pub messages: Vec<OpenAIMessage>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 temperature: Option<f64>,
167 #[serde(skip_serializing_if = "Vec::is_empty")]
168 tools: Vec<ToolDefinition>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
171 #[serde(flatten, skip_serializing_if = "Option::is_none")]
172 pub additional_params: Option<GroqAdditionalParameters>,
173 pub(super) stream: bool,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub(super) stream_options: Option<StreamOptions>,
176}
177
178#[derive(Debug, Serialize, Deserialize, Default)]
179pub(super) struct StreamOptions {
180 pub(super) include_usage: bool,
181}
182
183impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
184 type Error = CompletionError;
185
186 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
187 if req.output_schema.is_some() {
188 tracing::warn!("Structured outputs currently not supported for Groq");
189 }
190 let model = req.model.clone().unwrap_or_else(|| model.to_string());
191 let mut partial_history = vec![];
193 if let Some(docs) = req.normalized_documents() {
194 partial_history.push(docs);
195 }
196 partial_history.extend(req.chat_history);
197
198 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
200 Some(preamble) => vec![OpenAIMessage::system(preamble)],
201 None => vec![],
202 };
203
204 full_history.extend(
206 partial_history
207 .into_iter()
208 .map(message::Message::try_into)
209 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
210 .into_iter()
211 .flatten()
212 .collect::<Vec<_>>(),
213 );
214
215 let tool_choice = req
216 .tool_choice
217 .clone()
218 .map(crate::providers::openai::ToolChoice::try_from)
219 .transpose()?;
220
221 let additional_params: Option<GroqAdditionalParameters> =
222 if let Some(params) = req.additional_params {
223 Some(serde_json::from_value(params)?)
224 } else {
225 None
226 };
227
228 Ok(Self {
229 model: model.to_string(),
230 messages: full_history,
231 temperature: req.temperature,
232 tools: req
233 .tools
234 .clone()
235 .into_iter()
236 .map(ToolDefinition::from)
237 .collect::<Vec<_>>(),
238 tool_choice,
239 additional_params,
240 stream: false,
241 stream_options: None,
242 })
243 }
244}
245
246#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub struct GroqAdditionalParameters {
249 #[serde(skip_serializing_if = "Option::is_none")]
251 pub reasoning_format: Option<ReasoningFormat>,
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub include_reasoning: Option<bool>,
255 #[serde(flatten, skip_serializing_if = "Option::is_none")]
257 pub extra: Option<Map<String, serde_json::Value>>,
258}
259
260#[derive(Clone, Debug)]
261pub struct CompletionModel<T = reqwest::Client> {
262 client: Client<T>,
263 pub model: String,
265}
266
267impl<T> CompletionModel<T> {
268 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
269 Self {
270 client,
271 model: model.into(),
272 }
273 }
274}
275
276impl<T> completion::CompletionModel for CompletionModel<T>
277where
278 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
279{
280 type Response = CompletionResponse;
281 type StreamingResponse = StreamingCompletionResponse;
282
283 type Client = Client<T>;
284
285 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
286 Self::new(client.clone(), model)
287 }
288
289 async fn completion(
290 &self,
291 completion_request: CompletionRequest,
292 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
293 let span = if tracing::Span::current().is_disabled() {
294 info_span!(
295 target: "rig::completions",
296 "chat",
297 gen_ai.operation.name = "chat",
298 gen_ai.provider.name = "groq",
299 gen_ai.request.model = self.model,
300 gen_ai.system_instructions = tracing::field::Empty,
301 gen_ai.response.id = tracing::field::Empty,
302 gen_ai.response.model = tracing::field::Empty,
303 gen_ai.usage.output_tokens = tracing::field::Empty,
304 gen_ai.usage.input_tokens = tracing::field::Empty,
305 )
306 } else {
307 tracing::Span::current()
308 };
309
310 span.record("gen_ai.system_instructions", &completion_request.preamble);
311
312 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
313
314 if tracing::enabled!(tracing::Level::TRACE) {
315 tracing::trace!(target: "rig::completions",
316 "Groq completion request: {}",
317 serde_json::to_string_pretty(&request)?
318 );
319 }
320
321 let body = serde_json::to_vec(&request)?;
322 let req = self
323 .client
324 .post("/chat/completions")?
325 .body(body)
326 .map_err(|e| http_client::Error::Instance(e.into()))?;
327
328 let async_block = async move {
329 let response = self.client.send::<_, Bytes>(req).await?;
330 let status = response.status();
331 let response_body = response.into_body().into_future().await?.to_vec();
332
333 if status.is_success() {
334 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
335 ApiResponse::Ok(response) => {
336 let span = tracing::Span::current();
337 span.record("gen_ai.response.id", response.id.clone());
338 span.record("gen_ai.response.model_name", response.model.clone());
339 if let Some(ref usage) = response.usage {
340 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
341 span.record(
342 "gen_ai.usage.output_tokens",
343 usage.total_tokens - usage.prompt_tokens,
344 );
345 }
346
347 if tracing::enabled!(tracing::Level::TRACE) {
348 tracing::trace!(target: "rig::completions",
349 "Groq completion response: {}",
350 serde_json::to_string_pretty(&response)?
351 );
352 }
353
354 response.try_into()
355 }
356 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
357 }
358 } else {
359 Err(CompletionError::ProviderError(
360 String::from_utf8_lossy(&response_body).to_string(),
361 ))
362 }
363 };
364
365 tracing::Instrument::instrument(async_block, span).await
366 }
367
368 async fn stream(
369 &self,
370 request: CompletionRequest,
371 ) -> Result<
372 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
373 CompletionError,
374 > {
375 let span = if tracing::Span::current().is_disabled() {
376 info_span!(
377 target: "rig::completions",
378 "chat_streaming",
379 gen_ai.operation.name = "chat_streaming",
380 gen_ai.provider.name = "groq",
381 gen_ai.request.model = self.model,
382 gen_ai.system_instructions = tracing::field::Empty,
383 gen_ai.response.id = tracing::field::Empty,
384 gen_ai.response.model = tracing::field::Empty,
385 gen_ai.usage.output_tokens = tracing::field::Empty,
386 gen_ai.usage.input_tokens = tracing::field::Empty,
387 )
388 } else {
389 tracing::Span::current()
390 };
391
392 span.record("gen_ai.system_instructions", &request.preamble);
393
394 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
395
396 request.stream = true;
397 request.stream_options = Some(StreamOptions {
398 include_usage: true,
399 });
400
401 if tracing::enabled!(tracing::Level::TRACE) {
402 tracing::trace!(target: "rig::completions",
403 "Groq streaming completion request: {}",
404 serde_json::to_string_pretty(&request)?
405 );
406 }
407
408 let body = serde_json::to_vec(&request)?;
409 let req = self
410 .client
411 .post("/chat/completions")?
412 .body(body)
413 .map_err(|e| http_client::Error::Instance(e.into()))?;
414
415 tracing::Instrument::instrument(
416 send_compatible_streaming_request(self.client.clone(), req),
417 span,
418 )
419 .await
420 }
421}
422
423pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
428pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
429pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
430
431#[derive(Clone)]
432pub struct TranscriptionModel<T> {
433 client: Client<T>,
434 pub model: String,
436}
437
438impl<T> TranscriptionModel<T> {
439 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
440 Self {
441 client,
442 model: model.into(),
443 }
444 }
445}
446impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
447where
448 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
449{
450 type Response = TranscriptionResponse;
451
452 type Client = Client<T>;
453
454 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
455 Self::new(client.clone(), model)
456 }
457
458 async fn transcription(
459 &self,
460 request: transcription::TranscriptionRequest,
461 ) -> Result<
462 transcription::TranscriptionResponse<Self::Response>,
463 transcription::TranscriptionError,
464 > {
465 let data = request.data;
466
467 let mut body = MultipartForm::new()
468 .text("model", self.model.clone())
469 .part(Part::bytes("file", data).filename(request.filename.clone()));
470
471 if let Some(language) = request.language {
472 body = body.text("language", language);
473 }
474
475 if let Some(prompt) = request.prompt {
476 body = body.text("prompt", prompt.clone());
477 }
478
479 if let Some(ref temperature) = request.temperature {
480 body = body.text("temperature", temperature.to_string());
481 }
482
483 if let Some(ref additional_params) = request.additional_params {
484 for (key, value) in additional_params
485 .as_object()
486 .expect("Additional Parameters to OpenAI Transcription should be a map")
487 {
488 body = body.text(key.to_owned(), value.to_string());
489 }
490 }
491
492 let req = self
493 .client
494 .post("/audio/transcriptions")?
495 .body(body)
496 .unwrap();
497
498 let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
499
500 let status = response.status();
501 let response_body = response.into_body().into_future().await?.to_vec();
502
503 if status.is_success() {
504 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
505 ApiResponse::Ok(response) => response.try_into(),
506 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
507 api_error_response.message,
508 )),
509 }
510 } else {
511 Err(TranscriptionError::ProviderError(
512 String::from_utf8_lossy(&response_body).to_string(),
513 ))
514 }
515 }
516}
517
518#[derive(Deserialize, Debug)]
519#[serde(untagged)]
520enum StreamingDelta {
521 Reasoning {
522 reasoning: String,
523 },
524 MessageContent {
525 #[serde(default)]
526 content: Option<String>,
527 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
528 tool_calls: Vec<StreamingToolCall>,
529 },
530}
531
532#[derive(Deserialize, Debug)]
533struct StreamingChoice {
534 delta: StreamingDelta,
535}
536
537#[derive(Deserialize, Debug)]
538struct StreamingCompletionChunk {
539 choices: Vec<StreamingChoice>,
540 usage: Option<Usage>,
541}
542
543#[derive(Clone, Deserialize, Serialize, Debug)]
544pub struct StreamingCompletionResponse {
545 pub usage: Usage,
546}
547
548impl GetTokenUsage for StreamingCompletionResponse {
549 fn token_usage(&self) -> Option<crate::completion::Usage> {
550 let mut usage = crate::completion::Usage::new();
551
552 usage.input_tokens = self.usage.prompt_tokens as u64;
553 usage.total_tokens = self.usage.total_tokens as u64;
554 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
555 usage.cached_input_tokens = self
556 .usage
557 .prompt_tokens_details
558 .as_ref()
559 .map(|d| d.cached_tokens as u64)
560 .unwrap_or(0);
561
562 Some(usage)
563 }
564}
565
566pub async fn send_compatible_streaming_request<T>(
567 client: T,
568 req: Request<Vec<u8>>,
569) -> Result<
570 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
571 CompletionError,
572>
573where
574 T: HttpClientExt + Clone + 'static,
575{
576 let span = tracing::Span::current();
577
578 let mut event_source = GenericEventSource::new(client, req);
579
580 let stream = stream! {
581 let span = tracing::Span::current();
582 let mut final_usage = Usage {
583 prompt_tokens: 0,
584 total_tokens: 0,
585 prompt_tokens_details: None,
586 };
587
588 let mut text_response = String::new();
589
590 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
591
592 while let Some(event_result) = event_source.next().await {
593 match event_result {
594 Ok(Event::Open) => {
595 tracing::trace!("SSE connection opened");
596 continue;
597 }
598
599 Ok(Event::Message(message)) => {
600 let data_str = message.data.trim();
601
602 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
603 let Ok(data) = parsed else {
604 let err = parsed.unwrap_err();
605 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
606 continue;
607 };
608
609 if let Some(choice) = data.choices.first() {
610 match &choice.delta {
611 StreamingDelta::Reasoning { reasoning } => {
612 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
613 id: None,
614 reasoning: reasoning.to_string(),
615 });
616 }
617
618 StreamingDelta::MessageContent { content, tool_calls } => {
619 for tool_call in tool_calls {
621 let function = &tool_call.function;
622
623 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
625 && empty_or_none(&function.arguments)
626 {
627 let id = tool_call.id.clone().unwrap_or_default();
628 let name = function.name.clone().unwrap();
629 calls.insert(tool_call.index, (id, name, String::new()));
630 }
631 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
633 && let Some(arguments) = &function.arguments
634 && !arguments.is_empty()
635 {
636 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
637 let combined = format!("{}{}", existing_args, arguments);
638 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
639 } else {
640 tracing::debug!("Partial tool call received but tool call was never started.");
641 }
642 }
643 else {
645 let id = tool_call.id.clone().unwrap_or_default();
646 let name = function.name.clone().unwrap_or_default();
647 let arguments_str = function.arguments.clone().unwrap_or_default();
648
649 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
650 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
651 continue;
652 };
653
654 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
655 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
656 ));
657 }
658 }
659
660 if let Some(content) = content {
662 text_response += content;
663 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
664 }
665 }
666 }
667 }
668
669 if let Some(usage) = data.usage {
670 final_usage = usage.clone();
671 }
672 }
673
674 Err(crate::http_client::Error::StreamEnded) => break,
675 Err(err) => {
676 tracing::error!(?err, "SSE error");
677 yield Err(CompletionError::ResponseError(err.to_string()));
678 break;
679 }
680 }
681 }
682
683 event_source.close();
684
685 let mut tool_calls = Vec::new();
686 for (_, (id, name, arguments)) in calls {
688 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
689 continue;
690 };
691
692 tool_calls.push(rig::providers::openai::completion::ToolCall {
693 id: id.clone(),
694 r#type: ToolType::Function,
695 function: Function {
696 name: name.clone(),
697 arguments: arguments_json.clone()
698 }
699 });
700 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
701 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
702 ));
703 }
704
705 let response_message = crate::providers::openai::completion::Message::Assistant {
706 content: vec![AssistantContent::Text { text: text_response }],
707 refusal: None,
708 audio: None,
709 name: None,
710 tool_calls
711 };
712
713 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
714 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
715 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
716
717 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
719 StreamingCompletionResponse { usage: final_usage.clone() }
720 ));
721 }.instrument(span);
722
723 Ok(crate::streaming::StreamingCompletionResponse::stream(
724 Box::pin(stream),
725 ))
726}
727
728#[cfg(test)]
729mod tests {
730 use crate::{
731 OneOrMany,
732 providers::{
733 groq::{GroqAdditionalParameters, GroqCompletionRequest},
734 openai::{Message, UserContent},
735 },
736 };
737
738 #[test]
739 fn serialize_groq_request() {
740 let additional_params = GroqAdditionalParameters {
741 include_reasoning: Some(true),
742 reasoning_format: Some(super::ReasoningFormat::Parsed),
743 ..Default::default()
744 };
745
746 let groq = GroqCompletionRequest {
747 model: "openai/gpt-120b-oss".to_string(),
748 temperature: None,
749 tool_choice: None,
750 stream_options: None,
751 tools: Vec::new(),
752 messages: vec![Message::User {
753 content: OneOrMany::one(UserContent::Text {
754 text: "Hello world!".to_string(),
755 }),
756 name: None,
757 }],
758 stream: false,
759 additional_params: Some(additional_params),
760 };
761
762 let json = serde_json::to_value(&groq).unwrap();
763
764 assert_eq!(
765 json,
766 serde_json::json!({
767 "model": "openai/gpt-120b-oss",
768 "messages": [
769 {
770 "role": "user",
771 "content": "Hello world!"
772 }
773 ],
774 "stream": false,
775 "include_reasoning": true,
776 "reasoning_format": "parsed"
777 })
778 )
779 }
780 #[test]
781 fn test_client_initialization() {
782 let _client: crate::providers::groq::Client =
783 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
784 let _client_from_builder: crate::providers::groq::Client =
785 crate::providers::groq::Client::builder()
786 .api_key("dummy-key")
787 .build()
788 .expect("Client::builder() failed");
789 }
790}