1use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23 ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 json_utils,
37 message::{self},
38 providers::openai::ToolDefinition,
39 transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56 type Builder = GroqBuilder;
57
58 const VERIFY_PATH: &'static str = "/models";
59
60 fn build<H>(
61 _: &crate::client::ClientBuilder<
62 Self::Builder,
63 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
64 H,
65 >,
66 ) -> http_client::Result<Self> {
67 Ok(Self)
68 }
69}
70
71impl<H> Capabilities<H> for GroqExt {
72 type Completion = Capable<CompletionModel<H>>;
73 type Embeddings = Nothing;
74 type Transcription = Capable<TranscriptionModel<H>>;
75 #[cfg(feature = "image")]
76 type ImageGeneration = Nothing;
77
78 #[cfg(feature = "audio")]
79 type AudioGeneration = Nothing;
80}
81
82impl DebugExt for GroqExt {}
83
84impl ProviderBuilder for GroqBuilder {
85 type Output = GroqExt;
86 type ApiKey = GroqApiKey;
87
88 const BASE_URL: &'static str = GROQ_API_BASE_URL;
89}
90
91pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
92pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
93
94impl ProviderClient for Client {
95 type Input = String;
96
97 fn from_env() -> Self {
100 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
101 Self::new(&api_key).unwrap()
102 }
103
104 fn from_val(input: Self::Input) -> Self {
105 Self::new(&input).unwrap()
106 }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111 message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117 Ok(T),
118 Err(ApiErrorResponse),
119}
120
121pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
127pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
129pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
131pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
133pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
135pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
137pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
139pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
141pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
143pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
145pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
147pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
149pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
153#[serde(rename_all = "lowercase")]
154pub enum ReasoningFormat {
155 Parsed,
156 Raw,
157 Hidden,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
161pub(super) struct GroqCompletionRequest {
162 model: String,
163 pub messages: Vec<OpenAIMessage>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 temperature: Option<f64>,
166 #[serde(skip_serializing_if = "Vec::is_empty")]
167 tools: Vec<ToolDefinition>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
170 #[serde(flatten, skip_serializing_if = "Option::is_none")]
171 pub additional_params: Option<GroqAdditionalParameters>,
172 pub(super) stream: bool,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub(super) stream_options: Option<StreamOptions>,
175}
176
177#[derive(Debug, Serialize, Deserialize, Default)]
178pub(super) struct StreamOptions {
179 pub(super) include_usage: bool,
180}
181
182impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
183 type Error = CompletionError;
184
185 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
186 let mut partial_history = vec![];
188 if let Some(docs) = req.normalized_documents() {
189 partial_history.push(docs);
190 }
191 partial_history.extend(req.chat_history);
192
193 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
195 Some(preamble) => vec![OpenAIMessage::system(preamble)],
196 None => vec![],
197 };
198
199 full_history.extend(
201 partial_history
202 .into_iter()
203 .map(message::Message::try_into)
204 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
205 .into_iter()
206 .flatten()
207 .collect::<Vec<_>>(),
208 );
209
210 let tool_choice = req
211 .tool_choice
212 .clone()
213 .map(crate::providers::openai::ToolChoice::try_from)
214 .transpose()?;
215
216 let additional_params: Option<GroqAdditionalParameters> =
217 if let Some(params) = req.additional_params {
218 Some(serde_json::from_value(params)?)
219 } else {
220 None
221 };
222
223 Ok(Self {
224 model: model.to_string(),
225 messages: full_history,
226 temperature: req.temperature,
227 tools: req
228 .tools
229 .clone()
230 .into_iter()
231 .map(ToolDefinition::from)
232 .collect::<Vec<_>>(),
233 tool_choice,
234 additional_params,
235 stream: false,
236 stream_options: None,
237 })
238 }
239}
240
241#[derive(Clone, Debug, Default, Serialize, Deserialize)]
243pub struct GroqAdditionalParameters {
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub reasoning_format: Option<ReasoningFormat>,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub include_reasoning: Option<bool>,
250 #[serde(flatten, skip_serializing_if = "Option::is_none")]
252 pub extra: Option<Map<String, serde_json::Value>>,
253}
254
255#[derive(Clone, Debug)]
256pub struct CompletionModel<T = reqwest::Client> {
257 client: Client<T>,
258 pub model: String,
260}
261
262impl<T> CompletionModel<T> {
263 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
264 Self {
265 client,
266 model: model.into(),
267 }
268 }
269}
270
271impl<T> completion::CompletionModel for CompletionModel<T>
272where
273 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
274{
275 type Response = CompletionResponse;
276 type StreamingResponse = StreamingCompletionResponse;
277
278 type Client = Client<T>;
279
280 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
281 Self::new(client.clone(), model)
282 }
283
284 async fn completion(
285 &self,
286 completion_request: CompletionRequest,
287 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
288 let span = if tracing::Span::current().is_disabled() {
289 info_span!(
290 target: "rig::completions",
291 "chat",
292 gen_ai.operation.name = "chat",
293 gen_ai.provider.name = "groq",
294 gen_ai.request.model = self.model,
295 gen_ai.system_instructions = tracing::field::Empty,
296 gen_ai.response.id = tracing::field::Empty,
297 gen_ai.response.model = tracing::field::Empty,
298 gen_ai.usage.output_tokens = tracing::field::Empty,
299 gen_ai.usage.input_tokens = tracing::field::Empty,
300 )
301 } else {
302 tracing::Span::current()
303 };
304
305 span.record("gen_ai.system_instructions", &completion_request.preamble);
306
307 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
308
309 if tracing::enabled!(tracing::Level::TRACE) {
310 tracing::trace!(target: "rig::completions",
311 "Groq completion request: {}",
312 serde_json::to_string_pretty(&request)?
313 );
314 }
315
316 let body = serde_json::to_vec(&request)?;
317 let req = self
318 .client
319 .post("/chat/completions")?
320 .body(body)
321 .map_err(|e| http_client::Error::Instance(e.into()))?;
322
323 let async_block = async move {
324 let response = self.client.send::<_, Bytes>(req).await?;
325 let status = response.status();
326 let response_body = response.into_body().into_future().await?.to_vec();
327
328 if status.is_success() {
329 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
330 ApiResponse::Ok(response) => {
331 let span = tracing::Span::current();
332 span.record("gen_ai.response.id", response.id.clone());
333 span.record("gen_ai.response.model_name", response.model.clone());
334 if let Some(ref usage) = response.usage {
335 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
336 span.record(
337 "gen_ai.usage.output_tokens",
338 usage.total_tokens - usage.prompt_tokens,
339 );
340 }
341
342 if tracing::enabled!(tracing::Level::TRACE) {
343 tracing::trace!(target: "rig::completions",
344 "Groq completion response: {}",
345 serde_json::to_string_pretty(&response)?
346 );
347 }
348
349 response.try_into()
350 }
351 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
352 }
353 } else {
354 Err(CompletionError::ProviderError(
355 String::from_utf8_lossy(&response_body).to_string(),
356 ))
357 }
358 };
359
360 tracing::Instrument::instrument(async_block, span).await
361 }
362
363 async fn stream(
364 &self,
365 request: CompletionRequest,
366 ) -> Result<
367 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
368 CompletionError,
369 > {
370 let span = if tracing::Span::current().is_disabled() {
371 info_span!(
372 target: "rig::completions",
373 "chat_streaming",
374 gen_ai.operation.name = "chat_streaming",
375 gen_ai.provider.name = "groq",
376 gen_ai.request.model = self.model,
377 gen_ai.system_instructions = tracing::field::Empty,
378 gen_ai.response.id = tracing::field::Empty,
379 gen_ai.response.model = tracing::field::Empty,
380 gen_ai.usage.output_tokens = tracing::field::Empty,
381 gen_ai.usage.input_tokens = tracing::field::Empty,
382 )
383 } else {
384 tracing::Span::current()
385 };
386
387 span.record("gen_ai.system_instructions", &request.preamble);
388
389 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
390
391 request.stream = true;
392 request.stream_options = Some(StreamOptions {
393 include_usage: true,
394 });
395
396 if tracing::enabled!(tracing::Level::TRACE) {
397 tracing::trace!(target: "rig::completions",
398 "Groq streaming completion request: {}",
399 serde_json::to_string_pretty(&request)?
400 );
401 }
402
403 let body = serde_json::to_vec(&request)?;
404 let req = self
405 .client
406 .post("/chat/completions")?
407 .body(body)
408 .map_err(|e| http_client::Error::Instance(e.into()))?;
409
410 tracing::Instrument::instrument(
411 send_compatible_streaming_request(self.client.clone(), req),
412 span,
413 )
414 .await
415 }
416}
417
418pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
423pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
424pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
425
426#[derive(Clone)]
427pub struct TranscriptionModel<T> {
428 client: Client<T>,
429 pub model: String,
431}
432
433impl<T> TranscriptionModel<T> {
434 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
435 Self {
436 client,
437 model: model.into(),
438 }
439 }
440}
441impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
442where
443 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
444{
445 type Response = TranscriptionResponse;
446
447 type Client = Client<T>;
448
449 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
450 Self::new(client.clone(), model)
451 }
452
453 async fn transcription(
454 &self,
455 request: transcription::TranscriptionRequest,
456 ) -> Result<
457 transcription::TranscriptionResponse<Self::Response>,
458 transcription::TranscriptionError,
459 > {
460 let data = request.data;
461
462 let mut body = MultipartForm::new()
463 .text("model", self.model.clone())
464 .part(Part::bytes("file", data).filename(request.filename.clone()));
465
466 if let Some(language) = request.language {
467 body = body.text("language", language);
468 }
469
470 if let Some(prompt) = request.prompt {
471 body = body.text("prompt", prompt.clone());
472 }
473
474 if let Some(ref temperature) = request.temperature {
475 body = body.text("temperature", temperature.to_string());
476 }
477
478 if let Some(ref additional_params) = request.additional_params {
479 for (key, value) in additional_params
480 .as_object()
481 .expect("Additional Parameters to OpenAI Transcription should be a map")
482 {
483 body = body.text(key.to_owned(), value.to_string());
484 }
485 }
486
487 let req = self
488 .client
489 .post("/audio/transcriptions")?
490 .body(body)
491 .unwrap();
492
493 let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
494
495 let status = response.status();
496 let response_body = response.into_body().into_future().await?.to_vec();
497
498 if status.is_success() {
499 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
500 ApiResponse::Ok(response) => response.try_into(),
501 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
502 api_error_response.message,
503 )),
504 }
505 } else {
506 Err(TranscriptionError::ProviderError(
507 String::from_utf8_lossy(&response_body).to_string(),
508 ))
509 }
510 }
511}
512
513#[derive(Deserialize, Debug)]
514#[serde(untagged)]
515enum StreamingDelta {
516 Reasoning {
517 reasoning: String,
518 },
519 MessageContent {
520 #[serde(default)]
521 content: Option<String>,
522 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
523 tool_calls: Vec<StreamingToolCall>,
524 },
525}
526
527#[derive(Deserialize, Debug)]
528struct StreamingChoice {
529 delta: StreamingDelta,
530}
531
532#[derive(Deserialize, Debug)]
533struct StreamingCompletionChunk {
534 choices: Vec<StreamingChoice>,
535 usage: Option<Usage>,
536}
537
538#[derive(Clone, Deserialize, Serialize, Debug)]
539pub struct StreamingCompletionResponse {
540 pub usage: Usage,
541}
542
543impl GetTokenUsage for StreamingCompletionResponse {
544 fn token_usage(&self) -> Option<crate::completion::Usage> {
545 let mut usage = crate::completion::Usage::new();
546
547 usage.input_tokens = self.usage.prompt_tokens as u64;
548 usage.total_tokens = self.usage.total_tokens as u64;
549 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
550
551 Some(usage)
552 }
553}
554
555pub async fn send_compatible_streaming_request<T>(
556 client: T,
557 req: Request<Vec<u8>>,
558) -> Result<
559 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
560 CompletionError,
561>
562where
563 T: HttpClientExt + Clone + 'static,
564{
565 let span = tracing::Span::current();
566
567 let mut event_source = GenericEventSource::new(client, req);
568
569 let stream = stream! {
570 let span = tracing::Span::current();
571 let mut final_usage = Usage {
572 prompt_tokens: 0,
573 total_tokens: 0
574 };
575
576 let mut text_response = String::new();
577
578 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
579
580 while let Some(event_result) = event_source.next().await {
581 match event_result {
582 Ok(Event::Open) => {
583 tracing::trace!("SSE connection opened");
584 continue;
585 }
586
587 Ok(Event::Message(message)) => {
588 let data_str = message.data.trim();
589
590 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
591 let Ok(data) = parsed else {
592 let err = parsed.unwrap_err();
593 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
594 continue;
595 };
596
597 if let Some(choice) = data.choices.first() {
598 match &choice.delta {
599 StreamingDelta::Reasoning { reasoning } => {
600 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
601 id: None,
602 reasoning: reasoning.to_string(),
603 });
604 }
605
606 StreamingDelta::MessageContent { content, tool_calls } => {
607 for tool_call in tool_calls {
609 let function = &tool_call.function;
610
611 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
613 && empty_or_none(&function.arguments)
614 {
615 let id = tool_call.id.clone().unwrap_or_default();
616 let name = function.name.clone().unwrap();
617 calls.insert(tool_call.index, (id, name, String::new()));
618 }
619 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
621 && let Some(arguments) = &function.arguments
622 && !arguments.is_empty()
623 {
624 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
625 let combined = format!("{}{}", existing_args, arguments);
626 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
627 } else {
628 tracing::debug!("Partial tool call received but tool call was never started.");
629 }
630 }
631 else {
633 let id = tool_call.id.clone().unwrap_or_default();
634 let name = function.name.clone().unwrap_or_default();
635 let arguments_str = function.arguments.clone().unwrap_or_default();
636
637 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
638 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
639 continue;
640 };
641
642 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
643 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
644 ));
645 }
646 }
647
648 if let Some(content) = content {
650 text_response += content;
651 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
652 }
653 }
654 }
655 }
656
657 if let Some(usage) = data.usage {
658 final_usage = usage.clone();
659 }
660 }
661
662 Err(crate::http_client::Error::StreamEnded) => break,
663 Err(err) => {
664 tracing::error!(?err, "SSE error");
665 yield Err(CompletionError::ResponseError(err.to_string()));
666 break;
667 }
668 }
669 }
670
671 event_source.close();
672
673 let mut tool_calls = Vec::new();
674 for (_, (id, name, arguments)) in calls {
676 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
677 continue;
678 };
679
680 tool_calls.push(rig::providers::openai::completion::ToolCall {
681 id: id.clone(),
682 r#type: ToolType::Function,
683 function: Function {
684 name: name.clone(),
685 arguments: arguments_json.clone()
686 }
687 });
688 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
689 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
690 ));
691 }
692
693 let response_message = crate::providers::openai::completion::Message::Assistant {
694 content: vec![AssistantContent::Text { text: text_response }],
695 refusal: None,
696 audio: None,
697 name: None,
698 tool_calls
699 };
700
701 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
702 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
703 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
704
705 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
707 StreamingCompletionResponse { usage: final_usage.clone() }
708 ));
709 }.instrument(span);
710
711 Ok(crate::streaming::StreamingCompletionResponse::stream(
712 Box::pin(stream),
713 ))
714}
715
716#[cfg(test)]
717mod tests {
718 use crate::{
719 OneOrMany,
720 providers::{
721 groq::{GroqAdditionalParameters, GroqCompletionRequest},
722 openai::{Message, UserContent},
723 },
724 };
725
726 #[test]
727 fn serialize_groq_request() {
728 let additional_params = GroqAdditionalParameters {
729 include_reasoning: Some(true),
730 reasoning_format: Some(super::ReasoningFormat::Parsed),
731 ..Default::default()
732 };
733
734 let groq = GroqCompletionRequest {
735 model: "openai/gpt-120b-oss".to_string(),
736 temperature: None,
737 tool_choice: None,
738 stream_options: None,
739 tools: Vec::new(),
740 messages: vec![Message::User {
741 content: OneOrMany::one(UserContent::Text {
742 text: "Hello world!".to_string(),
743 }),
744 name: None,
745 }],
746 stream: false,
747 additional_params: Some(additional_params),
748 };
749
750 let json = serde_json::to_value(&groq).unwrap();
751
752 assert_eq!(
753 json,
754 serde_json::json!({
755 "model": "openai/gpt-120b-oss",
756 "messages": [
757 {
758 "role": "user",
759 "content": [{
760 "type": "text",
761 "text": "Hello world!"
762 }]
763 }
764 ],
765 "stream": false,
766 "include_reasoning": true,
767 "reasoning_format": "parsed"
768 })
769 )
770 }
771}