1use bytes::Bytes;
12use http::Request;
13use serde_json::{Map, Value};
14use tracing::info_span;
15
16use super::openai::{
17 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
18};
19use crate::client::{
20 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
21 ProviderClient,
22};
23use crate::completion::GetTokenUsage;
24use crate::http_client::multipart::Part;
25use crate::http_client::{self, HttpClientExt, MultipartForm};
26use crate::providers::internal::openai_chat_completions_compatible::{
27 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
28};
29
30use crate::{
31 completion::{self, CompletionError, CompletionRequest},
32 json_utils,
33 message::{self},
34 providers::openai::ToolDefinition,
35 transcription::{self, TranscriptionError},
36};
37use serde::{Deserialize, Serialize};
38
39const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
43
44#[derive(Debug, Default, Clone, Copy)]
45pub struct GroqExt;
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqBuilder;
48
49type GroqApiKey = BearerAuth;
50
51impl Provider for GroqExt {
52 type Builder = GroqBuilder;
53 const VERIFY_PATH: &'static str = "/models";
54}
55
56impl<H> Capabilities<H> for GroqExt {
57 type Completion = Capable<CompletionModel<H>>;
58 type Embeddings = Nothing;
59 type Transcription = Capable<TranscriptionModel<H>>;
60 type ModelListing = Nothing;
61 #[cfg(feature = "image")]
62 type ImageGeneration = Nothing;
63
64 #[cfg(feature = "audio")]
65 type AudioGeneration = Nothing;
66}
67
68impl DebugExt for GroqExt {}
69
70impl ProviderBuilder for GroqBuilder {
71 type Extension<H>
72 = GroqExt
73 where
74 H: HttpClientExt;
75 type ApiKey = GroqApiKey;
76
77 const BASE_URL: &'static str = GROQ_API_BASE_URL;
78
79 fn build<H>(
80 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
81 ) -> http_client::Result<Self::Extension<H>>
82 where
83 H: HttpClientExt,
84 {
85 Ok(GroqExt)
86 }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93 type Input = String;
94 type Error = crate::client::ProviderClientError;
95
96 fn from_env() -> Result<Self, Self::Error> {
98 let api_key = crate::client::required_env_var("GROQ_API_KEY")?;
99 Self::new(&api_key).map_err(Into::into)
100 }
101
102 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
103 Self::new(&input).map_err(Into::into)
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
125pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
127pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
129pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
131pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
133pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
135pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
137pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
139pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
141pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
143pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
145pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
147pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
149
150#[derive(Clone, Debug, Serialize, Deserialize)]
151#[serde(rename_all = "lowercase")]
152pub enum ReasoningFormat {
153 Parsed,
154 Raw,
155 Hidden,
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub(super) struct GroqCompletionRequest {
160 model: String,
161 pub messages: Vec<OpenAIMessage>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 temperature: Option<f64>,
164 #[serde(skip_serializing_if = "Vec::is_empty")]
165 tools: Vec<ToolDefinition>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
168 #[serde(flatten, skip_serializing_if = "Option::is_none")]
169 pub additional_params: Option<GroqAdditionalParameters>,
170 pub(super) stream: bool,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub(super) stream_options: Option<StreamOptions>,
173}
174
175#[derive(Debug, Serialize, Deserialize, Default)]
176pub(super) struct StreamOptions {
177 pub(super) include_usage: bool,
178}
179
180impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
181 type Error = CompletionError;
182
183 fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
184 if req.output_schema.is_some() {
185 tracing::warn!("Structured outputs currently not supported for Groq");
186 }
187 let model = req.model.clone().unwrap_or_else(|| model.to_string());
188 let mut partial_history = vec![];
190 if let Some(docs) = req.normalized_documents() {
191 partial_history.push(docs);
192 }
193 partial_history.extend(req.chat_history);
194
195 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
197 Some(preamble) => vec![OpenAIMessage::system(preamble)],
198 None => vec![],
199 };
200
201 full_history.extend(
203 partial_history
204 .into_iter()
205 .map(message::Message::try_into)
206 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
207 .into_iter()
208 .flatten()
209 .collect::<Vec<_>>(),
210 );
211
212 let tool_choice = req
213 .tool_choice
214 .clone()
215 .map(crate::providers::openai::ToolChoice::try_from)
216 .transpose()?;
217
218 let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
219 let native_tools =
220 extract_native_tools_from_additional_params(&mut additional_params_payload)?;
221
222 let mut additional_params: Option<GroqAdditionalParameters> =
223 if additional_params_payload.is_null() {
224 None
225 } else {
226 Some(serde_json::from_value(additional_params_payload)?)
227 };
228 apply_native_tools_to_additional_params(&mut additional_params, native_tools);
229
230 Ok(Self {
231 model: model.to_string(),
232 messages: full_history,
233 temperature: req.temperature,
234 tools: req
235 .tools
236 .clone()
237 .into_iter()
238 .map(ToolDefinition::from)
239 .collect::<Vec<_>>(),
240 tool_choice,
241 additional_params,
242 stream: false,
243 stream_options: None,
244 })
245 }
246}
247
248fn extract_native_tools_from_additional_params(
249 additional_params: &mut Value,
250) -> Result<Vec<Value>, CompletionError> {
251 if let Some(map) = additional_params.as_object_mut()
252 && let Some(raw_tools) = map.remove("tools")
253 {
254 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
255 CompletionError::RequestError(
256 format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
257 )
258 });
259 }
260
261 Ok(Vec::new())
262}
263
264fn apply_native_tools_to_additional_params(
265 additional_params: &mut Option<GroqAdditionalParameters>,
266 native_tools: Vec<Value>,
267) {
268 if native_tools.is_empty() {
269 return;
270 }
271
272 let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
273 let extra = params.extra.get_or_insert_with(Map::new);
274
275 let mut compound_custom = match extra.remove("compound_custom") {
276 Some(Value::Object(map)) => map,
277 _ => Map::new(),
278 };
279
280 let mut enabled_tools = match compound_custom.remove("enabled_tools") {
281 Some(Value::Array(values)) => values,
282 _ => Vec::new(),
283 };
284
285 for native_tool in native_tools {
286 let already_enabled = enabled_tools
287 .iter()
288 .any(|existing| native_tools_match(existing, &native_tool));
289 if !already_enabled {
290 enabled_tools.push(native_tool);
291 }
292 }
293
294 compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
295 extra.insert(
296 "compound_custom".to_string(),
297 Value::Object(compound_custom),
298 );
299}
300
301fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
302 if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
303 return lhs_type == rhs_type;
304 }
305
306 lhs == rhs
307}
308
309fn native_tool_kind(value: &Value) -> Option<&str> {
310 match value {
311 Value::String(kind) => Some(kind),
312 Value::Object(map) => map.get("type").and_then(Value::as_str),
313 _ => None,
314 }
315}
316
317#[derive(Clone, Debug, Default, Serialize, Deserialize)]
319pub struct GroqAdditionalParameters {
320 #[serde(skip_serializing_if = "Option::is_none")]
322 pub reasoning_format: Option<ReasoningFormat>,
323 #[serde(skip_serializing_if = "Option::is_none")]
325 pub include_reasoning: Option<bool>,
326 #[serde(flatten, skip_serializing_if = "Option::is_none")]
328 pub extra: Option<Map<String, serde_json::Value>>,
329}
330
331#[derive(Clone, Debug)]
332pub struct CompletionModel<T = reqwest::Client> {
333 client: Client<T>,
334 pub model: String,
336}
337
338impl<T> CompletionModel<T> {
339 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
340 Self {
341 client,
342 model: model.into(),
343 }
344 }
345}
346
347impl<T> completion::CompletionModel for CompletionModel<T>
348where
349 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
350{
351 type Response = CompletionResponse;
352 type StreamingResponse = StreamingCompletionResponse;
353
354 type Client = Client<T>;
355
356 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
357 Self::new(client.clone(), model)
358 }
359
360 async fn completion(
361 &self,
362 completion_request: CompletionRequest,
363 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
364 let span = if tracing::Span::current().is_disabled() {
365 info_span!(
366 target: "rig::completions",
367 "chat",
368 gen_ai.operation.name = "chat",
369 gen_ai.provider.name = "groq",
370 gen_ai.request.model = self.model,
371 gen_ai.system_instructions = tracing::field::Empty,
372 gen_ai.response.id = tracing::field::Empty,
373 gen_ai.response.model = tracing::field::Empty,
374 gen_ai.usage.output_tokens = tracing::field::Empty,
375 gen_ai.usage.input_tokens = tracing::field::Empty,
376 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
377 )
378 } else {
379 tracing::Span::current()
380 };
381
382 span.record("gen_ai.system_instructions", &completion_request.preamble);
383
384 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
385
386 if tracing::enabled!(tracing::Level::TRACE) {
387 tracing::trace!(target: "rig::completions",
388 "Groq completion request: {}",
389 serde_json::to_string_pretty(&request)?
390 );
391 }
392
393 let body = serde_json::to_vec(&request)?;
394 let req = self
395 .client
396 .post("/chat/completions")?
397 .body(body)
398 .map_err(|e| http_client::Error::Instance(e.into()))?;
399
400 let async_block = async move {
401 let response = self.client.send::<_, Bytes>(req).await?;
402 let status = response.status();
403 let response_body = response.into_body().into_future().await?.to_vec();
404
405 if status.is_success() {
406 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
407 ApiResponse::Ok(response) => {
408 let span = tracing::Span::current();
409 span.record("gen_ai.response.id", response.id.clone());
410 span.record("gen_ai.response.model", response.model.clone());
411 if let Some(ref usage) = response.usage {
412 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
413 span.record(
414 "gen_ai.usage.output_tokens",
415 usage.total_tokens - usage.prompt_tokens,
416 );
417 span.record(
418 "gen_ai.usage.cache_read.input_tokens",
419 usage
420 .prompt_tokens_details
421 .as_ref()
422 .map(|d| d.cached_tokens)
423 .unwrap_or(0),
424 );
425 }
426
427 if tracing::enabled!(tracing::Level::TRACE) {
428 tracing::trace!(target: "rig::completions",
429 "Groq completion response: {}",
430 serde_json::to_string_pretty(&response)?
431 );
432 }
433
434 response.try_into()
435 }
436 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
437 }
438 } else {
439 Err(CompletionError::ProviderError(
440 String::from_utf8_lossy(&response_body).to_string(),
441 ))
442 }
443 };
444
445 tracing::Instrument::instrument(async_block, span).await
446 }
447
448 async fn stream(
449 &self,
450 request: CompletionRequest,
451 ) -> Result<
452 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
453 CompletionError,
454 > {
455 let span = if tracing::Span::current().is_disabled() {
456 info_span!(
457 target: "rig::completions",
458 "chat_streaming",
459 gen_ai.operation.name = "chat_streaming",
460 gen_ai.provider.name = "groq",
461 gen_ai.request.model = self.model,
462 gen_ai.system_instructions = tracing::field::Empty,
463 gen_ai.response.id = tracing::field::Empty,
464 gen_ai.response.model = tracing::field::Empty,
465 gen_ai.usage.output_tokens = tracing::field::Empty,
466 gen_ai.usage.input_tokens = tracing::field::Empty,
467 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
468 )
469 } else {
470 tracing::Span::current()
471 };
472
473 span.record("gen_ai.system_instructions", &request.preamble);
474
475 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
476
477 request.stream = true;
478 request.stream_options = Some(StreamOptions {
479 include_usage: true,
480 });
481
482 if tracing::enabled!(tracing::Level::TRACE) {
483 tracing::trace!(target: "rig::completions",
484 "Groq streaming completion request: {}",
485 serde_json::to_string_pretty(&request)?
486 );
487 }
488
489 let body = serde_json::to_vec(&request)?;
490 let req = self
491 .client
492 .post("/chat/completions")?
493 .body(body)
494 .map_err(|e| http_client::Error::Instance(e.into()))?;
495
496 tracing::Instrument::instrument(
497 send_compatible_streaming_request(self.client.clone(), req),
498 span,
499 )
500 .await
501 }
502}
503
504pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
509pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
510pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
511
512#[derive(Clone)]
513pub struct TranscriptionModel<T> {
514 client: Client<T>,
515 pub model: String,
517}
518
519impl<T> TranscriptionModel<T> {
520 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
521 Self {
522 client,
523 model: model.into(),
524 }
525 }
526}
527impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
528where
529 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
530{
531 type Response = TranscriptionResponse;
532
533 type Client = Client<T>;
534
535 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536 Self::new(client.clone(), model)
537 }
538
539 async fn transcription(
540 &self,
541 request: transcription::TranscriptionRequest,
542 ) -> Result<
543 transcription::TranscriptionResponse<Self::Response>,
544 transcription::TranscriptionError,
545 > {
546 let data = request.data;
547
548 let mut body = MultipartForm::new()
549 .text("model", self.model.clone())
550 .part(Part::bytes("file", data).filename(request.filename.clone()));
551
552 if let Some(language) = request.language {
553 body = body.text("language", language);
554 }
555
556 if let Some(prompt) = request.prompt {
557 body = body.text("prompt", prompt.clone());
558 }
559
560 if let Some(ref temperature) = request.temperature {
561 body = body.text("temperature", temperature.to_string());
562 }
563
564 if let Some(ref additional_params) = request.additional_params {
565 let params = additional_params.as_object().ok_or_else(|| {
566 TranscriptionError::RequestError(Box::new(std::io::Error::new(
567 std::io::ErrorKind::InvalidInput,
568 "additional transcription parameters must be a JSON object",
569 )))
570 })?;
571
572 for (key, value) in params {
573 body = body.text(key.to_owned(), value.to_string());
574 }
575 }
576
577 let req = self
578 .client
579 .post("/audio/transcriptions")?
580 .body(body)
581 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
582
583 let response = self.client.send_multipart::<Bytes>(req).await?;
584
585 let status = response.status();
586 let response_body = response.into_body().into_future().await?.to_vec();
587
588 if status.is_success() {
589 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
590 ApiResponse::Ok(response) => response.try_into(),
591 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
592 api_error_response.message,
593 )),
594 }
595 } else {
596 Err(TranscriptionError::ProviderError(
597 String::from_utf8_lossy(&response_body).to_string(),
598 ))
599 }
600 }
601}
602
603#[derive(Deserialize, Debug)]
604#[serde(untagged)]
605enum StreamingDelta {
606 Reasoning {
607 reasoning: String,
608 },
609 MessageContent {
610 #[serde(default)]
611 content: Option<String>,
612 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
613 tool_calls: Vec<StreamingToolCall>,
614 },
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingChoice {
619 delta: StreamingDelta,
620}
621
622#[derive(Deserialize, Debug)]
623struct StreamingCompletionChunk {
624 id: Option<String>,
625 model: Option<String>,
626 choices: Vec<StreamingChoice>,
627 usage: Option<Usage>,
628}
629
630#[derive(Clone, Deserialize, Serialize, Debug)]
631pub struct StreamingCompletionResponse {
632 pub usage: Usage,
633}
634
635impl GetTokenUsage for StreamingCompletionResponse {
636 fn token_usage(&self) -> Option<crate::completion::Usage> {
637 self.usage.token_usage()
638 }
639}
640
641#[derive(Clone, Copy)]
642struct GroqCompatibleProfile;
643
644impl CompatibleStreamProfile for GroqCompatibleProfile {
645 type Usage = Usage;
646 type Detail = ();
647 type FinalResponse = StreamingCompletionResponse;
648
649 fn normalize_chunk(
650 &self,
651 data: &str,
652 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
653 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
654 Ok(data) => data,
655 Err(error) => {
656 tracing::debug!(
657 "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
658 error
659 );
660 return Ok(None);
661 }
662 };
663
664 Ok(Some(
665 openai_chat_completions_compatible::normalize_first_choice_chunk(
666 data.id,
667 data.model,
668 data.usage,
669 &data.choices,
670 |choice| match &choice.delta {
671 StreamingDelta::Reasoning { reasoning } => CompatibleChoiceData {
672 finish_reason: CompatibleFinishReason::Other,
673 text: None,
674 reasoning: Some(reasoning.clone()),
675 tool_calls: Vec::new(),
676 details: Vec::new(),
677 },
678 StreamingDelta::MessageContent {
679 content,
680 tool_calls,
681 } => CompatibleChoiceData {
682 finish_reason: CompatibleFinishReason::Other,
683 text: content.clone(),
684 reasoning: None,
685 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
686 tool_calls,
687 ),
688 details: Vec::new(),
689 },
690 },
691 ),
692 ))
693 }
694
695 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
696 StreamingCompletionResponse { usage }
697 }
698
699 fn uses_distinct_tool_call_eviction(&self) -> bool {
700 true
701 }
702
703 fn emits_complete_single_chunk_tool_calls(&self) -> bool {
704 true
705 }
706}
707
708pub async fn send_compatible_streaming_request<T>(
709 client: T,
710 req: Request<Vec<u8>>,
711) -> Result<
712 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
713 CompletionError,
714>
715where
716 T: HttpClientExt + Clone + 'static,
717{
718 openai_chat_completions_compatible::send_compatible_streaming_request(
719 client,
720 req,
721 GroqCompatibleProfile,
722 )
723 .await
724}
725
726#[cfg(test)]
727mod tests {
728 use crate::{
729 OneOrMany,
730 providers::{
731 groq::{GroqAdditionalParameters, GroqCompletionRequest},
732 openai::{Message, UserContent},
733 },
734 };
735
736 #[test]
737 fn serialize_groq_request() {
738 let additional_params = GroqAdditionalParameters {
739 include_reasoning: Some(true),
740 reasoning_format: Some(super::ReasoningFormat::Parsed),
741 ..Default::default()
742 };
743
744 let groq = GroqCompletionRequest {
745 model: "openai/gpt-120b-oss".to_string(),
746 temperature: None,
747 tool_choice: None,
748 stream_options: None,
749 tools: Vec::new(),
750 messages: vec![Message::User {
751 content: OneOrMany::one(UserContent::Text {
752 text: "Hello world!".to_string(),
753 }),
754 name: None,
755 }],
756 stream: false,
757 additional_params: Some(additional_params),
758 };
759
760 let json = serde_json::to_value(&groq).unwrap();
761
762 assert_eq!(
763 json,
764 serde_json::json!({
765 "model": "openai/gpt-120b-oss",
766 "messages": [
767 {
768 "role": "user",
769 "content": "Hello world!"
770 }
771 ],
772 "stream": false,
773 "include_reasoning": true,
774 "reasoning_format": "parsed"
775 })
776 )
777 }
778 #[test]
779 fn test_client_initialization() {
780 let _client =
781 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
782 let _client_from_builder = crate::providers::groq::Client::builder()
783 .api_key("dummy-key")
784 .build()
785 .expect("Client::builder() failed");
786 }
787}