Skip to main content

rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use serde_json::Map;
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19    CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23    ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    json_utils,
37    message::{self},
38    providers::openai::ToolDefinition,
39    transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43// ================================================================
44// Main Groq Client
45// ================================================================
46const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56    type Builder = GroqBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for GroqExt {
61    type Completion = Capable<CompletionModel<H>>;
62    type Embeddings = Nothing;
63    type Transcription = Capable<TranscriptionModel<H>>;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Nothing;
67
68    #[cfg(feature = "audio")]
69    type AudioGeneration = Nothing;
70}
71
72impl DebugExt for GroqExt {}
73
74impl ProviderBuilder for GroqBuilder {
75    type Extension<H>
76        = GroqExt
77    where
78        H: HttpClientExt;
79    type ApiKey = GroqApiKey;
80
81    const BASE_URL: &'static str = GROQ_API_BASE_URL;
82
83    fn build<H>(
84        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
85    ) -> http_client::Result<Self::Extension<H>>
86    where
87        H: HttpClientExt,
88    {
89        Ok(GroqExt)
90    }
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
94pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
95
96impl ProviderClient for Client {
97    type Input = String;
98
99    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
100    /// Panics if the environment variable is not set.
101    fn from_env() -> Self {
102        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
103        Self::new(&api_key).unwrap()
104    }
105
106    fn from_val(input: Self::Input) -> Self {
107        Self::new(&input).unwrap()
108    }
109}
110
111#[derive(Debug, Deserialize)]
112struct ApiErrorResponse {
113    message: String,
114}
115
116#[derive(Debug, Deserialize)]
117#[serde(untagged)]
118enum ApiResponse<T> {
119    Ok(T),
120    Err(ApiErrorResponse),
121}
122
123// ================================================================
124// Groq Completion API
125// ================================================================
126
127/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
128pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
129/// The `gemma2-9b-it` model. Used for chat completion.
130pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
131/// The `llama-3.1-8b-instant` model. Used for chat completion.
132pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
133/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
134pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
135/// The `llama-3.2-1b-preview` model. Used for chat completion.
136pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
137/// The `llama-3.2-3b-preview` model. Used for chat completion.
138pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
139/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
140pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
141/// The `llama-3.2-70b-specdec` model. Used for chat completion.
142pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
143/// The `llama-3.2-70b-versatile` model. Used for chat completion.
144pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
145/// The `llama-guard-3-8b` model. Used for chat completion.
146pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
147/// The `llama3-70b-8192` model. Used for chat completion.
148pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
149/// The `llama3-8b-8192` model. Used for chat completion.
150pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
151/// The `mixtral-8x7b-32768` model. Used for chat completion.
152pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum ReasoningFormat {
157    Parsed,
158    Raw,
159    Hidden,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub(super) struct GroqCompletionRequest {
164    model: String,
165    pub messages: Vec<OpenAIMessage>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    temperature: Option<f64>,
168    #[serde(skip_serializing_if = "Vec::is_empty")]
169    tools: Vec<ToolDefinition>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
172    #[serde(flatten, skip_serializing_if = "Option::is_none")]
173    pub additional_params: Option<GroqAdditionalParameters>,
174    pub(super) stream: bool,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub(super) stream_options: Option<StreamOptions>,
177}
178
179#[derive(Debug, Serialize, Deserialize, Default)]
180pub(super) struct StreamOptions {
181    pub(super) include_usage: bool,
182}
183
184impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
185    type Error = CompletionError;
186
187    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
188        if req.output_schema.is_some() {
189            tracing::warn!("Structured outputs currently not supported for Groq");
190        }
191        let model = req.model.clone().unwrap_or_else(|| model.to_string());
192        // Build up the order of messages (context, chat_history, prompt)
193        let mut partial_history = vec![];
194        if let Some(docs) = req.normalized_documents() {
195            partial_history.push(docs);
196        }
197        partial_history.extend(req.chat_history);
198
199        // Add preamble to chat history (if available)
200        let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
201            Some(preamble) => vec![OpenAIMessage::system(preamble)],
202            None => vec![],
203        };
204
205        // Convert and extend the rest of the history
206        full_history.extend(
207            partial_history
208                .into_iter()
209                .map(message::Message::try_into)
210                .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
211                .into_iter()
212                .flatten()
213                .collect::<Vec<_>>(),
214        );
215
216        let tool_choice = req
217            .tool_choice
218            .clone()
219            .map(crate::providers::openai::ToolChoice::try_from)
220            .transpose()?;
221
222        let additional_params: Option<GroqAdditionalParameters> =
223            if let Some(params) = req.additional_params {
224                Some(serde_json::from_value(params)?)
225            } else {
226                None
227            };
228
229        Ok(Self {
230            model: model.to_string(),
231            messages: full_history,
232            temperature: req.temperature,
233            tools: req
234                .tools
235                .clone()
236                .into_iter()
237                .map(ToolDefinition::from)
238                .collect::<Vec<_>>(),
239            tool_choice,
240            additional_params,
241            stream: false,
242            stream_options: None,
243        })
244    }
245}
246
247/// Additional parameters to send to the Groq API
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub struct GroqAdditionalParameters {
250    /// The reasoning format. See Groq's API docs for more details.
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub reasoning_format: Option<ReasoningFormat>,
253    /// Whether or not to include reasoning. See Groq's API docs for more details.
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub include_reasoning: Option<bool>,
256    /// Any other properties not included by default on this struct (that you want to send)
257    #[serde(flatten, skip_serializing_if = "Option::is_none")]
258    pub extra: Option<Map<String, serde_json::Value>>,
259}
260
261#[derive(Clone, Debug)]
262pub struct CompletionModel<T = reqwest::Client> {
263    client: Client<T>,
264    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
265    pub model: String,
266}
267
268impl<T> CompletionModel<T> {
269    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
270        Self {
271            client,
272            model: model.into(),
273        }
274    }
275}
276
277impl<T> completion::CompletionModel for CompletionModel<T>
278where
279    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
280{
281    type Response = CompletionResponse;
282    type StreamingResponse = StreamingCompletionResponse;
283
284    type Client = Client<T>;
285
286    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
287        Self::new(client.clone(), model)
288    }
289
290    async fn completion(
291        &self,
292        completion_request: CompletionRequest,
293    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
294        let span = if tracing::Span::current().is_disabled() {
295            info_span!(
296                target: "rig::completions",
297                "chat",
298                gen_ai.operation.name = "chat",
299                gen_ai.provider.name = "groq",
300                gen_ai.request.model = self.model,
301                gen_ai.system_instructions = tracing::field::Empty,
302                gen_ai.response.id = tracing::field::Empty,
303                gen_ai.response.model = tracing::field::Empty,
304                gen_ai.usage.output_tokens = tracing::field::Empty,
305                gen_ai.usage.input_tokens = tracing::field::Empty,
306            )
307        } else {
308            tracing::Span::current()
309        };
310
311        span.record("gen_ai.system_instructions", &completion_request.preamble);
312
313        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
314
315        if tracing::enabled!(tracing::Level::TRACE) {
316            tracing::trace!(target: "rig::completions",
317                "Groq completion request: {}",
318                serde_json::to_string_pretty(&request)?
319            );
320        }
321
322        let body = serde_json::to_vec(&request)?;
323        let req = self
324            .client
325            .post("/chat/completions")?
326            .body(body)
327            .map_err(|e| http_client::Error::Instance(e.into()))?;
328
329        let async_block = async move {
330            let response = self.client.send::<_, Bytes>(req).await?;
331            let status = response.status();
332            let response_body = response.into_body().into_future().await?.to_vec();
333
334            if status.is_success() {
335                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
336                    ApiResponse::Ok(response) => {
337                        let span = tracing::Span::current();
338                        span.record("gen_ai.response.id", response.id.clone());
339                        span.record("gen_ai.response.model_name", response.model.clone());
340                        if let Some(ref usage) = response.usage {
341                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
342                            span.record(
343                                "gen_ai.usage.output_tokens",
344                                usage.total_tokens - usage.prompt_tokens,
345                            );
346                        }
347
348                        if tracing::enabled!(tracing::Level::TRACE) {
349                            tracing::trace!(target: "rig::completions",
350                                "Groq completion response: {}",
351                                serde_json::to_string_pretty(&response)?
352                            );
353                        }
354
355                        response.try_into()
356                    }
357                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
358                }
359            } else {
360                Err(CompletionError::ProviderError(
361                    String::from_utf8_lossy(&response_body).to_string(),
362                ))
363            }
364        };
365
366        tracing::Instrument::instrument(async_block, span).await
367    }
368
369    async fn stream(
370        &self,
371        request: CompletionRequest,
372    ) -> Result<
373        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
374        CompletionError,
375    > {
376        let span = if tracing::Span::current().is_disabled() {
377            info_span!(
378                target: "rig::completions",
379                "chat_streaming",
380                gen_ai.operation.name = "chat_streaming",
381                gen_ai.provider.name = "groq",
382                gen_ai.request.model = self.model,
383                gen_ai.system_instructions = tracing::field::Empty,
384                gen_ai.response.id = tracing::field::Empty,
385                gen_ai.response.model = tracing::field::Empty,
386                gen_ai.usage.output_tokens = tracing::field::Empty,
387                gen_ai.usage.input_tokens = tracing::field::Empty,
388            )
389        } else {
390            tracing::Span::current()
391        };
392
393        span.record("gen_ai.system_instructions", &request.preamble);
394
395        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
396
397        request.stream = true;
398        request.stream_options = Some(StreamOptions {
399            include_usage: true,
400        });
401
402        if tracing::enabled!(tracing::Level::TRACE) {
403            tracing::trace!(target: "rig::completions",
404                "Groq streaming completion request: {}",
405                serde_json::to_string_pretty(&request)?
406            );
407        }
408
409        let body = serde_json::to_vec(&request)?;
410        let req = self
411            .client
412            .post("/chat/completions")?
413            .body(body)
414            .map_err(|e| http_client::Error::Instance(e.into()))?;
415
416        tracing::Instrument::instrument(
417            send_compatible_streaming_request(self.client.clone(), req),
418            span,
419        )
420        .await
421    }
422}
423
424// ================================================================
425// Groq Transcription API
426// ================================================================
427
428pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
429pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
430pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
431
432#[derive(Clone)]
433pub struct TranscriptionModel<T> {
434    client: Client<T>,
435    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
436    pub model: String,
437}
438
439impl<T> TranscriptionModel<T> {
440    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
441        Self {
442            client,
443            model: model.into(),
444        }
445    }
446}
447impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
448where
449    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
450{
451    type Response = TranscriptionResponse;
452
453    type Client = Client<T>;
454
455    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
456        Self::new(client.clone(), model)
457    }
458
459    async fn transcription(
460        &self,
461        request: transcription::TranscriptionRequest,
462    ) -> Result<
463        transcription::TranscriptionResponse<Self::Response>,
464        transcription::TranscriptionError,
465    > {
466        let data = request.data;
467
468        let mut body = MultipartForm::new()
469            .text("model", self.model.clone())
470            .part(Part::bytes("file", data).filename(request.filename.clone()));
471
472        if let Some(language) = request.language {
473            body = body.text("language", language);
474        }
475
476        if let Some(prompt) = request.prompt {
477            body = body.text("prompt", prompt.clone());
478        }
479
480        if let Some(ref temperature) = request.temperature {
481            body = body.text("temperature", temperature.to_string());
482        }
483
484        if let Some(ref additional_params) = request.additional_params {
485            for (key, value) in additional_params
486                .as_object()
487                .expect("Additional Parameters to OpenAI Transcription should be a map")
488            {
489                body = body.text(key.to_owned(), value.to_string());
490            }
491        }
492
493        let req = self
494            .client
495            .post("/audio/transcriptions")?
496            .body(body)
497            .unwrap();
498
499        let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
500
501        let status = response.status();
502        let response_body = response.into_body().into_future().await?.to_vec();
503
504        if status.is_success() {
505            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
506                ApiResponse::Ok(response) => response.try_into(),
507                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
508                    api_error_response.message,
509                )),
510            }
511        } else {
512            Err(TranscriptionError::ProviderError(
513                String::from_utf8_lossy(&response_body).to_string(),
514            ))
515        }
516    }
517}
518
519#[derive(Deserialize, Debug)]
520#[serde(untagged)]
521enum StreamingDelta {
522    Reasoning {
523        reasoning: String,
524    },
525    MessageContent {
526        #[serde(default)]
527        content: Option<String>,
528        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
529        tool_calls: Vec<StreamingToolCall>,
530    },
531}
532
533#[derive(Deserialize, Debug)]
534struct StreamingChoice {
535    delta: StreamingDelta,
536}
537
538#[derive(Deserialize, Debug)]
539struct StreamingCompletionChunk {
540    choices: Vec<StreamingChoice>,
541    usage: Option<Usage>,
542}
543
544#[derive(Clone, Deserialize, Serialize, Debug)]
545pub struct StreamingCompletionResponse {
546    pub usage: Usage,
547}
548
549impl GetTokenUsage for StreamingCompletionResponse {
550    fn token_usage(&self) -> Option<crate::completion::Usage> {
551        let mut usage = crate::completion::Usage::new();
552
553        usage.input_tokens = self.usage.prompt_tokens as u64;
554        usage.total_tokens = self.usage.total_tokens as u64;
555        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
556        usage.cached_input_tokens = self
557            .usage
558            .prompt_tokens_details
559            .as_ref()
560            .map(|d| d.cached_tokens as u64)
561            .unwrap_or(0);
562
563        Some(usage)
564    }
565}
566
567pub async fn send_compatible_streaming_request<T>(
568    client: T,
569    req: Request<Vec<u8>>,
570) -> Result<
571    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
572    CompletionError,
573>
574where
575    T: HttpClientExt + Clone + 'static,
576{
577    let span = tracing::Span::current();
578
579    let mut event_source = GenericEventSource::new(client, req);
580
581    let stream = stream! {
582        let span = tracing::Span::current();
583        let mut final_usage = Usage {
584            prompt_tokens: 0,
585            total_tokens: 0,
586            prompt_tokens_details: None,
587        };
588
589        let mut text_response = String::new();
590
591        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
592
593        while let Some(event_result) = event_source.next().await {
594            match event_result {
595                Ok(Event::Open) => {
596                    tracing::trace!("SSE connection opened");
597                    continue;
598                }
599
600                Ok(Event::Message(message)) => {
601                    let data_str = message.data.trim();
602
603                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
604                    let Ok(data) = parsed else {
605                        let err = parsed.unwrap_err();
606                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
607                        continue;
608                    };
609
610                    if let Some(choice) = data.choices.first() {
611                        match &choice.delta {
612                            StreamingDelta::Reasoning { reasoning } => {
613                                yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
614                                    id: None,
615                                    reasoning: reasoning.to_string(),
616                                });
617                            }
618
619                            StreamingDelta::MessageContent { content, tool_calls } => {
620                                // Handle tool calls
621                                for tool_call in tool_calls {
622                                    let function = &tool_call.function;
623
624                                    // Start of tool call
625                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
626                                        && empty_or_none(&function.arguments)
627                                    {
628                                        let id = tool_call.id.clone().unwrap_or_default();
629                                        let name = function.name.clone().unwrap();
630                                        calls.insert(tool_call.index, (id, name, String::new()));
631                                    }
632                                    // Continuation
633                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
634                                        && let Some(arguments) = &function.arguments
635                                        && !arguments.is_empty()
636                                    {
637                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
638                                            let combined = format!("{}{}", existing_args, arguments);
639                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
640                                        } else {
641                                            tracing::debug!("Partial tool call received but tool call was never started.");
642                                        }
643                                    }
644                                    // Complete tool call
645                                    else {
646                                        let id = tool_call.id.clone().unwrap_or_default();
647                                        let name = function.name.clone().unwrap_or_default();
648                                        let arguments_str = function.arguments.clone().unwrap_or_default();
649
650                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
651                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
652                                            continue;
653                                        };
654
655                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
656                                            crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
657                                        ));
658                                    }
659                                }
660
661                                // Streamed content
662                                if let Some(content) = content {
663                                    text_response += content;
664                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
665                                }
666                            }
667                        }
668                    }
669
670                    if let Some(usage) = data.usage {
671                        final_usage = usage.clone();
672                    }
673                }
674
675                Err(crate::http_client::Error::StreamEnded) => break,
676                Err(err) => {
677                    tracing::error!(?err, "SSE error");
678                    yield Err(CompletionError::ResponseError(err.to_string()));
679                    break;
680                }
681            }
682        }
683
684        event_source.close();
685
686        let mut tool_calls = Vec::new();
687        // Flush accumulated tool calls
688        for (_, (id, name, arguments)) in calls {
689            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
690                continue;
691            };
692
693            tool_calls.push(rig::providers::openai::completion::ToolCall {
694                id: id.clone(),
695                r#type: ToolType::Function,
696                function: Function {
697                    name: name.clone(),
698                    arguments: arguments_json.clone()
699                }
700            });
701            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
702                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
703            ));
704        }
705
706        let response_message = crate::providers::openai::completion::Message::Assistant {
707            content: vec![AssistantContent::Text { text: text_response }],
708            refusal: None,
709            audio: None,
710            name: None,
711            tool_calls
712        };
713
714        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
715        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
716        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
717
718        // Final response
719        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
720            StreamingCompletionResponse { usage: final_usage.clone() }
721        ));
722    }.instrument(span);
723
724    Ok(crate::streaming::StreamingCompletionResponse::stream(
725        Box::pin(stream),
726    ))
727}
728
729#[cfg(test)]
730mod tests {
731    use crate::{
732        OneOrMany,
733        providers::{
734            groq::{GroqAdditionalParameters, GroqCompletionRequest},
735            openai::{Message, UserContent},
736        },
737    };
738
739    #[test]
740    fn serialize_groq_request() {
741        let additional_params = GroqAdditionalParameters {
742            include_reasoning: Some(true),
743            reasoning_format: Some(super::ReasoningFormat::Parsed),
744            ..Default::default()
745        };
746
747        let groq = GroqCompletionRequest {
748            model: "openai/gpt-120b-oss".to_string(),
749            temperature: None,
750            tool_choice: None,
751            stream_options: None,
752            tools: Vec::new(),
753            messages: vec![Message::User {
754                content: OneOrMany::one(UserContent::Text {
755                    text: "Hello world!".to_string(),
756                }),
757                name: None,
758            }],
759            stream: false,
760            additional_params: Some(additional_params),
761        };
762
763        let json = serde_json::to_value(&groq).unwrap();
764
765        assert_eq!(
766            json,
767            serde_json::json!({
768                "model": "openai/gpt-120b-oss",
769                "messages": [
770                    {
771                        "role": "user",
772                        "content": "Hello world!"
773                    }
774                ],
775                "stream": false,
776                "include_reasoning": true,
777                "reasoning_format": "parsed"
778            })
779        )
780    }
781    #[test]
782    fn test_client_initialization() {
783        let _client =
784            crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
785        let _client_from_builder = crate::providers::groq::Client::builder()
786            .api_key("dummy-key")
787            .build()
788            .expect("Client::builder() failed");
789    }
790}