Skip to main content

rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19    CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23    ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    json_utils,
37    message::{self},
38    providers::openai::ToolDefinition,
39    transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43// ================================================================
44// Main Groq Client
45// ================================================================
46const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56    type Builder = GroqBuilder;
57
58    const VERIFY_PATH: &'static str = "/models";
59
60    fn build<H>(
61        _: &crate::client::ClientBuilder<
62            Self::Builder,
63            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
64            H,
65        >,
66    ) -> http_client::Result<Self> {
67        Ok(Self)
68    }
69}
70
71impl<H> Capabilities<H> for GroqExt {
72    type Completion = Capable<CompletionModel<H>>;
73    type Embeddings = Nothing;
74    type Transcription = Capable<TranscriptionModel<H>>;
75    type ModelListing = Nothing;
76    #[cfg(feature = "image")]
77    type ImageGeneration = Nothing;
78
79    #[cfg(feature = "audio")]
80    type AudioGeneration = Nothing;
81}
82
83impl DebugExt for GroqExt {}
84
85impl ProviderBuilder for GroqBuilder {
86    type Output = GroqExt;
87    type ApiKey = GroqApiKey;
88
89    const BASE_URL: &'static str = GROQ_API_BASE_URL;
90}
91
92pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
93pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
94
95impl ProviderClient for Client {
96    type Input = String;
97
98    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
99    /// Panics if the environment variable is not set.
100    fn from_env() -> Self {
101        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
102        Self::new(&api_key).unwrap()
103    }
104
105    fn from_val(input: Self::Input) -> Self {
106        Self::new(&input).unwrap()
107    }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112    message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118    Ok(T),
119    Err(ApiErrorResponse),
120}
121
122// ================================================================
123// Groq Completion API
124// ================================================================
125
126/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
127pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
128/// The `gemma2-9b-it` model. Used for chat completion.
129pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
130/// The `llama-3.1-8b-instant` model. Used for chat completion.
131pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
132/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
133pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
134/// The `llama-3.2-1b-preview` model. Used for chat completion.
135pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
136/// The `llama-3.2-3b-preview` model. Used for chat completion.
137pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
138/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
139pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
140/// The `llama-3.2-70b-specdec` model. Used for chat completion.
141pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
142/// The `llama-3.2-70b-versatile` model. Used for chat completion.
143pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
144/// The `llama-guard-3-8b` model. Used for chat completion.
145pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
146/// The `llama3-70b-8192` model. Used for chat completion.
147pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
148/// The `llama3-8b-8192` model. Used for chat completion.
149pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
150/// The `mixtral-8x7b-32768` model. Used for chat completion.
151pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
152
153#[derive(Clone, Debug, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ReasoningFormat {
156    Parsed,
157    Raw,
158    Hidden,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162pub(super) struct GroqCompletionRequest {
163    model: String,
164    pub messages: Vec<OpenAIMessage>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    temperature: Option<f64>,
167    #[serde(skip_serializing_if = "Vec::is_empty")]
168    tools: Vec<ToolDefinition>,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
171    #[serde(flatten, skip_serializing_if = "Option::is_none")]
172    pub additional_params: Option<GroqAdditionalParameters>,
173    pub(super) stream: bool,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub(super) stream_options: Option<StreamOptions>,
176}
177
178#[derive(Debug, Serialize, Deserialize, Default)]
179pub(super) struct StreamOptions {
180    pub(super) include_usage: bool,
181}
182
183impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
184    type Error = CompletionError;
185
186    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
187        if req.output_schema.is_some() {
188            tracing::warn!("Structured outputs currently not supported for Groq");
189        }
190        let model = req.model.clone().unwrap_or_else(|| model.to_string());
191        // Build up the order of messages (context, chat_history, prompt)
192        let mut partial_history = vec![];
193        if let Some(docs) = req.normalized_documents() {
194            partial_history.push(docs);
195        }
196        partial_history.extend(req.chat_history);
197
198        // Add preamble to chat history (if available)
199        let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
200            Some(preamble) => vec![OpenAIMessage::system(preamble)],
201            None => vec![],
202        };
203
204        // Convert and extend the rest of the history
205        full_history.extend(
206            partial_history
207                .into_iter()
208                .map(message::Message::try_into)
209                .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
210                .into_iter()
211                .flatten()
212                .collect::<Vec<_>>(),
213        );
214
215        let tool_choice = req
216            .tool_choice
217            .clone()
218            .map(crate::providers::openai::ToolChoice::try_from)
219            .transpose()?;
220
221        let additional_params: Option<GroqAdditionalParameters> =
222            if let Some(params) = req.additional_params {
223                Some(serde_json::from_value(params)?)
224            } else {
225                None
226            };
227
228        Ok(Self {
229            model: model.to_string(),
230            messages: full_history,
231            temperature: req.temperature,
232            tools: req
233                .tools
234                .clone()
235                .into_iter()
236                .map(ToolDefinition::from)
237                .collect::<Vec<_>>(),
238            tool_choice,
239            additional_params,
240            stream: false,
241            stream_options: None,
242        })
243    }
244}
245
246/// Additional parameters to send to the Groq API
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub struct GroqAdditionalParameters {
249    /// The reasoning format. See Groq's API docs for more details.
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub reasoning_format: Option<ReasoningFormat>,
252    /// Whether or not to include reasoning. See Groq's API docs for more details.
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub include_reasoning: Option<bool>,
255    /// Any other properties not included by default on this struct (that you want to send)
256    #[serde(flatten, skip_serializing_if = "Option::is_none")]
257    pub extra: Option<Map<String, serde_json::Value>>,
258}
259
260#[derive(Clone, Debug)]
261pub struct CompletionModel<T = reqwest::Client> {
262    client: Client<T>,
263    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
264    pub model: String,
265}
266
267impl<T> CompletionModel<T> {
268    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
269        Self {
270            client,
271            model: model.into(),
272        }
273    }
274}
275
276impl<T> completion::CompletionModel for CompletionModel<T>
277where
278    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
279{
280    type Response = CompletionResponse;
281    type StreamingResponse = StreamingCompletionResponse;
282
283    type Client = Client<T>;
284
285    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
286        Self::new(client.clone(), model)
287    }
288
289    async fn completion(
290        &self,
291        completion_request: CompletionRequest,
292    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
293        let span = if tracing::Span::current().is_disabled() {
294            info_span!(
295                target: "rig::completions",
296                "chat",
297                gen_ai.operation.name = "chat",
298                gen_ai.provider.name = "groq",
299                gen_ai.request.model = self.model,
300                gen_ai.system_instructions = tracing::field::Empty,
301                gen_ai.response.id = tracing::field::Empty,
302                gen_ai.response.model = tracing::field::Empty,
303                gen_ai.usage.output_tokens = tracing::field::Empty,
304                gen_ai.usage.input_tokens = tracing::field::Empty,
305            )
306        } else {
307            tracing::Span::current()
308        };
309
310        span.record("gen_ai.system_instructions", &completion_request.preamble);
311
312        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
313
314        if tracing::enabled!(tracing::Level::TRACE) {
315            tracing::trace!(target: "rig::completions",
316                "Groq completion request: {}",
317                serde_json::to_string_pretty(&request)?
318            );
319        }
320
321        let body = serde_json::to_vec(&request)?;
322        let req = self
323            .client
324            .post("/chat/completions")?
325            .body(body)
326            .map_err(|e| http_client::Error::Instance(e.into()))?;
327
328        let async_block = async move {
329            let response = self.client.send::<_, Bytes>(req).await?;
330            let status = response.status();
331            let response_body = response.into_body().into_future().await?.to_vec();
332
333            if status.is_success() {
334                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
335                    ApiResponse::Ok(response) => {
336                        let span = tracing::Span::current();
337                        span.record("gen_ai.response.id", response.id.clone());
338                        span.record("gen_ai.response.model_name", response.model.clone());
339                        if let Some(ref usage) = response.usage {
340                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
341                            span.record(
342                                "gen_ai.usage.output_tokens",
343                                usage.total_tokens - usage.prompt_tokens,
344                            );
345                        }
346
347                        if tracing::enabled!(tracing::Level::TRACE) {
348                            tracing::trace!(target: "rig::completions",
349                                "Groq completion response: {}",
350                                serde_json::to_string_pretty(&response)?
351                            );
352                        }
353
354                        response.try_into()
355                    }
356                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
357                }
358            } else {
359                Err(CompletionError::ProviderError(
360                    String::from_utf8_lossy(&response_body).to_string(),
361                ))
362            }
363        };
364
365        tracing::Instrument::instrument(async_block, span).await
366    }
367
368    async fn stream(
369        &self,
370        request: CompletionRequest,
371    ) -> Result<
372        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
373        CompletionError,
374    > {
375        let span = if tracing::Span::current().is_disabled() {
376            info_span!(
377                target: "rig::completions",
378                "chat_streaming",
379                gen_ai.operation.name = "chat_streaming",
380                gen_ai.provider.name = "groq",
381                gen_ai.request.model = self.model,
382                gen_ai.system_instructions = tracing::field::Empty,
383                gen_ai.response.id = tracing::field::Empty,
384                gen_ai.response.model = tracing::field::Empty,
385                gen_ai.usage.output_tokens = tracing::field::Empty,
386                gen_ai.usage.input_tokens = tracing::field::Empty,
387            )
388        } else {
389            tracing::Span::current()
390        };
391
392        span.record("gen_ai.system_instructions", &request.preamble);
393
394        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
395
396        request.stream = true;
397        request.stream_options = Some(StreamOptions {
398            include_usage: true,
399        });
400
401        if tracing::enabled!(tracing::Level::TRACE) {
402            tracing::trace!(target: "rig::completions",
403                "Groq streaming completion request: {}",
404                serde_json::to_string_pretty(&request)?
405            );
406        }
407
408        let body = serde_json::to_vec(&request)?;
409        let req = self
410            .client
411            .post("/chat/completions")?
412            .body(body)
413            .map_err(|e| http_client::Error::Instance(e.into()))?;
414
415        tracing::Instrument::instrument(
416            send_compatible_streaming_request(self.client.clone(), req),
417            span,
418        )
419        .await
420    }
421}
422
423// ================================================================
424// Groq Transcription API
425// ================================================================
426
427pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
428pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
429pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
430
431#[derive(Clone)]
432pub struct TranscriptionModel<T> {
433    client: Client<T>,
434    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
435    pub model: String,
436}
437
438impl<T> TranscriptionModel<T> {
439    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
440        Self {
441            client,
442            model: model.into(),
443        }
444    }
445}
446impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
447where
448    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
449{
450    type Response = TranscriptionResponse;
451
452    type Client = Client<T>;
453
454    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
455        Self::new(client.clone(), model)
456    }
457
458    async fn transcription(
459        &self,
460        request: transcription::TranscriptionRequest,
461    ) -> Result<
462        transcription::TranscriptionResponse<Self::Response>,
463        transcription::TranscriptionError,
464    > {
465        let data = request.data;
466
467        let mut body = MultipartForm::new()
468            .text("model", self.model.clone())
469            .part(Part::bytes("file", data).filename(request.filename.clone()));
470
471        if let Some(language) = request.language {
472            body = body.text("language", language);
473        }
474
475        if let Some(prompt) = request.prompt {
476            body = body.text("prompt", prompt.clone());
477        }
478
479        if let Some(ref temperature) = request.temperature {
480            body = body.text("temperature", temperature.to_string());
481        }
482
483        if let Some(ref additional_params) = request.additional_params {
484            for (key, value) in additional_params
485                .as_object()
486                .expect("Additional Parameters to OpenAI Transcription should be a map")
487            {
488                body = body.text(key.to_owned(), value.to_string());
489            }
490        }
491
492        let req = self
493            .client
494            .post("/audio/transcriptions")?
495            .body(body)
496            .unwrap();
497
498        let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
499
500        let status = response.status();
501        let response_body = response.into_body().into_future().await?.to_vec();
502
503        if status.is_success() {
504            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
505                ApiResponse::Ok(response) => response.try_into(),
506                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
507                    api_error_response.message,
508                )),
509            }
510        } else {
511            Err(TranscriptionError::ProviderError(
512                String::from_utf8_lossy(&response_body).to_string(),
513            ))
514        }
515    }
516}
517
518#[derive(Deserialize, Debug)]
519#[serde(untagged)]
520enum StreamingDelta {
521    Reasoning {
522        reasoning: String,
523    },
524    MessageContent {
525        #[serde(default)]
526        content: Option<String>,
527        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
528        tool_calls: Vec<StreamingToolCall>,
529    },
530}
531
532#[derive(Deserialize, Debug)]
533struct StreamingChoice {
534    delta: StreamingDelta,
535}
536
537#[derive(Deserialize, Debug)]
538struct StreamingCompletionChunk {
539    choices: Vec<StreamingChoice>,
540    usage: Option<Usage>,
541}
542
543#[derive(Clone, Deserialize, Serialize, Debug)]
544pub struct StreamingCompletionResponse {
545    pub usage: Usage,
546}
547
548impl GetTokenUsage for StreamingCompletionResponse {
549    fn token_usage(&self) -> Option<crate::completion::Usage> {
550        let mut usage = crate::completion::Usage::new();
551
552        usage.input_tokens = self.usage.prompt_tokens as u64;
553        usage.total_tokens = self.usage.total_tokens as u64;
554        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
555        usage.cached_input_tokens = self
556            .usage
557            .prompt_tokens_details
558            .as_ref()
559            .map(|d| d.cached_tokens as u64)
560            .unwrap_or(0);
561
562        Some(usage)
563    }
564}
565
566pub async fn send_compatible_streaming_request<T>(
567    client: T,
568    req: Request<Vec<u8>>,
569) -> Result<
570    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
571    CompletionError,
572>
573where
574    T: HttpClientExt + Clone + 'static,
575{
576    let span = tracing::Span::current();
577
578    let mut event_source = GenericEventSource::new(client, req);
579
580    let stream = stream! {
581        let span = tracing::Span::current();
582        let mut final_usage = Usage {
583            prompt_tokens: 0,
584            total_tokens: 0,
585            prompt_tokens_details: None,
586        };
587
588        let mut text_response = String::new();
589
590        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
591
592        while let Some(event_result) = event_source.next().await {
593            match event_result {
594                Ok(Event::Open) => {
595                    tracing::trace!("SSE connection opened");
596                    continue;
597                }
598
599                Ok(Event::Message(message)) => {
600                    let data_str = message.data.trim();
601
602                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
603                    let Ok(data) = parsed else {
604                        let err = parsed.unwrap_err();
605                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
606                        continue;
607                    };
608
609                    if let Some(choice) = data.choices.first() {
610                        match &choice.delta {
611                            StreamingDelta::Reasoning { reasoning } => {
612                                yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
613                                    id: None,
614                                    reasoning: reasoning.to_string(),
615                                });
616                            }
617
618                            StreamingDelta::MessageContent { content, tool_calls } => {
619                                // Handle tool calls
620                                for tool_call in tool_calls {
621                                    let function = &tool_call.function;
622
623                                    // Start of tool call
624                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
625                                        && empty_or_none(&function.arguments)
626                                    {
627                                        let id = tool_call.id.clone().unwrap_or_default();
628                                        let name = function.name.clone().unwrap();
629                                        calls.insert(tool_call.index, (id, name, String::new()));
630                                    }
631                                    // Continuation
632                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
633                                        && let Some(arguments) = &function.arguments
634                                        && !arguments.is_empty()
635                                    {
636                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
637                                            let combined = format!("{}{}", existing_args, arguments);
638                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
639                                        } else {
640                                            tracing::debug!("Partial tool call received but tool call was never started.");
641                                        }
642                                    }
643                                    // Complete tool call
644                                    else {
645                                        let id = tool_call.id.clone().unwrap_or_default();
646                                        let name = function.name.clone().unwrap_or_default();
647                                        let arguments_str = function.arguments.clone().unwrap_or_default();
648
649                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
650                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
651                                            continue;
652                                        };
653
654                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
655                                            crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
656                                        ));
657                                    }
658                                }
659
660                                // Streamed content
661                                if let Some(content) = content {
662                                    text_response += content;
663                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
664                                }
665                            }
666                        }
667                    }
668
669                    if let Some(usage) = data.usage {
670                        final_usage = usage.clone();
671                    }
672                }
673
674                Err(crate::http_client::Error::StreamEnded) => break,
675                Err(err) => {
676                    tracing::error!(?err, "SSE error");
677                    yield Err(CompletionError::ResponseError(err.to_string()));
678                    break;
679                }
680            }
681        }
682
683        event_source.close();
684
685        let mut tool_calls = Vec::new();
686        // Flush accumulated tool calls
687        for (_, (id, name, arguments)) in calls {
688            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
689                continue;
690            };
691
692            tool_calls.push(rig::providers::openai::completion::ToolCall {
693                id: id.clone(),
694                r#type: ToolType::Function,
695                function: Function {
696                    name: name.clone(),
697                    arguments: arguments_json.clone()
698                }
699            });
700            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
701                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
702            ));
703        }
704
705        let response_message = crate::providers::openai::completion::Message::Assistant {
706            content: vec![AssistantContent::Text { text: text_response }],
707            refusal: None,
708            audio: None,
709            name: None,
710            tool_calls
711        };
712
713        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
714        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
715        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
716
717        // Final response
718        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
719            StreamingCompletionResponse { usage: final_usage.clone() }
720        ));
721    }.instrument(span);
722
723    Ok(crate::streaming::StreamingCompletionResponse::stream(
724        Box::pin(stream),
725    ))
726}
727
728#[cfg(test)]
729mod tests {
730    use crate::{
731        OneOrMany,
732        providers::{
733            groq::{GroqAdditionalParameters, GroqCompletionRequest},
734            openai::{Message, UserContent},
735        },
736    };
737
738    #[test]
739    fn serialize_groq_request() {
740        let additional_params = GroqAdditionalParameters {
741            include_reasoning: Some(true),
742            reasoning_format: Some(super::ReasoningFormat::Parsed),
743            ..Default::default()
744        };
745
746        let groq = GroqCompletionRequest {
747            model: "openai/gpt-120b-oss".to_string(),
748            temperature: None,
749            tool_choice: None,
750            stream_options: None,
751            tools: Vec::new(),
752            messages: vec![Message::User {
753                content: OneOrMany::one(UserContent::Text {
754                    text: "Hello world!".to_string(),
755                }),
756                name: None,
757            }],
758            stream: false,
759            additional_params: Some(additional_params),
760        };
761
762        let json = serde_json::to_value(&groq).unwrap();
763
764        assert_eq!(
765            json,
766            serde_json::json!({
767                "model": "openai/gpt-120b-oss",
768                "messages": [
769                    {
770                        "role": "user",
771                        "content": "Hello world!"
772                    }
773                ],
774                "stream": false,
775                "include_reasoning": true,
776                "reasoning_format": "parsed"
777            })
778        )
779    }
780    #[test]
781    fn test_client_initialization() {
782        let _client: crate::providers::groq::Client =
783            crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
784        let _client_from_builder: crate::providers::groq::Client =
785            crate::providers::groq::Client::builder()
786                .api_key("dummy-key")
787                .build()
788                .expect("Client::builder() failed");
789    }
790}