Skip to main content

rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use serde_json::{Map, Value};
14use tracing::info_span;
15
16use super::openai::{
17    CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
18};
19use crate::client::{
20    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
21    ProviderClient,
22};
23use crate::completion::GetTokenUsage;
24use crate::http_client::multipart::Part;
25use crate::http_client::{self, HttpClientExt, MultipartForm};
26use crate::providers::internal::openai_chat_completions_compatible::{
27    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
28};
29
30use crate::{
31    completion::{self, CompletionError, CompletionRequest},
32    json_utils,
33    message::{self},
34    providers::openai::ToolDefinition,
35    transcription::{self, TranscriptionError},
36};
37use serde::{Deserialize, Serialize};
38
39// ================================================================
40// Main Groq Client
41// ================================================================
42const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
43
44#[derive(Debug, Default, Clone, Copy)]
45pub struct GroqExt;
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqBuilder;
48
49type GroqApiKey = BearerAuth;
50
51impl Provider for GroqExt {
52    type Builder = GroqBuilder;
53    const VERIFY_PATH: &'static str = "/models";
54}
55
56impl<H> Capabilities<H> for GroqExt {
57    type Completion = Capable<CompletionModel<H>>;
58    type Embeddings = Nothing;
59    type Transcription = Capable<TranscriptionModel<H>>;
60    type ModelListing = Nothing;
61    #[cfg(feature = "image")]
62    type ImageGeneration = Nothing;
63
64    #[cfg(feature = "audio")]
65    type AudioGeneration = Nothing;
66}
67
68impl DebugExt for GroqExt {}
69
70impl ProviderBuilder for GroqBuilder {
71    type Extension<H>
72        = GroqExt
73    where
74        H: HttpClientExt;
75    type ApiKey = GroqApiKey;
76
77    const BASE_URL: &'static str = GROQ_API_BASE_URL;
78
79    fn build<H>(
80        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
81    ) -> http_client::Result<Self::Extension<H>>
82    where
83        H: HttpClientExt,
84    {
85        Ok(GroqExt)
86    }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93    type Input = String;
94    type Error = crate::client::ProviderClientError;
95
96    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
97    fn from_env() -> Result<Self, Self::Error> {
98        let api_key = crate::client::required_env_var("GROQ_API_KEY")?;
99        Self::new(&api_key).map_err(Into::into)
100    }
101
102    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
103        Self::new(&input).map_err(Into::into)
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119// ================================================================
120// Groq Completion API
121// ================================================================
122
123/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
124pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
125/// The `gemma2-9b-it` model. Used for chat completion.
126pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
127/// The `llama-3.1-8b-instant` model. Used for chat completion.
128pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
129/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
130pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
131/// The `llama-3.2-1b-preview` model. Used for chat completion.
132pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
133/// The `llama-3.2-3b-preview` model. Used for chat completion.
134pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
135/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
136pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
137/// The `llama-3.2-70b-specdec` model. Used for chat completion.
138pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
139/// The `llama-3.2-70b-versatile` model. Used for chat completion.
140pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
141/// The `llama-guard-3-8b` model. Used for chat completion.
142pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
143/// The `llama3-70b-8192` model. Used for chat completion.
144pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
145/// The `llama3-8b-8192` model. Used for chat completion.
146pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
147/// The `mixtral-8x7b-32768` model. Used for chat completion.
148pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
149
150#[derive(Clone, Debug, Serialize, Deserialize)]
151#[serde(rename_all = "lowercase")]
152pub enum ReasoningFormat {
153    Parsed,
154    Raw,
155    Hidden,
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub(super) struct GroqCompletionRequest {
160    model: String,
161    pub messages: Vec<OpenAIMessage>,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    temperature: Option<f64>,
164    #[serde(skip_serializing_if = "Vec::is_empty")]
165    tools: Vec<ToolDefinition>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
168    #[serde(flatten, skip_serializing_if = "Option::is_none")]
169    pub additional_params: Option<GroqAdditionalParameters>,
170    pub(super) stream: bool,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub(super) stream_options: Option<StreamOptions>,
173}
174
175#[derive(Debug, Serialize, Deserialize, Default)]
176pub(super) struct StreamOptions {
177    pub(super) include_usage: bool,
178}
179
180impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
181    type Error = CompletionError;
182
183    fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
184        if req.output_schema.is_some() {
185            tracing::warn!("Structured outputs currently not supported for Groq");
186        }
187        let model = req.model.clone().unwrap_or_else(|| model.to_string());
188        // Build up the order of messages (context, chat_history, prompt)
189        let mut partial_history = vec![];
190        if let Some(docs) = req.normalized_documents() {
191            partial_history.push(docs);
192        }
193        partial_history.extend(req.chat_history);
194
195        // Add preamble to chat history (if available)
196        let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
197            Some(preamble) => vec![OpenAIMessage::system(preamble)],
198            None => vec![],
199        };
200
201        // Convert and extend the rest of the history
202        full_history.extend(
203            partial_history
204                .into_iter()
205                .map(message::Message::try_into)
206                .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
207                .into_iter()
208                .flatten()
209                .collect::<Vec<_>>(),
210        );
211
212        let tool_choice = req
213            .tool_choice
214            .clone()
215            .map(crate::providers::openai::ToolChoice::try_from)
216            .transpose()?;
217
218        let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
219        let native_tools =
220            extract_native_tools_from_additional_params(&mut additional_params_payload)?;
221
222        let mut additional_params: Option<GroqAdditionalParameters> =
223            if additional_params_payload.is_null() {
224                None
225            } else {
226                Some(serde_json::from_value(additional_params_payload)?)
227            };
228        apply_native_tools_to_additional_params(&mut additional_params, native_tools);
229
230        Ok(Self {
231            model: model.to_string(),
232            messages: full_history,
233            temperature: req.temperature,
234            tools: req
235                .tools
236                .clone()
237                .into_iter()
238                .map(ToolDefinition::from)
239                .collect::<Vec<_>>(),
240            tool_choice,
241            additional_params,
242            stream: false,
243            stream_options: None,
244        })
245    }
246}
247
248fn extract_native_tools_from_additional_params(
249    additional_params: &mut Value,
250) -> Result<Vec<Value>, CompletionError> {
251    if let Some(map) = additional_params.as_object_mut()
252        && let Some(raw_tools) = map.remove("tools")
253    {
254        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
255            CompletionError::RequestError(
256                format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
257            )
258        });
259    }
260
261    Ok(Vec::new())
262}
263
264fn apply_native_tools_to_additional_params(
265    additional_params: &mut Option<GroqAdditionalParameters>,
266    native_tools: Vec<Value>,
267) {
268    if native_tools.is_empty() {
269        return;
270    }
271
272    let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
273    let extra = params.extra.get_or_insert_with(Map::new);
274
275    let mut compound_custom = match extra.remove("compound_custom") {
276        Some(Value::Object(map)) => map,
277        _ => Map::new(),
278    };
279
280    let mut enabled_tools = match compound_custom.remove("enabled_tools") {
281        Some(Value::Array(values)) => values,
282        _ => Vec::new(),
283    };
284
285    for native_tool in native_tools {
286        let already_enabled = enabled_tools
287            .iter()
288            .any(|existing| native_tools_match(existing, &native_tool));
289        if !already_enabled {
290            enabled_tools.push(native_tool);
291        }
292    }
293
294    compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
295    extra.insert(
296        "compound_custom".to_string(),
297        Value::Object(compound_custom),
298    );
299}
300
301fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
302    if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
303        return lhs_type == rhs_type;
304    }
305
306    lhs == rhs
307}
308
309fn native_tool_kind(value: &Value) -> Option<&str> {
310    match value {
311        Value::String(kind) => Some(kind),
312        Value::Object(map) => map.get("type").and_then(Value::as_str),
313        _ => None,
314    }
315}
316
317/// Additional parameters to send to the Groq API
318#[derive(Clone, Debug, Default, Serialize, Deserialize)]
319pub struct GroqAdditionalParameters {
320    /// The reasoning format. See Groq's API docs for more details.
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub reasoning_format: Option<ReasoningFormat>,
323    /// Whether or not to include reasoning. See Groq's API docs for more details.
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub include_reasoning: Option<bool>,
326    /// Any other properties not included by default on this struct (that you want to send)
327    #[serde(flatten, skip_serializing_if = "Option::is_none")]
328    pub extra: Option<Map<String, serde_json::Value>>,
329}
330
331#[derive(Clone, Debug)]
332pub struct CompletionModel<T = reqwest::Client> {
333    client: Client<T>,
334    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
335    pub model: String,
336}
337
338impl<T> CompletionModel<T> {
339    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
340        Self {
341            client,
342            model: model.into(),
343        }
344    }
345}
346
347impl<T> completion::CompletionModel for CompletionModel<T>
348where
349    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
350{
351    type Response = CompletionResponse;
352    type StreamingResponse = StreamingCompletionResponse;
353
354    type Client = Client<T>;
355
356    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
357        Self::new(client.clone(), model)
358    }
359
360    async fn completion(
361        &self,
362        completion_request: CompletionRequest,
363    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
364        let span = if tracing::Span::current().is_disabled() {
365            info_span!(
366                target: "rig::completions",
367                "chat",
368                gen_ai.operation.name = "chat",
369                gen_ai.provider.name = "groq",
370                gen_ai.request.model = self.model,
371                gen_ai.system_instructions = tracing::field::Empty,
372                gen_ai.response.id = tracing::field::Empty,
373                gen_ai.response.model = tracing::field::Empty,
374                gen_ai.usage.output_tokens = tracing::field::Empty,
375                gen_ai.usage.input_tokens = tracing::field::Empty,
376                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
377            )
378        } else {
379            tracing::Span::current()
380        };
381
382        span.record("gen_ai.system_instructions", &completion_request.preamble);
383
384        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
385
386        if tracing::enabled!(tracing::Level::TRACE) {
387            tracing::trace!(target: "rig::completions",
388                "Groq completion request: {}",
389                serde_json::to_string_pretty(&request)?
390            );
391        }
392
393        let body = serde_json::to_vec(&request)?;
394        let req = self
395            .client
396            .post("/chat/completions")?
397            .body(body)
398            .map_err(|e| http_client::Error::Instance(e.into()))?;
399
400        let async_block = async move {
401            let response = self.client.send::<_, Bytes>(req).await?;
402            let status = response.status();
403            let response_body = response.into_body().into_future().await?.to_vec();
404
405            if status.is_success() {
406                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
407                    ApiResponse::Ok(response) => {
408                        let span = tracing::Span::current();
409                        span.record("gen_ai.response.id", response.id.clone());
410                        span.record("gen_ai.response.model", response.model.clone());
411                        if let Some(ref usage) = response.usage {
412                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
413                            span.record(
414                                "gen_ai.usage.output_tokens",
415                                usage.total_tokens - usage.prompt_tokens,
416                            );
417                            span.record(
418                                "gen_ai.usage.cache_read.input_tokens",
419                                usage
420                                    .prompt_tokens_details
421                                    .as_ref()
422                                    .map(|d| d.cached_tokens)
423                                    .unwrap_or(0),
424                            );
425                        }
426
427                        if tracing::enabled!(tracing::Level::TRACE) {
428                            tracing::trace!(target: "rig::completions",
429                                "Groq completion response: {}",
430                                serde_json::to_string_pretty(&response)?
431                            );
432                        }
433
434                        response.try_into()
435                    }
436                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
437                }
438            } else {
439                Err(CompletionError::ProviderError(
440                    String::from_utf8_lossy(&response_body).to_string(),
441                ))
442            }
443        };
444
445        tracing::Instrument::instrument(async_block, span).await
446    }
447
448    async fn stream(
449        &self,
450        request: CompletionRequest,
451    ) -> Result<
452        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
453        CompletionError,
454    > {
455        let span = if tracing::Span::current().is_disabled() {
456            info_span!(
457                target: "rig::completions",
458                "chat_streaming",
459                gen_ai.operation.name = "chat_streaming",
460                gen_ai.provider.name = "groq",
461                gen_ai.request.model = self.model,
462                gen_ai.system_instructions = tracing::field::Empty,
463                gen_ai.response.id = tracing::field::Empty,
464                gen_ai.response.model = tracing::field::Empty,
465                gen_ai.usage.output_tokens = tracing::field::Empty,
466                gen_ai.usage.input_tokens = tracing::field::Empty,
467                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
468            )
469        } else {
470            tracing::Span::current()
471        };
472
473        span.record("gen_ai.system_instructions", &request.preamble);
474
475        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
476
477        request.stream = true;
478        request.stream_options = Some(StreamOptions {
479            include_usage: true,
480        });
481
482        if tracing::enabled!(tracing::Level::TRACE) {
483            tracing::trace!(target: "rig::completions",
484                "Groq streaming completion request: {}",
485                serde_json::to_string_pretty(&request)?
486            );
487        }
488
489        let body = serde_json::to_vec(&request)?;
490        let req = self
491            .client
492            .post("/chat/completions")?
493            .body(body)
494            .map_err(|e| http_client::Error::Instance(e.into()))?;
495
496        tracing::Instrument::instrument(
497            send_compatible_streaming_request(self.client.clone(), req),
498            span,
499        )
500        .await
501    }
502}
503
504// ================================================================
505// Groq Transcription API
506// ================================================================
507
508pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
509pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
510pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
511
512#[derive(Clone)]
513pub struct TranscriptionModel<T> {
514    client: Client<T>,
515    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
516    pub model: String,
517}
518
519impl<T> TranscriptionModel<T> {
520    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
521        Self {
522            client,
523            model: model.into(),
524        }
525    }
526}
527impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
528where
529    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
530{
531    type Response = TranscriptionResponse;
532
533    type Client = Client<T>;
534
535    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536        Self::new(client.clone(), model)
537    }
538
539    async fn transcription(
540        &self,
541        request: transcription::TranscriptionRequest,
542    ) -> Result<
543        transcription::TranscriptionResponse<Self::Response>,
544        transcription::TranscriptionError,
545    > {
546        let data = request.data;
547
548        let mut body = MultipartForm::new()
549            .text("model", self.model.clone())
550            .part(Part::bytes("file", data).filename(request.filename.clone()));
551
552        if let Some(language) = request.language {
553            body = body.text("language", language);
554        }
555
556        if let Some(prompt) = request.prompt {
557            body = body.text("prompt", prompt.clone());
558        }
559
560        if let Some(ref temperature) = request.temperature {
561            body = body.text("temperature", temperature.to_string());
562        }
563
564        if let Some(ref additional_params) = request.additional_params {
565            let params = additional_params.as_object().ok_or_else(|| {
566                TranscriptionError::RequestError(Box::new(std::io::Error::new(
567                    std::io::ErrorKind::InvalidInput,
568                    "additional transcription parameters must be a JSON object",
569                )))
570            })?;
571
572            for (key, value) in params {
573                body = body.text(key.to_owned(), value.to_string());
574            }
575        }
576
577        let req = self
578            .client
579            .post("/audio/transcriptions")?
580            .body(body)
581            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
582
583        let response = self.client.send_multipart::<Bytes>(req).await?;
584
585        let status = response.status();
586        let response_body = response.into_body().into_future().await?.to_vec();
587
588        if status.is_success() {
589            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
590                ApiResponse::Ok(response) => response.try_into(),
591                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
592                    api_error_response.message,
593                )),
594            }
595        } else {
596            Err(TranscriptionError::ProviderError(
597                String::from_utf8_lossy(&response_body).to_string(),
598            ))
599        }
600    }
601}
602
603#[derive(Deserialize, Debug)]
604#[serde(untagged)]
605enum StreamingDelta {
606    Reasoning {
607        reasoning: String,
608    },
609    MessageContent {
610        #[serde(default)]
611        content: Option<String>,
612        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
613        tool_calls: Vec<StreamingToolCall>,
614    },
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingChoice {
619    delta: StreamingDelta,
620}
621
622#[derive(Deserialize, Debug)]
623struct StreamingCompletionChunk {
624    id: Option<String>,
625    model: Option<String>,
626    choices: Vec<StreamingChoice>,
627    usage: Option<Usage>,
628}
629
630#[derive(Clone, Deserialize, Serialize, Debug)]
631pub struct StreamingCompletionResponse {
632    pub usage: Usage,
633}
634
635impl GetTokenUsage for StreamingCompletionResponse {
636    fn token_usage(&self) -> Option<crate::completion::Usage> {
637        self.usage.token_usage()
638    }
639}
640
641#[derive(Clone, Copy)]
642struct GroqCompatibleProfile;
643
644impl CompatibleStreamProfile for GroqCompatibleProfile {
645    type Usage = Usage;
646    type Detail = ();
647    type FinalResponse = StreamingCompletionResponse;
648
649    fn normalize_chunk(
650        &self,
651        data: &str,
652    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
653        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
654            Ok(data) => data,
655            Err(error) => {
656                tracing::debug!(
657                    "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
658                    error
659                );
660                return Ok(None);
661            }
662        };
663
664        Ok(Some(
665            openai_chat_completions_compatible::normalize_first_choice_chunk(
666                data.id,
667                data.model,
668                data.usage,
669                &data.choices,
670                |choice| match &choice.delta {
671                    StreamingDelta::Reasoning { reasoning } => CompatibleChoiceData {
672                        finish_reason: CompatibleFinishReason::Other,
673                        text: None,
674                        reasoning: Some(reasoning.clone()),
675                        tool_calls: Vec::new(),
676                        details: Vec::new(),
677                    },
678                    StreamingDelta::MessageContent {
679                        content,
680                        tool_calls,
681                    } => CompatibleChoiceData {
682                        finish_reason: CompatibleFinishReason::Other,
683                        text: content.clone(),
684                        reasoning: None,
685                        tool_calls: openai_chat_completions_compatible::tool_call_chunks(
686                            tool_calls,
687                        ),
688                        details: Vec::new(),
689                    },
690                },
691            ),
692        ))
693    }
694
695    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
696        StreamingCompletionResponse { usage }
697    }
698
699    fn uses_distinct_tool_call_eviction(&self) -> bool {
700        true
701    }
702
703    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
704        true
705    }
706}
707
708pub async fn send_compatible_streaming_request<T>(
709    client: T,
710    req: Request<Vec<u8>>,
711) -> Result<
712    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
713    CompletionError,
714>
715where
716    T: HttpClientExt + Clone + 'static,
717{
718    openai_chat_completions_compatible::send_compatible_streaming_request(
719        client,
720        req,
721        GroqCompatibleProfile,
722    )
723    .await
724}
725
726#[cfg(test)]
727mod tests {
728    use crate::{
729        OneOrMany,
730        providers::{
731            groq::{GroqAdditionalParameters, GroqCompletionRequest},
732            openai::{Message, UserContent},
733        },
734    };
735
736    #[test]
737    fn serialize_groq_request() {
738        let additional_params = GroqAdditionalParameters {
739            include_reasoning: Some(true),
740            reasoning_format: Some(super::ReasoningFormat::Parsed),
741            ..Default::default()
742        };
743
744        let groq = GroqCompletionRequest {
745            model: "openai/gpt-120b-oss".to_string(),
746            temperature: None,
747            tool_choice: None,
748            stream_options: None,
749            tools: Vec::new(),
750            messages: vec![Message::User {
751                content: OneOrMany::one(UserContent::Text {
752                    text: "Hello world!".to_string(),
753                }),
754                name: None,
755            }],
756            stream: false,
757            additional_params: Some(additional_params),
758        };
759
760        let json = serde_json::to_value(&groq).unwrap();
761
762        assert_eq!(
763            json,
764            serde_json::json!({
765                "model": "openai/gpt-120b-oss",
766                "messages": [
767                    {
768                        "role": "user",
769                        "content": "Hello world!"
770                    }
771                ],
772                "stream": false,
773                "include_reasoning": true,
774                "reasoning_format": "parsed"
775            })
776        )
777    }
778    #[test]
779    fn test_client_initialization() {
780        let _client =
781            crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
782        let _client_from_builder = crate::providers::groq::Client::builder()
783            .api_key("dummy-key")
784            .build()
785            .expect("Client::builder() failed");
786    }
787}