Skip to main content

rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use serde_json::{Map, Value};
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19    CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23    ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    json_utils,
37    message::{self},
38    providers::openai::ToolDefinition,
39    transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43// ================================================================
44// Main Groq Client
45// ================================================================
46const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56    type Builder = GroqBuilder;
57    const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for GroqExt {
61    type Completion = Capable<CompletionModel<H>>;
62    type Embeddings = Nothing;
63    type Transcription = Capable<TranscriptionModel<H>>;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Nothing;
67
68    #[cfg(feature = "audio")]
69    type AudioGeneration = Nothing;
70}
71
72impl DebugExt for GroqExt {}
73
74impl ProviderBuilder for GroqBuilder {
75    type Extension<H>
76        = GroqExt
77    where
78        H: HttpClientExt;
79    type ApiKey = GroqApiKey;
80
81    const BASE_URL: &'static str = GROQ_API_BASE_URL;
82
83    fn build<H>(
84        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
85    ) -> http_client::Result<Self::Extension<H>>
86    where
87        H: HttpClientExt,
88    {
89        Ok(GroqExt)
90    }
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
94pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
95
96impl ProviderClient for Client {
97    type Input = String;
98
99    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
100    /// Panics if the environment variable is not set.
101    fn from_env() -> Self {
102        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
103        Self::new(&api_key).unwrap()
104    }
105
106    fn from_val(input: Self::Input) -> Self {
107        Self::new(&input).unwrap()
108    }
109}
110
111#[derive(Debug, Deserialize)]
112struct ApiErrorResponse {
113    message: String,
114}
115
116#[derive(Debug, Deserialize)]
117#[serde(untagged)]
118enum ApiResponse<T> {
119    Ok(T),
120    Err(ApiErrorResponse),
121}
122
123// ================================================================
124// Groq Completion API
125// ================================================================
126
127/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
128pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
129/// The `gemma2-9b-it` model. Used for chat completion.
130pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
131/// The `llama-3.1-8b-instant` model. Used for chat completion.
132pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
133/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
134pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
135/// The `llama-3.2-1b-preview` model. Used for chat completion.
136pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
137/// The `llama-3.2-3b-preview` model. Used for chat completion.
138pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
139/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
140pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
141/// The `llama-3.2-70b-specdec` model. Used for chat completion.
142pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
143/// The `llama-3.2-70b-versatile` model. Used for chat completion.
144pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
145/// The `llama-guard-3-8b` model. Used for chat completion.
146pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
147/// The `llama3-70b-8192` model. Used for chat completion.
148pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
149/// The `llama3-8b-8192` model. Used for chat completion.
150pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
151/// The `mixtral-8x7b-32768` model. Used for chat completion.
152pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum ReasoningFormat {
157    Parsed,
158    Raw,
159    Hidden,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub(super) struct GroqCompletionRequest {
164    model: String,
165    pub messages: Vec<OpenAIMessage>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    temperature: Option<f64>,
168    #[serde(skip_serializing_if = "Vec::is_empty")]
169    tools: Vec<ToolDefinition>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
172    #[serde(flatten, skip_serializing_if = "Option::is_none")]
173    pub additional_params: Option<GroqAdditionalParameters>,
174    pub(super) stream: bool,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub(super) stream_options: Option<StreamOptions>,
177}
178
179#[derive(Debug, Serialize, Deserialize, Default)]
180pub(super) struct StreamOptions {
181    pub(super) include_usage: bool,
182}
183
184impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
185    type Error = CompletionError;
186
187    fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
188        if req.output_schema.is_some() {
189            tracing::warn!("Structured outputs currently not supported for Groq");
190        }
191        let model = req.model.clone().unwrap_or_else(|| model.to_string());
192        // Build up the order of messages (context, chat_history, prompt)
193        let mut partial_history = vec![];
194        if let Some(docs) = req.normalized_documents() {
195            partial_history.push(docs);
196        }
197        partial_history.extend(req.chat_history);
198
199        // Add preamble to chat history (if available)
200        let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
201            Some(preamble) => vec![OpenAIMessage::system(preamble)],
202            None => vec![],
203        };
204
205        // Convert and extend the rest of the history
206        full_history.extend(
207            partial_history
208                .into_iter()
209                .map(message::Message::try_into)
210                .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
211                .into_iter()
212                .flatten()
213                .collect::<Vec<_>>(),
214        );
215
216        let tool_choice = req
217            .tool_choice
218            .clone()
219            .map(crate::providers::openai::ToolChoice::try_from)
220            .transpose()?;
221
222        let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
223        let native_tools =
224            extract_native_tools_from_additional_params(&mut additional_params_payload)?;
225
226        let mut additional_params: Option<GroqAdditionalParameters> =
227            if additional_params_payload.is_null() {
228                None
229            } else {
230                Some(serde_json::from_value(additional_params_payload)?)
231            };
232        apply_native_tools_to_additional_params(&mut additional_params, native_tools);
233
234        Ok(Self {
235            model: model.to_string(),
236            messages: full_history,
237            temperature: req.temperature,
238            tools: req
239                .tools
240                .clone()
241                .into_iter()
242                .map(ToolDefinition::from)
243                .collect::<Vec<_>>(),
244            tool_choice,
245            additional_params,
246            stream: false,
247            stream_options: None,
248        })
249    }
250}
251
252fn extract_native_tools_from_additional_params(
253    additional_params: &mut Value,
254) -> Result<Vec<Value>, CompletionError> {
255    if let Some(map) = additional_params.as_object_mut()
256        && let Some(raw_tools) = map.remove("tools")
257    {
258        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
259            CompletionError::RequestError(
260                format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
261            )
262        });
263    }
264
265    Ok(Vec::new())
266}
267
268fn apply_native_tools_to_additional_params(
269    additional_params: &mut Option<GroqAdditionalParameters>,
270    native_tools: Vec<Value>,
271) {
272    if native_tools.is_empty() {
273        return;
274    }
275
276    let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
277    let extra = params.extra.get_or_insert_with(Map::new);
278
279    let mut compound_custom = match extra.remove("compound_custom") {
280        Some(Value::Object(map)) => map,
281        _ => Map::new(),
282    };
283
284    let mut enabled_tools = match compound_custom.remove("enabled_tools") {
285        Some(Value::Array(values)) => values,
286        _ => Vec::new(),
287    };
288
289    for native_tool in native_tools {
290        let already_enabled = enabled_tools
291            .iter()
292            .any(|existing| native_tools_match(existing, &native_tool));
293        if !already_enabled {
294            enabled_tools.push(native_tool);
295        }
296    }
297
298    compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
299    extra.insert(
300        "compound_custom".to_string(),
301        Value::Object(compound_custom),
302    );
303}
304
305fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
306    if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
307        return lhs_type == rhs_type;
308    }
309
310    lhs == rhs
311}
312
313fn native_tool_kind(value: &Value) -> Option<&str> {
314    match value {
315        Value::String(kind) => Some(kind),
316        Value::Object(map) => map.get("type").and_then(Value::as_str),
317        _ => None,
318    }
319}
320
321/// Additional parameters to send to the Groq API
322#[derive(Clone, Debug, Default, Serialize, Deserialize)]
323pub struct GroqAdditionalParameters {
324    /// The reasoning format. See Groq's API docs for more details.
325    #[serde(skip_serializing_if = "Option::is_none")]
326    pub reasoning_format: Option<ReasoningFormat>,
327    /// Whether or not to include reasoning. See Groq's API docs for more details.
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub include_reasoning: Option<bool>,
330    /// Any other properties not included by default on this struct (that you want to send)
331    #[serde(flatten, skip_serializing_if = "Option::is_none")]
332    pub extra: Option<Map<String, serde_json::Value>>,
333}
334
335#[derive(Clone, Debug)]
336pub struct CompletionModel<T = reqwest::Client> {
337    client: Client<T>,
338    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
339    pub model: String,
340}
341
342impl<T> CompletionModel<T> {
343    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
344        Self {
345            client,
346            model: model.into(),
347        }
348    }
349}
350
351impl<T> completion::CompletionModel for CompletionModel<T>
352where
353    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
354{
355    type Response = CompletionResponse;
356    type StreamingResponse = StreamingCompletionResponse;
357
358    type Client = Client<T>;
359
360    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
361        Self::new(client.clone(), model)
362    }
363
364    async fn completion(
365        &self,
366        completion_request: CompletionRequest,
367    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
368        let span = if tracing::Span::current().is_disabled() {
369            info_span!(
370                target: "rig::completions",
371                "chat",
372                gen_ai.operation.name = "chat",
373                gen_ai.provider.name = "groq",
374                gen_ai.request.model = self.model,
375                gen_ai.system_instructions = tracing::field::Empty,
376                gen_ai.response.id = tracing::field::Empty,
377                gen_ai.response.model = tracing::field::Empty,
378                gen_ai.usage.output_tokens = tracing::field::Empty,
379                gen_ai.usage.input_tokens = tracing::field::Empty,
380                gen_ai.usage.cached_tokens = tracing::field::Empty,
381            )
382        } else {
383            tracing::Span::current()
384        };
385
386        span.record("gen_ai.system_instructions", &completion_request.preamble);
387
388        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
389
390        if tracing::enabled!(tracing::Level::TRACE) {
391            tracing::trace!(target: "rig::completions",
392                "Groq completion request: {}",
393                serde_json::to_string_pretty(&request)?
394            );
395        }
396
397        let body = serde_json::to_vec(&request)?;
398        let req = self
399            .client
400            .post("/chat/completions")?
401            .body(body)
402            .map_err(|e| http_client::Error::Instance(e.into()))?;
403
404        let async_block = async move {
405            let response = self.client.send::<_, Bytes>(req).await?;
406            let status = response.status();
407            let response_body = response.into_body().into_future().await?.to_vec();
408
409            if status.is_success() {
410                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
411                    ApiResponse::Ok(response) => {
412                        let span = tracing::Span::current();
413                        span.record("gen_ai.response.id", response.id.clone());
414                        span.record("gen_ai.response.model_name", response.model.clone());
415                        if let Some(ref usage) = response.usage {
416                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
417                            span.record(
418                                "gen_ai.usage.output_tokens",
419                                usage.total_tokens - usage.prompt_tokens,
420                            );
421                            span.record(
422                                "gen_ai.usage.cached_tokens",
423                                usage
424                                    .prompt_tokens_details
425                                    .as_ref()
426                                    .map(|d| d.cached_tokens)
427                                    .unwrap_or(0),
428                            );
429                        }
430
431                        if tracing::enabled!(tracing::Level::TRACE) {
432                            tracing::trace!(target: "rig::completions",
433                                "Groq completion response: {}",
434                                serde_json::to_string_pretty(&response)?
435                            );
436                        }
437
438                        response.try_into()
439                    }
440                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
441                }
442            } else {
443                Err(CompletionError::ProviderError(
444                    String::from_utf8_lossy(&response_body).to_string(),
445                ))
446            }
447        };
448
449        tracing::Instrument::instrument(async_block, span).await
450    }
451
452    async fn stream(
453        &self,
454        request: CompletionRequest,
455    ) -> Result<
456        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
457        CompletionError,
458    > {
459        let span = if tracing::Span::current().is_disabled() {
460            info_span!(
461                target: "rig::completions",
462                "chat_streaming",
463                gen_ai.operation.name = "chat_streaming",
464                gen_ai.provider.name = "groq",
465                gen_ai.request.model = self.model,
466                gen_ai.system_instructions = tracing::field::Empty,
467                gen_ai.response.id = tracing::field::Empty,
468                gen_ai.response.model = tracing::field::Empty,
469                gen_ai.usage.output_tokens = tracing::field::Empty,
470                gen_ai.usage.input_tokens = tracing::field::Empty,
471                gen_ai.usage.cached_tokens = tracing::field::Empty,
472            )
473        } else {
474            tracing::Span::current()
475        };
476
477        span.record("gen_ai.system_instructions", &request.preamble);
478
479        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
480
481        request.stream = true;
482        request.stream_options = Some(StreamOptions {
483            include_usage: true,
484        });
485
486        if tracing::enabled!(tracing::Level::TRACE) {
487            tracing::trace!(target: "rig::completions",
488                "Groq streaming completion request: {}",
489                serde_json::to_string_pretty(&request)?
490            );
491        }
492
493        let body = serde_json::to_vec(&request)?;
494        let req = self
495            .client
496            .post("/chat/completions")?
497            .body(body)
498            .map_err(|e| http_client::Error::Instance(e.into()))?;
499
500        tracing::Instrument::instrument(
501            send_compatible_streaming_request(self.client.clone(), req),
502            span,
503        )
504        .await
505    }
506}
507
508// ================================================================
509// Groq Transcription API
510// ================================================================
511
512pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
513pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
514pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
515
516#[derive(Clone)]
517pub struct TranscriptionModel<T> {
518    client: Client<T>,
519    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
520    pub model: String,
521}
522
523impl<T> TranscriptionModel<T> {
524    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
525        Self {
526            client,
527            model: model.into(),
528        }
529    }
530}
531impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
532where
533    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
534{
535    type Response = TranscriptionResponse;
536
537    type Client = Client<T>;
538
539    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
540        Self::new(client.clone(), model)
541    }
542
543    async fn transcription(
544        &self,
545        request: transcription::TranscriptionRequest,
546    ) -> Result<
547        transcription::TranscriptionResponse<Self::Response>,
548        transcription::TranscriptionError,
549    > {
550        let data = request.data;
551
552        let mut body = MultipartForm::new()
553            .text("model", self.model.clone())
554            .part(Part::bytes("file", data).filename(request.filename.clone()));
555
556        if let Some(language) = request.language {
557            body = body.text("language", language);
558        }
559
560        if let Some(prompt) = request.prompt {
561            body = body.text("prompt", prompt.clone());
562        }
563
564        if let Some(ref temperature) = request.temperature {
565            body = body.text("temperature", temperature.to_string());
566        }
567
568        if let Some(ref additional_params) = request.additional_params {
569            for (key, value) in additional_params
570                .as_object()
571                .expect("Additional Parameters to OpenAI Transcription should be a map")
572            {
573                body = body.text(key.to_owned(), value.to_string());
574            }
575        }
576
577        let req = self
578            .client
579            .post("/audio/transcriptions")?
580            .body(body)
581            .unwrap();
582
583        let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
584
585        let status = response.status();
586        let response_body = response.into_body().into_future().await?.to_vec();
587
588        if status.is_success() {
589            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
590                ApiResponse::Ok(response) => response.try_into(),
591                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
592                    api_error_response.message,
593                )),
594            }
595        } else {
596            Err(TranscriptionError::ProviderError(
597                String::from_utf8_lossy(&response_body).to_string(),
598            ))
599        }
600    }
601}
602
603#[derive(Deserialize, Debug)]
604#[serde(untagged)]
605enum StreamingDelta {
606    Reasoning {
607        reasoning: String,
608    },
609    MessageContent {
610        #[serde(default)]
611        content: Option<String>,
612        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
613        tool_calls: Vec<StreamingToolCall>,
614    },
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingChoice {
619    delta: StreamingDelta,
620}
621
622#[derive(Deserialize, Debug)]
623struct StreamingCompletionChunk {
624    choices: Vec<StreamingChoice>,
625    usage: Option<Usage>,
626}
627
628#[derive(Clone, Deserialize, Serialize, Debug)]
629pub struct StreamingCompletionResponse {
630    pub usage: Usage,
631}
632
633impl GetTokenUsage for StreamingCompletionResponse {
634    fn token_usage(&self) -> Option<crate::completion::Usage> {
635        let mut usage = crate::completion::Usage::new();
636
637        usage.input_tokens = self.usage.prompt_tokens as u64;
638        usage.total_tokens = self.usage.total_tokens as u64;
639        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
640        usage.cached_input_tokens = self
641            .usage
642            .prompt_tokens_details
643            .as_ref()
644            .map(|d| d.cached_tokens as u64)
645            .unwrap_or(0);
646
647        Some(usage)
648    }
649}
650
651pub async fn send_compatible_streaming_request<T>(
652    client: T,
653    req: Request<Vec<u8>>,
654) -> Result<
655    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
656    CompletionError,
657>
658where
659    T: HttpClientExt + Clone + 'static,
660{
661    let span = tracing::Span::current();
662
663    let mut event_source = GenericEventSource::new(client, req);
664
665    let stream = stream! {
666        let span = tracing::Span::current();
667        let mut final_usage = Usage {
668            prompt_tokens: 0,
669            total_tokens: 0,
670            prompt_tokens_details: None,
671        };
672
673        let mut text_response = String::new();
674
675        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
676
677        while let Some(event_result) = event_source.next().await {
678            match event_result {
679                Ok(Event::Open) => {
680                    tracing::trace!("SSE connection opened");
681                    continue;
682                }
683
684                Ok(Event::Message(message)) => {
685                    let data_str = message.data.trim();
686
687                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
688                    let Ok(data) = parsed else {
689                        let err = parsed.unwrap_err();
690                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
691                        continue;
692                    };
693
694                    if let Some(choice) = data.choices.first() {
695                        match &choice.delta {
696                            StreamingDelta::Reasoning { reasoning } => {
697                                yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
698                                    id: None,
699                                    reasoning: reasoning.to_string(),
700                                });
701                            }
702
703                            StreamingDelta::MessageContent { content, tool_calls } => {
704                                // Handle tool calls
705                                for tool_call in tool_calls {
706                                    let function = &tool_call.function;
707
708                                    // Start of tool call
709                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
710                                        && empty_or_none(&function.arguments)
711                                    {
712                                        let id = tool_call.id.clone().unwrap_or_default();
713                                        let name = function.name.clone().unwrap();
714                                        calls.insert(tool_call.index, (id, name, String::new()));
715                                    }
716                                    // Continuation
717                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
718                                        && let Some(arguments) = &function.arguments
719                                        && !arguments.is_empty()
720                                    {
721                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
722                                            let combined = format!("{}{}", existing_args, arguments);
723                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
724                                        } else {
725                                            tracing::debug!("Partial tool call received but tool call was never started.");
726                                        }
727                                    }
728                                    // Complete tool call
729                                    else {
730                                        let id = tool_call.id.clone().unwrap_or_default();
731                                        let name = function.name.clone().unwrap_or_default();
732                                        let arguments_str = function.arguments.clone().unwrap_or_default();
733
734                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
735                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
736                                            continue;
737                                        };
738
739                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
740                                            crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
741                                        ));
742                                    }
743                                }
744
745                                // Streamed content
746                                if let Some(content) = content {
747                                    text_response += content;
748                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
749                                }
750                            }
751                        }
752                    }
753
754                    if let Some(usage) = data.usage {
755                        final_usage = usage.clone();
756                    }
757                }
758
759                Err(crate::http_client::Error::StreamEnded) => break,
760                Err(err) => {
761                    tracing::error!(?err, "SSE error");
762                    yield Err(CompletionError::ResponseError(err.to_string()));
763                    break;
764                }
765            }
766        }
767
768        event_source.close();
769
770        let mut tool_calls = Vec::new();
771        // Flush accumulated tool calls
772        for (_, (id, name, arguments)) in calls {
773            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
774                continue;
775            };
776
777            tool_calls.push(rig::providers::openai::completion::ToolCall {
778                id: id.clone(),
779                r#type: ToolType::Function,
780                function: Function {
781                    name: name.clone(),
782                    arguments: arguments_json.clone()
783                }
784            });
785            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
786                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
787            ));
788        }
789
790        let response_message = crate::providers::openai::completion::Message::Assistant {
791            content: vec![AssistantContent::Text { text: text_response }],
792            refusal: None,
793            audio: None,
794            name: None,
795            tool_calls
796        };
797
798        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
799        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
800        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
801        span.record(
802            "gen_ai.usage.cached_tokens",
803            final_usage
804                .prompt_tokens_details
805                .as_ref()
806                .map(|d| d.cached_tokens)
807                .unwrap_or(0),
808        );
809
810        // Final response
811        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
812            StreamingCompletionResponse { usage: final_usage.clone() }
813        ));
814    }.instrument(span);
815
816    Ok(crate::streaming::StreamingCompletionResponse::stream(
817        Box::pin(stream),
818    ))
819}
820
821#[cfg(test)]
822mod tests {
823    use crate::{
824        OneOrMany,
825        providers::{
826            groq::{GroqAdditionalParameters, GroqCompletionRequest},
827            openai::{Message, UserContent},
828        },
829    };
830
831    #[test]
832    fn serialize_groq_request() {
833        let additional_params = GroqAdditionalParameters {
834            include_reasoning: Some(true),
835            reasoning_format: Some(super::ReasoningFormat::Parsed),
836            ..Default::default()
837        };
838
839        let groq = GroqCompletionRequest {
840            model: "openai/gpt-120b-oss".to_string(),
841            temperature: None,
842            tool_choice: None,
843            stream_options: None,
844            tools: Vec::new(),
845            messages: vec![Message::User {
846                content: OneOrMany::one(UserContent::Text {
847                    text: "Hello world!".to_string(),
848                }),
849                name: None,
850            }],
851            stream: false,
852            additional_params: Some(additional_params),
853        };
854
855        let json = serde_json::to_value(&groq).unwrap();
856
857        assert_eq!(
858            json,
859            serde_json::json!({
860                "model": "openai/gpt-120b-oss",
861                "messages": [
862                    {
863                        "role": "user",
864                        "content": "Hello world!"
865                    }
866                ],
867                "stream": false,
868                "include_reasoning": true,
869                "reasoning_format": "parsed"
870            })
871        )
872    }
873    #[test]
874    fn test_client_initialization() {
875        let _client =
876            crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
877        let _client_from_builder = crate::providers::groq::Client::builder()
878            .api_key("dummy-key")
879            .build()
880            .expect("Client::builder() failed");
881    }
882}