rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19    CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23    ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    json_utils,
37    message::{self},
38    providers::openai::ToolDefinition,
39    transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43// ================================================================
44// Main Groq Client
45// ================================================================
46const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56    type Builder = GroqBuilder;
57
58    const VERIFY_PATH: &'static str = "/models";
59
60    fn build<H>(
61        _: &crate::client::ClientBuilder<
62            Self::Builder,
63            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
64            H,
65        >,
66    ) -> http_client::Result<Self> {
67        Ok(Self)
68    }
69}
70
71impl<H> Capabilities<H> for GroqExt {
72    type Completion = Capable<CompletionModel<H>>;
73    type Embeddings = Nothing;
74    type Transcription = Capable<TranscriptionModel<H>>;
75    #[cfg(feature = "image")]
76    type ImageGeneration = Nothing;
77
78    #[cfg(feature = "audio")]
79    type AudioGeneration = Nothing;
80}
81
82impl DebugExt for GroqExt {}
83
84impl ProviderBuilder for GroqBuilder {
85    type Output = GroqExt;
86    type ApiKey = GroqApiKey;
87
88    const BASE_URL: &'static str = GROQ_API_BASE_URL;
89}
90
91pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
92pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
93
94impl ProviderClient for Client {
95    type Input = String;
96
97    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
98    /// Panics if the environment variable is not set.
99    fn from_env() -> Self {
100        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
101        Self::new(&api_key).unwrap()
102    }
103
104    fn from_val(input: Self::Input) -> Self {
105        Self::new(&input).unwrap()
106    }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111    message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117    Ok(T),
118    Err(ApiErrorResponse),
119}
120
121// ================================================================
122// Groq Completion API
123// ================================================================
124
125/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
126pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
127/// The `gemma2-9b-it` model. Used for chat completion.
128pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
129/// The `llama-3.1-8b-instant` model. Used for chat completion.
130pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
131/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
132pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
133/// The `llama-3.2-1b-preview` model. Used for chat completion.
134pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
135/// The `llama-3.2-3b-preview` model. Used for chat completion.
136pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
137/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
138pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
139/// The `llama-3.2-70b-specdec` model. Used for chat completion.
140pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
141/// The `llama-3.2-70b-versatile` model. Used for chat completion.
142pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
143/// The `llama-guard-3-8b` model. Used for chat completion.
144pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
145/// The `llama3-70b-8192` model. Used for chat completion.
146pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
147/// The `llama3-8b-8192` model. Used for chat completion.
148pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
149/// The `mixtral-8x7b-32768` model. Used for chat completion.
150pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
153#[serde(rename_all = "lowercase")]
154pub enum ReasoningFormat {
155    Parsed,
156    Raw,
157    Hidden,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
161pub(super) struct GroqCompletionRequest {
162    model: String,
163    pub messages: Vec<OpenAIMessage>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    temperature: Option<f64>,
166    #[serde(skip_serializing_if = "Vec::is_empty")]
167    tools: Vec<ToolDefinition>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
170    #[serde(flatten, skip_serializing_if = "Option::is_none")]
171    pub additional_params: Option<GroqAdditionalParameters>,
172    pub(super) stream: bool,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub(super) stream_options: Option<StreamOptions>,
175}
176
177#[derive(Debug, Serialize, Deserialize, Default)]
178pub(super) struct StreamOptions {
179    pub(super) include_usage: bool,
180}
181
182impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
183    type Error = CompletionError;
184
185    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
186        // Build up the order of messages (context, chat_history, prompt)
187        let mut partial_history = vec![];
188        if let Some(docs) = req.normalized_documents() {
189            partial_history.push(docs);
190        }
191        partial_history.extend(req.chat_history);
192
193        // Add preamble to chat history (if available)
194        let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
195            Some(preamble) => vec![OpenAIMessage::system(preamble)],
196            None => vec![],
197        };
198
199        // Convert and extend the rest of the history
200        full_history.extend(
201            partial_history
202                .into_iter()
203                .map(message::Message::try_into)
204                .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
205                .into_iter()
206                .flatten()
207                .collect::<Vec<_>>(),
208        );
209
210        let tool_choice = req
211            .tool_choice
212            .clone()
213            .map(crate::providers::openai::ToolChoice::try_from)
214            .transpose()?;
215
216        let additional_params: Option<GroqAdditionalParameters> =
217            if let Some(params) = req.additional_params {
218                Some(serde_json::from_value(params)?)
219            } else {
220                None
221            };
222
223        Ok(Self {
224            model: model.to_string(),
225            messages: full_history,
226            temperature: req.temperature,
227            tools: req
228                .tools
229                .clone()
230                .into_iter()
231                .map(ToolDefinition::from)
232                .collect::<Vec<_>>(),
233            tool_choice,
234            additional_params,
235            stream: false,
236            stream_options: None,
237        })
238    }
239}
240
241/// Additional parameters to send to the Groq API
242#[derive(Clone, Debug, Default, Serialize, Deserialize)]
243pub struct GroqAdditionalParameters {
244    /// The reasoning format. See Groq's API docs for more details.
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub reasoning_format: Option<ReasoningFormat>,
247    /// Whether or not to include reasoning. See Groq's API docs for more details.
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub include_reasoning: Option<bool>,
250    /// Any other properties not included by default on this struct (that you want to send)
251    #[serde(flatten, skip_serializing_if = "Option::is_none")]
252    pub extra: Option<Map<String, serde_json::Value>>,
253}
254
255#[derive(Clone, Debug)]
256pub struct CompletionModel<T = reqwest::Client> {
257    client: Client<T>,
258    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
259    pub model: String,
260}
261
262impl<T> CompletionModel<T> {
263    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
264        Self {
265            client,
266            model: model.into(),
267        }
268    }
269}
270
271impl<T> completion::CompletionModel for CompletionModel<T>
272where
273    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
274{
275    type Response = CompletionResponse;
276    type StreamingResponse = StreamingCompletionResponse;
277
278    type Client = Client<T>;
279
280    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
281        Self::new(client.clone(), model)
282    }
283
284    async fn completion(
285        &self,
286        completion_request: CompletionRequest,
287    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
288        let span = if tracing::Span::current().is_disabled() {
289            info_span!(
290                target: "rig::completions",
291                "chat",
292                gen_ai.operation.name = "chat",
293                gen_ai.provider.name = "groq",
294                gen_ai.request.model = self.model,
295                gen_ai.system_instructions = tracing::field::Empty,
296                gen_ai.response.id = tracing::field::Empty,
297                gen_ai.response.model = tracing::field::Empty,
298                gen_ai.usage.output_tokens = tracing::field::Empty,
299                gen_ai.usage.input_tokens = tracing::field::Empty,
300            )
301        } else {
302            tracing::Span::current()
303        };
304
305        span.record("gen_ai.system_instructions", &completion_request.preamble);
306
307        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
308
309        if tracing::enabled!(tracing::Level::TRACE) {
310            tracing::trace!(target: "rig::completions",
311                "Groq completion request: {}",
312                serde_json::to_string_pretty(&request)?
313            );
314        }
315
316        let body = serde_json::to_vec(&request)?;
317        let req = self
318            .client
319            .post("/chat/completions")?
320            .body(body)
321            .map_err(|e| http_client::Error::Instance(e.into()))?;
322
323        let async_block = async move {
324            let response = self.client.send::<_, Bytes>(req).await?;
325            let status = response.status();
326            let response_body = response.into_body().into_future().await?.to_vec();
327
328            if status.is_success() {
329                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
330                    ApiResponse::Ok(response) => {
331                        let span = tracing::Span::current();
332                        span.record("gen_ai.response.id", response.id.clone());
333                        span.record("gen_ai.response.model_name", response.model.clone());
334                        if let Some(ref usage) = response.usage {
335                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
336                            span.record(
337                                "gen_ai.usage.output_tokens",
338                                usage.total_tokens - usage.prompt_tokens,
339                            );
340                        }
341
342                        if tracing::enabled!(tracing::Level::TRACE) {
343                            tracing::trace!(target: "rig::completions",
344                                "Groq completion response: {}",
345                                serde_json::to_string_pretty(&response)?
346                            );
347                        }
348
349                        response.try_into()
350                    }
351                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
352                }
353            } else {
354                Err(CompletionError::ProviderError(
355                    String::from_utf8_lossy(&response_body).to_string(),
356                ))
357            }
358        };
359
360        tracing::Instrument::instrument(async_block, span).await
361    }
362
363    async fn stream(
364        &self,
365        request: CompletionRequest,
366    ) -> Result<
367        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
368        CompletionError,
369    > {
370        let span = if tracing::Span::current().is_disabled() {
371            info_span!(
372                target: "rig::completions",
373                "chat_streaming",
374                gen_ai.operation.name = "chat_streaming",
375                gen_ai.provider.name = "groq",
376                gen_ai.request.model = self.model,
377                gen_ai.system_instructions = tracing::field::Empty,
378                gen_ai.response.id = tracing::field::Empty,
379                gen_ai.response.model = tracing::field::Empty,
380                gen_ai.usage.output_tokens = tracing::field::Empty,
381                gen_ai.usage.input_tokens = tracing::field::Empty,
382            )
383        } else {
384            tracing::Span::current()
385        };
386
387        span.record("gen_ai.system_instructions", &request.preamble);
388
389        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
390
391        request.stream = true;
392        request.stream_options = Some(StreamOptions {
393            include_usage: true,
394        });
395
396        if tracing::enabled!(tracing::Level::TRACE) {
397            tracing::trace!(target: "rig::completions",
398                "Groq streaming completion request: {}",
399                serde_json::to_string_pretty(&request)?
400            );
401        }
402
403        let body = serde_json::to_vec(&request)?;
404        let req = self
405            .client
406            .post("/chat/completions")?
407            .body(body)
408            .map_err(|e| http_client::Error::Instance(e.into()))?;
409
410        tracing::Instrument::instrument(
411            send_compatible_streaming_request(self.client.clone(), req),
412            span,
413        )
414        .await
415    }
416}
417
418// ================================================================
419// Groq Transcription API
420// ================================================================
421
422pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
423pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
424pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
425
426#[derive(Clone)]
427pub struct TranscriptionModel<T> {
428    client: Client<T>,
429    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
430    pub model: String,
431}
432
433impl<T> TranscriptionModel<T> {
434    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
435        Self {
436            client,
437            model: model.into(),
438        }
439    }
440}
441impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
442where
443    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
444{
445    type Response = TranscriptionResponse;
446
447    type Client = Client<T>;
448
449    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
450        Self::new(client.clone(), model)
451    }
452
453    async fn transcription(
454        &self,
455        request: transcription::TranscriptionRequest,
456    ) -> Result<
457        transcription::TranscriptionResponse<Self::Response>,
458        transcription::TranscriptionError,
459    > {
460        let data = request.data;
461
462        let mut body = MultipartForm::new()
463            .text("model", self.model.clone())
464            .part(Part::bytes("file", data).filename(request.filename.clone()));
465
466        if let Some(language) = request.language {
467            body = body.text("language", language);
468        }
469
470        if let Some(prompt) = request.prompt {
471            body = body.text("prompt", prompt.clone());
472        }
473
474        if let Some(ref temperature) = request.temperature {
475            body = body.text("temperature", temperature.to_string());
476        }
477
478        if let Some(ref additional_params) = request.additional_params {
479            for (key, value) in additional_params
480                .as_object()
481                .expect("Additional Parameters to OpenAI Transcription should be a map")
482            {
483                body = body.text(key.to_owned(), value.to_string());
484            }
485        }
486
487        let req = self
488            .client
489            .post("/audio/transcriptions")?
490            .body(body)
491            .unwrap();
492
493        let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
494
495        let status = response.status();
496        let response_body = response.into_body().into_future().await?.to_vec();
497
498        if status.is_success() {
499            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
500                ApiResponse::Ok(response) => response.try_into(),
501                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
502                    api_error_response.message,
503                )),
504            }
505        } else {
506            Err(TranscriptionError::ProviderError(
507                String::from_utf8_lossy(&response_body).to_string(),
508            ))
509        }
510    }
511}
512
513#[derive(Deserialize, Debug)]
514#[serde(untagged)]
515enum StreamingDelta {
516    Reasoning {
517        reasoning: String,
518    },
519    MessageContent {
520        #[serde(default)]
521        content: Option<String>,
522        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
523        tool_calls: Vec<StreamingToolCall>,
524    },
525}
526
527#[derive(Deserialize, Debug)]
528struct StreamingChoice {
529    delta: StreamingDelta,
530}
531
532#[derive(Deserialize, Debug)]
533struct StreamingCompletionChunk {
534    choices: Vec<StreamingChoice>,
535    usage: Option<Usage>,
536}
537
538#[derive(Clone, Deserialize, Serialize, Debug)]
539pub struct StreamingCompletionResponse {
540    pub usage: Usage,
541}
542
543impl GetTokenUsage for StreamingCompletionResponse {
544    fn token_usage(&self) -> Option<crate::completion::Usage> {
545        let mut usage = crate::completion::Usage::new();
546
547        usage.input_tokens = self.usage.prompt_tokens as u64;
548        usage.total_tokens = self.usage.total_tokens as u64;
549        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
550
551        Some(usage)
552    }
553}
554
555pub async fn send_compatible_streaming_request<T>(
556    client: T,
557    req: Request<Vec<u8>>,
558) -> Result<
559    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
560    CompletionError,
561>
562where
563    T: HttpClientExt + Clone + 'static,
564{
565    let span = tracing::Span::current();
566
567    let mut event_source = GenericEventSource::new(client, req);
568
569    let stream = stream! {
570        let span = tracing::Span::current();
571        let mut final_usage = Usage {
572            prompt_tokens: 0,
573            total_tokens: 0
574        };
575
576        let mut text_response = String::new();
577
578        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
579
580        while let Some(event_result) = event_source.next().await {
581            match event_result {
582                Ok(Event::Open) => {
583                    tracing::trace!("SSE connection opened");
584                    continue;
585                }
586
587                Ok(Event::Message(message)) => {
588                    let data_str = message.data.trim();
589
590                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
591                    let Ok(data) = parsed else {
592                        let err = parsed.unwrap_err();
593                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
594                        continue;
595                    };
596
597                    if let Some(choice) = data.choices.first() {
598                        match &choice.delta {
599                            StreamingDelta::Reasoning { reasoning } => {
600                                yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
601                                    id: None,
602                                    reasoning: reasoning.to_string(),
603                                });
604                            }
605
606                            StreamingDelta::MessageContent { content, tool_calls } => {
607                                // Handle tool calls
608                                for tool_call in tool_calls {
609                                    let function = &tool_call.function;
610
611                                    // Start of tool call
612                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
613                                        && empty_or_none(&function.arguments)
614                                    {
615                                        let id = tool_call.id.clone().unwrap_or_default();
616                                        let name = function.name.clone().unwrap();
617                                        calls.insert(tool_call.index, (id, name, String::new()));
618                                    }
619                                    // Continuation
620                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
621                                        && let Some(arguments) = &function.arguments
622                                        && !arguments.is_empty()
623                                    {
624                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
625                                            let combined = format!("{}{}", existing_args, arguments);
626                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
627                                        } else {
628                                            tracing::debug!("Partial tool call received but tool call was never started.");
629                                        }
630                                    }
631                                    // Complete tool call
632                                    else {
633                                        let id = tool_call.id.clone().unwrap_or_default();
634                                        let name = function.name.clone().unwrap_or_default();
635                                        let arguments_str = function.arguments.clone().unwrap_or_default();
636
637                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
638                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
639                                            continue;
640                                        };
641
642                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
643                                            crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
644                                        ));
645                                    }
646                                }
647
648                                // Streamed content
649                                if let Some(content) = content {
650                                    text_response += content;
651                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
652                                }
653                            }
654                        }
655                    }
656
657                    if let Some(usage) = data.usage {
658                        final_usage = usage.clone();
659                    }
660                }
661
662                Err(crate::http_client::Error::StreamEnded) => break,
663                Err(err) => {
664                    tracing::error!(?err, "SSE error");
665                    yield Err(CompletionError::ResponseError(err.to_string()));
666                    break;
667                }
668            }
669        }
670
671        event_source.close();
672
673        let mut tool_calls = Vec::new();
674        // Flush accumulated tool calls
675        for (_, (id, name, arguments)) in calls {
676            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
677                continue;
678            };
679
680            tool_calls.push(rig::providers::openai::completion::ToolCall {
681                id: id.clone(),
682                r#type: ToolType::Function,
683                function: Function {
684                    name: name.clone(),
685                    arguments: arguments_json.clone()
686                }
687            });
688            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
689                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
690            ));
691        }
692
693        let response_message = crate::providers::openai::completion::Message::Assistant {
694            content: vec![AssistantContent::Text { text: text_response }],
695            refusal: None,
696            audio: None,
697            name: None,
698            tool_calls
699        };
700
701        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
702        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
703        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
704
705        // Final response
706        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
707            StreamingCompletionResponse { usage: final_usage.clone() }
708        ));
709    }.instrument(span);
710
711    Ok(crate::streaming::StreamingCompletionResponse::stream(
712        Box::pin(stream),
713    ))
714}
715
716#[cfg(test)]
717mod tests {
718    use crate::{
719        OneOrMany,
720        providers::{
721            groq::{GroqAdditionalParameters, GroqCompletionRequest},
722            openai::{Message, UserContent},
723        },
724    };
725
726    #[test]
727    fn serialize_groq_request() {
728        let additional_params = GroqAdditionalParameters {
729            include_reasoning: Some(true),
730            reasoning_format: Some(super::ReasoningFormat::Parsed),
731            ..Default::default()
732        };
733
734        let groq = GroqCompletionRequest {
735            model: "openai/gpt-120b-oss".to_string(),
736            temperature: None,
737            tool_choice: None,
738            stream_options: None,
739            tools: Vec::new(),
740            messages: vec![Message::User {
741                content: OneOrMany::one(UserContent::Text {
742                    text: "Hello world!".to_string(),
743                }),
744                name: None,
745            }],
746            stream: false,
747            additional_params: Some(additional_params),
748        };
749
750        let json = serde_json::to_value(&groq).unwrap();
751
752        assert_eq!(
753            json,
754            serde_json::json!({
755                "model": "openai/gpt-120b-oss",
756                "messages": [
757                    {
758                        "role": "user",
759                        "content": [{
760                            "type": "text",
761                            "text": "Hello world!"
762                        }]
763                    }
764                ],
765                "stream": false,
766                "include_reasoning": true,
767                "reasoning_format": "parsed"
768            })
769        )
770    }
771}