1use bytes::Bytes;
15use http::Request;
16use serde_json::{Map, Value};
17use tracing::info_span;
18
19use super::openai::{
20 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
21};
22use crate::client::{
23 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
24 ProviderClient,
25};
26use crate::completion::GetTokenUsage;
27use crate::http_client::multipart::Part;
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::providers::internal::openai_chat_completions_compatible::{
30 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
31};
32
33use crate::{
34 completion::{self, CompletionError, CompletionRequest},
35 json_utils,
36 message::{self},
37 providers::openai::ToolDefinition,
38 transcription::{self, TranscriptionError},
39};
40use serde::{Deserialize, Serialize};
41
42const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
46
47#[derive(Debug, Default, Clone, Copy)]
48pub struct GroqExt;
49#[derive(Debug, Default, Clone, Copy)]
50pub struct GroqBuilder;
51
52type GroqApiKey = BearerAuth;
53
54impl Provider for GroqExt {
55 type Builder = GroqBuilder;
56 const VERIFY_PATH: &'static str = "/models";
57}
58
59impl<H> Capabilities<H> for GroqExt {
60 type Completion = Capable<CompletionModel<H>>;
61 type Embeddings = Nothing;
62 type Transcription = Capable<TranscriptionModel<H>>;
63 type ModelListing = Nothing;
64 #[cfg(feature = "image")]
65 type ImageGeneration = Nothing;
66
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Nothing;
69}
70
71impl DebugExt for GroqExt {}
72
73impl ProviderBuilder for GroqBuilder {
74 type Extension<H>
75 = GroqExt
76 where
77 H: HttpClientExt;
78 type ApiKey = GroqApiKey;
79
80 const BASE_URL: &'static str = GROQ_API_BASE_URL;
81
82 fn build<H>(
83 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
84 ) -> http_client::Result<Self::Extension<H>>
85 where
86 H: HttpClientExt,
87 {
88 Ok(GroqExt)
89 }
90}
91
92pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
93pub type ClientBuilder<H = crate::markers::Missing> = client::ClientBuilder<GroqBuilder, String, H>;
94
95impl ProviderClient for Client {
96 type Input = String;
97 type Error = crate::client::ProviderClientError;
98
99 fn from_env() -> Result<Self, Self::Error> {
101 let api_key = crate::client::required_env_var("GROQ_API_KEY")?;
102 Self::new(&api_key).map_err(Into::into)
103 }
104
105 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
106 Self::new(&input).map_err(Into::into)
107 }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112 message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118 Ok(T),
119 Err(ApiErrorResponse),
120}
121
122pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
128pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
130pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
132pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
134pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
136pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
138pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
140pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
142pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
144pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
146pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
148pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
150pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
152
153#[derive(Clone, Debug, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ReasoningFormat {
156 Parsed,
157 Raw,
158 Hidden,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162pub(super) struct GroqCompletionRequest {
163 model: String,
164 pub messages: Vec<OpenAIMessage>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 temperature: Option<f64>,
167 #[serde(skip_serializing_if = "Vec::is_empty")]
168 tools: Vec<ToolDefinition>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
171 #[serde(flatten, skip_serializing_if = "Option::is_none")]
172 pub additional_params: Option<GroqAdditionalParameters>,
173 pub(super) stream: bool,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub(super) stream_options: Option<StreamOptions>,
176}
177
178#[derive(Debug, Serialize, Deserialize, Default)]
179pub(super) struct StreamOptions {
180 pub(super) include_usage: bool,
181}
182
183impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
184 type Error = CompletionError;
185
186 fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
187 if req.output_schema.is_some() {
188 tracing::warn!("Structured outputs currently not supported for Groq");
189 }
190 let model = req.model.clone().unwrap_or_else(|| model.to_string());
191 let mut partial_history = vec![];
193 if let Some(docs) = req.normalized_documents() {
194 partial_history.push(docs);
195 }
196 partial_history.extend(req.chat_history);
197
198 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
200 Some(preamble) => vec![OpenAIMessage::system(preamble)],
201 None => vec![],
202 };
203
204 full_history.extend(
206 partial_history
207 .into_iter()
208 .map(message::Message::try_into)
209 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
210 .into_iter()
211 .flatten()
212 .collect::<Vec<_>>(),
213 );
214
215 let tool_choice = req
216 .tool_choice
217 .clone()
218 .map(crate::providers::openai::ToolChoice::try_from)
219 .transpose()?;
220
221 let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
222 let native_tools =
223 extract_native_tools_from_additional_params(&mut additional_params_payload)?;
224
225 let mut additional_params: Option<GroqAdditionalParameters> =
226 if additional_params_payload.is_null() {
227 None
228 } else {
229 Some(serde_json::from_value(additional_params_payload)?)
230 };
231 apply_native_tools_to_additional_params(&mut additional_params, native_tools);
232
233 Ok(Self {
234 model: model.to_string(),
235 messages: full_history,
236 temperature: req.temperature,
237 tools: req
238 .tools
239 .clone()
240 .into_iter()
241 .map(ToolDefinition::from)
242 .collect::<Vec<_>>(),
243 tool_choice,
244 additional_params,
245 stream: false,
246 stream_options: None,
247 })
248 }
249}
250
251fn extract_native_tools_from_additional_params(
252 additional_params: &mut Value,
253) -> Result<Vec<Value>, CompletionError> {
254 if let Some(map) = additional_params.as_object_mut()
255 && let Some(raw_tools) = map.remove("tools")
256 {
257 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
258 CompletionError::RequestError(
259 format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
260 )
261 });
262 }
263
264 Ok(Vec::new())
265}
266
267fn apply_native_tools_to_additional_params(
268 additional_params: &mut Option<GroqAdditionalParameters>,
269 native_tools: Vec<Value>,
270) {
271 if native_tools.is_empty() {
272 return;
273 }
274
275 let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
276 let extra = params.extra.get_or_insert_with(Map::new);
277
278 let mut compound_custom = match extra.remove("compound_custom") {
279 Some(Value::Object(map)) => map,
280 _ => Map::new(),
281 };
282
283 let mut enabled_tools = match compound_custom.remove("enabled_tools") {
284 Some(Value::Array(values)) => values,
285 _ => Vec::new(),
286 };
287
288 for native_tool in native_tools {
289 let already_enabled = enabled_tools
290 .iter()
291 .any(|existing| native_tools_match(existing, &native_tool));
292 if !already_enabled {
293 enabled_tools.push(native_tool);
294 }
295 }
296
297 compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
298 extra.insert(
299 "compound_custom".to_string(),
300 Value::Object(compound_custom),
301 );
302}
303
304fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
305 if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
306 return lhs_type == rhs_type;
307 }
308
309 lhs == rhs
310}
311
312fn native_tool_kind(value: &Value) -> Option<&str> {
313 match value {
314 Value::String(kind) => Some(kind),
315 Value::Object(map) => map.get("type").and_then(Value::as_str),
316 _ => None,
317 }
318}
319
320#[derive(Clone, Debug, Default, Serialize, Deserialize)]
322pub struct GroqAdditionalParameters {
323 #[serde(skip_serializing_if = "Option::is_none")]
325 pub reasoning_format: Option<ReasoningFormat>,
326 #[serde(skip_serializing_if = "Option::is_none")]
328 pub include_reasoning: Option<bool>,
329 #[serde(flatten, skip_serializing_if = "Option::is_none")]
331 pub extra: Option<Map<String, serde_json::Value>>,
332}
333
334#[derive(Clone, Debug)]
335pub struct CompletionModel<T = reqwest::Client> {
336 client: Client<T>,
337 pub model: String,
339}
340
341impl<T> CompletionModel<T> {
342 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
343 Self {
344 client,
345 model: model.into(),
346 }
347 }
348}
349
350impl<T> completion::CompletionModel for CompletionModel<T>
351where
352 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
353{
354 type Response = CompletionResponse;
355 type StreamingResponse = StreamingCompletionResponse;
356
357 type Client = Client<T>;
358
359 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
360 Self::new(client.clone(), model)
361 }
362
363 async fn completion(
364 &self,
365 completion_request: CompletionRequest,
366 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
367 let span = if tracing::Span::current().is_disabled() {
368 info_span!(
369 target: "rig::completions",
370 "chat",
371 gen_ai.operation.name = "chat",
372 gen_ai.provider.name = "groq",
373 gen_ai.request.model = self.model,
374 gen_ai.system_instructions = tracing::field::Empty,
375 gen_ai.response.id = tracing::field::Empty,
376 gen_ai.response.model = tracing::field::Empty,
377 gen_ai.usage.output_tokens = tracing::field::Empty,
378 gen_ai.usage.input_tokens = tracing::field::Empty,
379 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
380 )
381 } else {
382 tracing::Span::current()
383 };
384
385 span.record("gen_ai.system_instructions", &completion_request.preamble);
386
387 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
388
389 if tracing::enabled!(tracing::Level::TRACE) {
390 tracing::trace!(target: "rig::completions",
391 "Groq completion request: {}",
392 serde_json::to_string_pretty(&request)?
393 );
394 }
395
396 let body = serde_json::to_vec(&request)?;
397 let req = self
398 .client
399 .post("/chat/completions")?
400 .body(body)
401 .map_err(|e| http_client::Error::Instance(e.into()))?;
402
403 let async_block = async move {
404 let response = self.client.send::<_, Bytes>(req).await?;
405 let status = response.status();
406 let response_body = response.into_body().into_future().await?.to_vec();
407
408 if status.is_success() {
409 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
410 ApiResponse::Ok(response) => {
411 let span = tracing::Span::current();
412 span.record("gen_ai.response.id", response.id.clone());
413 span.record("gen_ai.response.model", response.model.clone());
414 if let Some(ref usage) = response.usage {
415 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
416 span.record(
417 "gen_ai.usage.output_tokens",
418 usage.total_tokens - usage.prompt_tokens,
419 );
420 span.record(
421 "gen_ai.usage.cache_read.input_tokens",
422 usage
423 .prompt_tokens_details
424 .as_ref()
425 .map(|d| d.cached_tokens)
426 .unwrap_or(0),
427 );
428 }
429
430 if tracing::enabled!(tracing::Level::TRACE) {
431 tracing::trace!(target: "rig::completions",
432 "Groq completion response: {}",
433 serde_json::to_string_pretty(&response)?
434 );
435 }
436
437 response.try_into()
438 }
439 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
440 }
441 } else {
442 Err(CompletionError::ProviderError(
443 String::from_utf8_lossy(&response_body).to_string(),
444 ))
445 }
446 };
447
448 tracing::Instrument::instrument(async_block, span).await
449 }
450
451 async fn stream(
452 &self,
453 request: CompletionRequest,
454 ) -> Result<
455 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
456 CompletionError,
457 > {
458 let span = if tracing::Span::current().is_disabled() {
459 info_span!(
460 target: "rig::completions",
461 "chat_streaming",
462 gen_ai.operation.name = "chat_streaming",
463 gen_ai.provider.name = "groq",
464 gen_ai.request.model = self.model,
465 gen_ai.system_instructions = tracing::field::Empty,
466 gen_ai.response.id = tracing::field::Empty,
467 gen_ai.response.model = tracing::field::Empty,
468 gen_ai.usage.output_tokens = tracing::field::Empty,
469 gen_ai.usage.input_tokens = tracing::field::Empty,
470 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
471 )
472 } else {
473 tracing::Span::current()
474 };
475
476 span.record("gen_ai.system_instructions", &request.preamble);
477
478 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
479
480 request.stream = true;
481 request.stream_options = Some(StreamOptions {
482 include_usage: true,
483 });
484
485 if tracing::enabled!(tracing::Level::TRACE) {
486 tracing::trace!(target: "rig::completions",
487 "Groq streaming completion request: {}",
488 serde_json::to_string_pretty(&request)?
489 );
490 }
491
492 let body = serde_json::to_vec(&request)?;
493 let req = self
494 .client
495 .post("/chat/completions")?
496 .body(body)
497 .map_err(|e| http_client::Error::Instance(e.into()))?;
498
499 tracing::Instrument::instrument(
500 send_compatible_streaming_request(self.client.clone(), req),
501 span,
502 )
503 .await
504 }
505}
506
507pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
512pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
513pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
514
515#[derive(Clone)]
516pub struct TranscriptionModel<T> {
517 client: Client<T>,
518 pub model: String,
520}
521
522impl<T> TranscriptionModel<T> {
523 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
524 Self {
525 client,
526 model: model.into(),
527 }
528 }
529}
530impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
531where
532 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
533{
534 type Response = TranscriptionResponse;
535
536 type Client = Client<T>;
537
538 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
539 Self::new(client.clone(), model)
540 }
541
542 async fn transcription(
543 &self,
544 request: transcription::TranscriptionRequest,
545 ) -> Result<
546 transcription::TranscriptionResponse<Self::Response>,
547 transcription::TranscriptionError,
548 > {
549 let data = request.data;
550
551 let mut body = MultipartForm::new()
552 .text("model", self.model.clone())
553 .part(Part::bytes("file", data).filename(request.filename.clone()));
554
555 if let Some(language) = request.language {
556 body = body.text("language", language);
557 }
558
559 if let Some(prompt) = request.prompt {
560 body = body.text("prompt", prompt.clone());
561 }
562
563 if let Some(ref temperature) = request.temperature {
564 body = body.text("temperature", temperature.to_string());
565 }
566
567 if let Some(ref additional_params) = request.additional_params {
568 let params = additional_params.as_object().ok_or_else(|| {
569 TranscriptionError::RequestError(Box::new(std::io::Error::new(
570 std::io::ErrorKind::InvalidInput,
571 "additional transcription parameters must be a JSON object",
572 )))
573 })?;
574
575 for (key, value) in params {
576 body = body.text(key.to_owned(), value.to_string());
577 }
578 }
579
580 let req = self
581 .client
582 .post("/audio/transcriptions")?
583 .body(body)
584 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
585
586 let response = self.client.send_multipart::<Bytes>(req).await?;
587
588 let status = response.status();
589 let response_body = response.into_body().into_future().await?.to_vec();
590
591 if status.is_success() {
592 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
593 ApiResponse::Ok(response) => response.try_into(),
594 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
595 api_error_response.message,
596 )),
597 }
598 } else {
599 Err(TranscriptionError::ProviderError(
600 String::from_utf8_lossy(&response_body).to_string(),
601 ))
602 }
603 }
604}
605
606#[derive(Deserialize, Debug)]
607#[serde(untagged)]
608enum StreamingDelta {
609 Reasoning {
610 reasoning: String,
611 },
612 MessageContent {
613 #[serde(default)]
614 content: Option<String>,
615 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
616 tool_calls: Vec<StreamingToolCall>,
617 },
618}
619
620#[derive(Deserialize, Debug)]
621struct StreamingChoice {
622 delta: StreamingDelta,
623}
624
625#[derive(Deserialize, Debug)]
626struct StreamingCompletionChunk {
627 id: Option<String>,
628 model: Option<String>,
629 choices: Vec<StreamingChoice>,
630 usage: Option<Usage>,
631}
632
633#[derive(Clone, Deserialize, Serialize, Debug)]
634pub struct StreamingCompletionResponse {
635 pub usage: Usage,
636}
637
638impl GetTokenUsage for StreamingCompletionResponse {
639 fn token_usage(&self) -> Option<crate::completion::Usage> {
640 self.usage.token_usage()
641 }
642}
643
644#[derive(Clone, Copy)]
645struct GroqCompatibleProfile;
646
647impl CompatibleStreamProfile for GroqCompatibleProfile {
648 type Usage = Usage;
649 type Detail = ();
650 type FinalResponse = StreamingCompletionResponse;
651
652 fn normalize_chunk(
653 &self,
654 data: &str,
655 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
656 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
657 Ok(data) => data,
658 Err(error) => {
659 tracing::debug!(
660 "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
661 error
662 );
663 return Ok(None);
664 }
665 };
666
667 Ok(Some(
668 openai_chat_completions_compatible::normalize_first_choice_chunk(
669 data.id,
670 data.model,
671 data.usage,
672 &data.choices,
673 |choice| match &choice.delta {
674 StreamingDelta::Reasoning { reasoning } => CompatibleChoiceData {
675 finish_reason: CompatibleFinishReason::Other,
676 text: None,
677 reasoning: Some(reasoning.clone()),
678 tool_calls: Vec::new(),
679 details: Vec::new(),
680 },
681 StreamingDelta::MessageContent {
682 content,
683 tool_calls,
684 } => CompatibleChoiceData {
685 finish_reason: CompatibleFinishReason::Other,
686 text: content.clone(),
687 reasoning: None,
688 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
689 tool_calls,
690 ),
691 details: Vec::new(),
692 },
693 },
694 ),
695 ))
696 }
697
698 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
699 StreamingCompletionResponse { usage }
700 }
701
702 fn uses_distinct_tool_call_eviction(&self) -> bool {
703 true
704 }
705
706 fn emits_complete_single_chunk_tool_calls(&self) -> bool {
707 true
708 }
709}
710
711pub async fn send_compatible_streaming_request<T>(
712 client: T,
713 req: Request<Vec<u8>>,
714) -> Result<
715 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
716 CompletionError,
717>
718where
719 T: HttpClientExt + Clone + 'static,
720{
721 openai_chat_completions_compatible::send_compatible_streaming_request(
722 client,
723 req,
724 GroqCompatibleProfile,
725 )
726 .await
727}
728
729#[cfg(test)]
730mod tests {
731 use crate::{
732 OneOrMany,
733 providers::{
734 groq::{GroqAdditionalParameters, GroqCompletionRequest},
735 openai::{Message, UserContent},
736 },
737 };
738
739 #[test]
740 fn serialize_groq_request() {
741 let additional_params = GroqAdditionalParameters {
742 include_reasoning: Some(true),
743 reasoning_format: Some(super::ReasoningFormat::Parsed),
744 ..Default::default()
745 };
746
747 let groq = GroqCompletionRequest {
748 model: "openai/gpt-120b-oss".to_string(),
749 temperature: None,
750 tool_choice: None,
751 stream_options: None,
752 tools: Vec::new(),
753 messages: vec![Message::User {
754 content: OneOrMany::one(UserContent::Text {
755 text: "Hello world!".to_string(),
756 }),
757 name: None,
758 }],
759 stream: false,
760 additional_params: Some(additional_params),
761 };
762
763 let json = serde_json::to_value(&groq).unwrap();
764
765 assert_eq!(
766 json,
767 serde_json::json!({
768 "model": "openai/gpt-120b-oss",
769 "messages": [
770 {
771 "role": "user",
772 "content": "Hello world!"
773 }
774 ],
775 "stream": false,
776 "include_reasoning": true,
777 "reasoning_format": "parsed"
778 })
779 )
780 }
781 #[test]
782 fn test_client_initialization() {
783 let _client =
784 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
785 let _client_from_builder = crate::providers::groq::Client::builder()
786 .api_key("dummy-key")
787 .build()
788 .expect("Client::builder() failed");
789 }
790}