1use bytes::Bytes;
15use http::Request;
16use serde_json::{Map, Value};
17use tracing::info_span;
18
19use super::openai::{
20 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
21};
22use crate::client::{
23 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
24 ProviderClient,
25};
26use crate::completion::GetTokenUsage;
27use crate::http_client::multipart::Part;
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::providers::internal::openai_chat_completions_compatible::{
30 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
31};
32
33use crate::{
34 completion::{self, CompletionError, CompletionRequest},
35 json_utils,
36 message::{self},
37 providers::openai::ToolDefinition,
38 transcription::{self, TranscriptionError},
39};
40use serde::{Deserialize, Serialize};
41
42const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
46
47#[derive(Debug, Default, Clone, Copy)]
48pub struct GroqExt;
49#[derive(Debug, Default, Clone, Copy)]
50pub struct GroqBuilder;
51
52type GroqApiKey = BearerAuth;
53
54impl Provider for GroqExt {
55 type Builder = GroqBuilder;
56 const VERIFY_PATH: &'static str = "/models";
57}
58
59impl<H> Capabilities<H> for GroqExt {
60 type Completion = Capable<CompletionModel<H>>;
61 type Embeddings = Nothing;
62 type Transcription = Capable<TranscriptionModel<H>>;
63 type ModelListing = Nothing;
64 #[cfg(feature = "image")]
65 type ImageGeneration = Nothing;
66
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Nothing;
69 type Rerank = Nothing;
70}
71
72impl DebugExt for GroqExt {}
73
74impl ProviderBuilder for GroqBuilder {
75 type Extension<H>
76 = GroqExt
77 where
78 H: HttpClientExt;
79 type ApiKey = GroqApiKey;
80
81 const BASE_URL: &'static str = GROQ_API_BASE_URL;
82
83 fn build<H>(
84 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
85 ) -> http_client::Result<Self::Extension<H>>
86 where
87 H: HttpClientExt,
88 {
89 Ok(GroqExt)
90 }
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
94pub type ClientBuilder<H = crate::markers::Missing> = client::ClientBuilder<GroqBuilder, String, H>;
95
96impl ProviderClient for Client {
97 type Input = String;
98 type Error = crate::client::ProviderClientError;
99
100 fn from_env() -> Result<Self, Self::Error> {
102 let api_key = crate::client::required_env_var("GROQ_API_KEY")?;
103 Self::new(&api_key).map_err(Into::into)
104 }
105
106 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
107 Self::new(&input).map_err(Into::into)
108 }
109}
110
111#[derive(Debug, Deserialize)]
112struct ApiErrorResponse {
113 message: String,
114}
115
116#[derive(Debug, Deserialize)]
117#[serde(untagged)]
118enum ApiResponse<T> {
119 Ok(T),
120 Err(ApiErrorResponse),
121}
122
123pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
129pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
131pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
133pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
135pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
137pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
139pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
141pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
143pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
145pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
147pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
149pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
151pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum ReasoningFormat {
157 Parsed,
158 Raw,
159 Hidden,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub(super) struct GroqCompletionRequest {
164 model: String,
165 pub messages: Vec<OpenAIMessage>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 temperature: Option<f64>,
168 #[serde(skip_serializing_if = "Vec::is_empty")]
169 tools: Vec<ToolDefinition>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
172 #[serde(flatten, skip_serializing_if = "Option::is_none")]
173 pub additional_params: Option<GroqAdditionalParameters>,
174 pub(super) stream: bool,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub(super) stream_options: Option<StreamOptions>,
177}
178
179#[derive(Debug, Serialize, Deserialize, Default)]
180pub(super) struct StreamOptions {
181 pub(super) include_usage: bool,
182}
183
184impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
185 type Error = CompletionError;
186
187 fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
188 let chat_history = req.chat_history_with_documents();
189 if req.output_schema.is_some() {
190 tracing::warn!("Structured outputs currently not supported for Groq");
191 }
192 let model = req.model.clone().unwrap_or_else(|| model.to_string());
193 let mut partial_history = vec![];
195 partial_history.extend(chat_history);
196
197 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
199 Some(preamble) => vec![OpenAIMessage::system(preamble)],
200 None => vec![],
201 };
202
203 full_history.extend(
205 partial_history
206 .into_iter()
207 .map(message::Message::try_into)
208 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
209 .into_iter()
210 .flatten()
211 .collect::<Vec<_>>(),
212 );
213
214 let tool_choice = req
215 .tool_choice
216 .clone()
217 .map(crate::providers::openai::ToolChoice::try_from)
218 .transpose()?;
219
220 let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
221 let native_tools =
222 extract_native_tools_from_additional_params(&mut additional_params_payload)?;
223
224 let mut additional_params: Option<GroqAdditionalParameters> =
225 if additional_params_payload.is_null() {
226 None
227 } else {
228 Some(serde_json::from_value(additional_params_payload)?)
229 };
230 apply_native_tools_to_additional_params(&mut additional_params, native_tools);
231
232 Ok(Self {
233 model: model.to_string(),
234 messages: full_history,
235 temperature: req.temperature,
236 tools: req
237 .tools
238 .clone()
239 .into_iter()
240 .map(ToolDefinition::from)
241 .collect::<Vec<_>>(),
242 tool_choice,
243 additional_params,
244 stream: false,
245 stream_options: None,
246 })
247 }
248}
249
250fn extract_native_tools_from_additional_params(
251 additional_params: &mut Value,
252) -> Result<Vec<Value>, CompletionError> {
253 if let Some(map) = additional_params.as_object_mut()
254 && let Some(raw_tools) = map.remove("tools")
255 {
256 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
257 CompletionError::RequestError(
258 format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
259 )
260 });
261 }
262
263 Ok(Vec::new())
264}
265
266fn apply_native_tools_to_additional_params(
267 additional_params: &mut Option<GroqAdditionalParameters>,
268 native_tools: Vec<Value>,
269) {
270 if native_tools.is_empty() {
271 return;
272 }
273
274 let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
275 let extra = params.extra.get_or_insert_with(Map::new);
276
277 let mut compound_custom = match extra.remove("compound_custom") {
278 Some(Value::Object(map)) => map,
279 _ => Map::new(),
280 };
281
282 let mut enabled_tools = match compound_custom.remove("enabled_tools") {
283 Some(Value::Array(values)) => values,
284 _ => Vec::new(),
285 };
286
287 for native_tool in native_tools {
288 let already_enabled = enabled_tools
289 .iter()
290 .any(|existing| native_tools_match(existing, &native_tool));
291 if !already_enabled {
292 enabled_tools.push(native_tool);
293 }
294 }
295
296 compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
297 extra.insert(
298 "compound_custom".to_string(),
299 Value::Object(compound_custom),
300 );
301}
302
303fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
304 if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
305 return lhs_type == rhs_type;
306 }
307
308 lhs == rhs
309}
310
311fn native_tool_kind(value: &Value) -> Option<&str> {
312 match value {
313 Value::String(kind) => Some(kind),
314 Value::Object(map) => map.get("type").and_then(Value::as_str),
315 _ => None,
316 }
317}
318
319#[derive(Clone, Debug, Default, Serialize, Deserialize)]
321pub struct GroqAdditionalParameters {
322 #[serde(skip_serializing_if = "Option::is_none")]
324 pub reasoning_format: Option<ReasoningFormat>,
325 #[serde(skip_serializing_if = "Option::is_none")]
327 pub include_reasoning: Option<bool>,
328 #[serde(flatten, skip_serializing_if = "Option::is_none")]
330 pub extra: Option<Map<String, serde_json::Value>>,
331}
332
333#[derive(Clone, Debug)]
334pub struct CompletionModel<T = reqwest::Client> {
335 client: Client<T>,
336 pub model: String,
338}
339
340impl<T> CompletionModel<T> {
341 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
342 Self {
343 client,
344 model: model.into(),
345 }
346 }
347}
348
349impl<T> completion::CompletionModel for CompletionModel<T>
350where
351 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
352{
353 type Response = CompletionResponse;
354 type StreamingResponse = StreamingCompletionResponse;
355
356 type Client = Client<T>;
357
358 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
359 Self::new(client.clone(), model)
360 }
361
362 async fn completion(
363 &self,
364 completion_request: CompletionRequest,
365 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
366 let span = if tracing::Span::current().is_disabled() {
367 info_span!(
368 target: "rig::completions",
369 "chat",
370 gen_ai.operation.name = "chat",
371 gen_ai.provider.name = "groq",
372 gen_ai.request.model = self.model,
373 gen_ai.system_instructions = tracing::field::Empty,
374 gen_ai.response.id = tracing::field::Empty,
375 gen_ai.response.model = tracing::field::Empty,
376 gen_ai.usage.output_tokens = tracing::field::Empty,
377 gen_ai.usage.input_tokens = tracing::field::Empty,
378 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
379 )
380 } else {
381 tracing::Span::current()
382 };
383
384 span.record("gen_ai.system_instructions", &completion_request.preamble);
385
386 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
387
388 if tracing::enabled!(tracing::Level::TRACE) {
389 tracing::trace!(target: "rig::completions",
390 "Groq completion request: {}",
391 serde_json::to_string_pretty(&request)?
392 );
393 }
394
395 let body = serde_json::to_vec(&request)?;
396 let req = self
397 .client
398 .post("/chat/completions")?
399 .body(body)
400 .map_err(|e| http_client::Error::Instance(e.into()))?;
401
402 let async_block = async move {
403 let response = self.client.send::<_, Bytes>(req).await?;
404 let status = response.status();
405 let response_body = response.into_body().into_future().await?.to_vec();
406
407 if status.is_success() {
408 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
409 ApiResponse::Ok(response) => {
410 let span = tracing::Span::current();
411 span.record("gen_ai.response.id", response.id.clone());
412 span.record("gen_ai.response.model", response.model.clone());
413 if let Some(ref usage) = response.usage {
414 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
415 span.record(
416 "gen_ai.usage.output_tokens",
417 usage.total_tokens - usage.prompt_tokens,
418 );
419 span.record(
420 "gen_ai.usage.cache_read.input_tokens",
421 usage
422 .prompt_tokens_details
423 .as_ref()
424 .map(|d| d.cached_tokens)
425 .unwrap_or(0),
426 );
427 }
428
429 if tracing::enabled!(tracing::Level::TRACE) {
430 tracing::trace!(target: "rig::completions",
431 "Groq completion response: {}",
432 serde_json::to_string_pretty(&response)?
433 );
434 }
435
436 response.try_into()
437 }
438 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
439 }
440 } else {
441 Err(CompletionError::ProviderError(
442 String::from_utf8_lossy(&response_body).to_string(),
443 ))
444 }
445 };
446
447 tracing::Instrument::instrument(async_block, span).await
448 }
449
450 async fn stream(
451 &self,
452 request: CompletionRequest,
453 ) -> Result<
454 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
455 CompletionError,
456 > {
457 let span = if tracing::Span::current().is_disabled() {
458 info_span!(
459 target: "rig::completions",
460 "chat_streaming",
461 gen_ai.operation.name = "chat_streaming",
462 gen_ai.provider.name = "groq",
463 gen_ai.request.model = self.model,
464 gen_ai.system_instructions = tracing::field::Empty,
465 gen_ai.response.id = tracing::field::Empty,
466 gen_ai.response.model = tracing::field::Empty,
467 gen_ai.usage.output_tokens = tracing::field::Empty,
468 gen_ai.usage.input_tokens = tracing::field::Empty,
469 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
470 )
471 } else {
472 tracing::Span::current()
473 };
474
475 span.record("gen_ai.system_instructions", &request.preamble);
476
477 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
478
479 request.stream = true;
480 request.stream_options = Some(StreamOptions {
481 include_usage: true,
482 });
483
484 if tracing::enabled!(tracing::Level::TRACE) {
485 tracing::trace!(target: "rig::completions",
486 "Groq streaming completion request: {}",
487 serde_json::to_string_pretty(&request)?
488 );
489 }
490
491 let body = serde_json::to_vec(&request)?;
492 let req = self
493 .client
494 .post("/chat/completions")?
495 .body(body)
496 .map_err(|e| http_client::Error::Instance(e.into()))?;
497
498 tracing::Instrument::instrument(
499 send_compatible_streaming_request(self.client.clone(), req),
500 span,
501 )
502 .await
503 }
504}
505
506pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
511pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
512pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
513
514#[derive(Clone)]
515pub struct TranscriptionModel<T> {
516 client: Client<T>,
517 pub model: String,
519}
520
521impl<T> TranscriptionModel<T> {
522 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
523 Self {
524 client,
525 model: model.into(),
526 }
527 }
528}
529impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
530where
531 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
532{
533 type Response = TranscriptionResponse;
534
535 type Client = Client<T>;
536
537 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
538 Self::new(client.clone(), model)
539 }
540
541 async fn transcription(
542 &self,
543 request: transcription::TranscriptionRequest,
544 ) -> Result<
545 transcription::TranscriptionResponse<Self::Response>,
546 transcription::TranscriptionError,
547 > {
548 let data = request.data;
549
550 let mut body = MultipartForm::new()
551 .text("model", self.model.clone())
552 .part(Part::bytes("file", data).filename(request.filename.clone()));
553
554 if let Some(language) = request.language {
555 body = body.text("language", language);
556 }
557
558 if let Some(prompt) = request.prompt {
559 body = body.text("prompt", prompt.clone());
560 }
561
562 if let Some(ref temperature) = request.temperature {
563 body = body.text("temperature", temperature.to_string());
564 }
565
566 if let Some(ref additional_params) = request.additional_params {
567 let params = additional_params.as_object().ok_or_else(|| {
568 TranscriptionError::RequestError(Box::new(std::io::Error::new(
569 std::io::ErrorKind::InvalidInput,
570 "additional transcription parameters must be a JSON object",
571 )))
572 })?;
573
574 for (key, value) in params {
575 body = body.text(key.to_owned(), value.to_string());
576 }
577 }
578
579 let req = self
580 .client
581 .post("/audio/transcriptions")?
582 .body(body)
583 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
584
585 let response = self.client.send_multipart::<Bytes>(req).await?;
586
587 let status = response.status();
588 let response_body = response.into_body().into_future().await?.to_vec();
589
590 if status.is_success() {
591 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
592 ApiResponse::Ok(response) => response.try_into(),
593 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
594 api_error_response.message,
595 )),
596 }
597 } else {
598 Err(TranscriptionError::ProviderError(
599 String::from_utf8_lossy(&response_body).to_string(),
600 ))
601 }
602 }
603}
604
605#[derive(Deserialize, Debug)]
606#[serde(untagged)]
607enum StreamingDelta {
608 Reasoning {
609 reasoning: String,
610 },
611 MessageContent {
612 #[serde(default)]
613 content: Option<String>,
614 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
615 tool_calls: Vec<StreamingToolCall>,
616 },
617}
618
619#[derive(Deserialize, Debug)]
620struct StreamingChoice {
621 delta: StreamingDelta,
622}
623
624#[derive(Deserialize, Debug)]
625struct StreamingCompletionChunk {
626 id: Option<String>,
627 model: Option<String>,
628 choices: Vec<StreamingChoice>,
629 usage: Option<Usage>,
630}
631
632#[derive(Clone, Deserialize, Serialize, Debug)]
633pub struct StreamingCompletionResponse {
634 pub usage: Usage,
635}
636
637impl GetTokenUsage for StreamingCompletionResponse {
638 fn token_usage(&self) -> crate::completion::Usage {
639 self.usage.token_usage()
640 }
641}
642
643#[derive(Clone, Copy)]
644struct GroqCompatibleProfile;
645
646impl CompatibleStreamProfile for GroqCompatibleProfile {
647 type Usage = Usage;
648 type Detail = ();
649 type FinalResponse = StreamingCompletionResponse;
650
651 fn normalize_chunk(
652 &self,
653 data: &str,
654 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
655 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
656 Ok(data) => data,
657 Err(error) => {
658 tracing::debug!(
659 "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
660 error
661 );
662 return Ok(None);
663 }
664 };
665
666 Ok(Some(
667 openai_chat_completions_compatible::normalize_first_choice_chunk(
668 data.id,
669 data.model,
670 data.usage,
671 &data.choices,
672 |choice| match &choice.delta {
673 StreamingDelta::Reasoning { reasoning } => CompatibleChoiceData {
674 finish_reason: CompatibleFinishReason::Other,
675 text: None,
676 reasoning: Some(reasoning.clone()),
677 tool_calls: Vec::new(),
678 details: Vec::new(),
679 },
680 StreamingDelta::MessageContent {
681 content,
682 tool_calls,
683 } => CompatibleChoiceData {
684 finish_reason: CompatibleFinishReason::Other,
685 text: content.clone(),
686 reasoning: None,
687 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
688 tool_calls,
689 ),
690 details: Vec::new(),
691 },
692 },
693 ),
694 ))
695 }
696
697 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
698 StreamingCompletionResponse { usage }
699 }
700
701 fn uses_distinct_tool_call_eviction(&self) -> bool {
702 true
703 }
704
705 fn emits_complete_single_chunk_tool_calls(&self) -> bool {
706 true
707 }
708}
709
710pub async fn send_compatible_streaming_request<T>(
711 client: T,
712 req: Request<Vec<u8>>,
713) -> Result<
714 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
715 CompletionError,
716>
717where
718 T: HttpClientExt + Clone + 'static,
719{
720 openai_chat_completions_compatible::send_compatible_streaming_request(
721 client,
722 req,
723 GroqCompatibleProfile,
724 )
725 .await
726}
727
728#[cfg(test)]
729mod tests {
730 use crate::{
731 OneOrMany,
732 providers::{
733 groq::{GroqAdditionalParameters, GroqCompletionRequest},
734 openai::{Message, UserContent},
735 },
736 };
737
738 #[test]
739 fn serialize_groq_request() {
740 let additional_params = GroqAdditionalParameters {
741 include_reasoning: Some(true),
742 reasoning_format: Some(super::ReasoningFormat::Parsed),
743 ..Default::default()
744 };
745
746 let groq = GroqCompletionRequest {
747 model: "openai/gpt-120b-oss".to_string(),
748 temperature: None,
749 tool_choice: None,
750 stream_options: None,
751 tools: Vec::new(),
752 messages: vec![Message::User {
753 content: OneOrMany::one(UserContent::Text {
754 text: "Hello world!".to_string(),
755 }),
756 name: None,
757 }],
758 stream: false,
759 additional_params: Some(additional_params),
760 };
761
762 let json = serde_json::to_value(&groq).unwrap();
763
764 assert_eq!(
765 json,
766 serde_json::json!({
767 "model": "openai/gpt-120b-oss",
768 "messages": [
769 {
770 "role": "user",
771 "content": "Hello world!"
772 }
773 ],
774 "stream": false,
775 "include_reasoning": true,
776 "reasoning_format": "parsed"
777 })
778 )
779 }
780 #[test]
781 fn test_client_initialization() {
782 let _client =
783 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
784 let _client_from_builder = crate::providers::groq::Client::builder()
785 .api_key("dummy-key")
786 .build()
787 .expect("Client::builder() failed");
788 }
789}