rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{
19    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
20    ProviderClient,
21};
22use crate::completion::GetTokenUsage;
23use crate::http_client::sse::{Event, GenericEventSource};
24use crate::http_client::{self, HttpClientExt};
25use crate::json_utils::empty_or_none;
26use crate::providers::openai::{AssistantContent, Function, ToolType};
27use async_stream::stream;
28use futures::StreamExt;
29
30use crate::{
31    OneOrMany,
32    completion::{self, CompletionError, CompletionRequest},
33    json_utils,
34    message::{self, MessageError},
35    providers::openai::ToolDefinition,
36    transcription::{self, TranscriptionError},
37};
38use reqwest::multipart::Part;
39use serde::{Deserialize, Serialize};
40
41// ================================================================
42// Main Groq Client
43// ================================================================
44const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqBuilder;
50
51type GroqApiKey = BearerAuth;
52
53impl Provider for GroqExt {
54    type Builder = GroqBuilder;
55
56    const VERIFY_PATH: &'static str = "/models";
57
58    fn build<H>(
59        _: &crate::client::ClientBuilder<
60            Self::Builder,
61            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
62            H,
63        >,
64    ) -> http_client::Result<Self> {
65        Ok(Self)
66    }
67}
68
69impl<H> Capabilities<H> for GroqExt {
70    type Completion = Capable<CompletionModel<H>>;
71    type Embeddings = Nothing;
72    type Transcription = Capable<TranscriptionModel<H>>;
73    #[cfg(feature = "image")]
74    type ImageGeneration = Nothing;
75
76    #[cfg(feature = "audio")]
77    type AudioGeneration = Nothing;
78}
79
80impl DebugExt for GroqExt {}
81
82impl ProviderBuilder for GroqBuilder {
83    type Output = GroqExt;
84    type ApiKey = GroqApiKey;
85
86    const BASE_URL: &'static str = GROQ_API_BASE_URL;
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93    type Input = String;
94
95    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
96    /// Panics if the environment variable is not set.
97    fn from_env() -> Self {
98        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
99        Self::new(&api_key).unwrap()
100    }
101
102    fn from_val(input: Self::Input) -> Self {
103        Self::new(&input).unwrap()
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub struct Message {
121    pub role: String,
122    pub content: Option<String>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub reasoning: Option<String>,
125}
126
127impl Message {
128    fn system(preamble: &str) -> Self {
129        Self {
130            role: "system".to_string(),
131            content: Some(preamble.to_string()),
132            reasoning: None,
133        }
134    }
135}
136
137impl TryFrom<Message> for message::Message {
138    type Error = message::MessageError;
139
140    fn try_from(message: Message) -> Result<Self, Self::Error> {
141        match message.role.as_str() {
142            "user" => Ok(Self::User {
143                content: OneOrMany::one(
144                    message
145                        .content
146                        .map(|content| message::UserContent::text(&content))
147                        .ok_or_else(|| {
148                            message::MessageError::ConversionError("Empty user message".to_string())
149                        })?,
150                ),
151            }),
152            "assistant" => Ok(Self::Assistant {
153                id: None,
154                content: OneOrMany::one(
155                    message
156                        .content
157                        .map(|content| message::AssistantContent::text(&content))
158                        .ok_or_else(|| {
159                            message::MessageError::ConversionError(
160                                "Empty assistant message".to_string(),
161                            )
162                        })?,
163                ),
164            }),
165            _ => Err(message::MessageError::ConversionError(format!(
166                "Unknown role: {}",
167                message.role
168            ))),
169        }
170    }
171}
172
173impl TryFrom<message::Message> for Message {
174    type Error = message::MessageError;
175
176    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
177        match message {
178            message::Message::User { content } => Ok(Self {
179                role: "user".to_string(),
180                content: content.iter().find_map(|c| match c {
181                    message::UserContent::Text(text) => Some(text.text.clone()),
182                    _ => None,
183                }),
184                reasoning: None,
185            }),
186            message::Message::Assistant { content, .. } => {
187                let mut text_content: Option<String> = None;
188                let mut groq_reasoning: Option<String> = None;
189
190                for c in content.iter() {
191                    match c {
192                        message::AssistantContent::Text(text) => {
193                            text_content = Some(
194                                text_content
195                                    .map(|mut existing| {
196                                        existing.push('\n');
197                                        existing.push_str(&text.text);
198                                        existing
199                                    })
200                                    .unwrap_or_else(|| text.text.clone()),
201                            );
202                        }
203                        message::AssistantContent::ToolCall(_tool_call) => {
204                            return Err(MessageError::ConversionError(
205                                "Tool calls do not exist on this message".into(),
206                            ));
207                        }
208                        message::AssistantContent::Reasoning(message::Reasoning {
209                            reasoning,
210                            ..
211                        }) => {
212                            groq_reasoning =
213                                Some(reasoning.first().cloned().unwrap_or(String::new()));
214                        }
215                        message::AssistantContent::Image(_) => {
216                            return Err(MessageError::ConversionError(
217                                "Ollama currently doesn't support images.".into(),
218                            ));
219                        }
220                    }
221                }
222
223                Ok(Self {
224                    role: "assistant".to_string(),
225                    content: text_content,
226                    reasoning: groq_reasoning,
227                })
228            }
229        }
230    }
231}
232
233// ================================================================
234// Groq Completion API
235// ================================================================
236
237/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
238pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
239/// The `gemma2-9b-it` model. Used for chat completion.
240pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
241/// The `llama-3.1-8b-instant` model. Used for chat completion.
242pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
243/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
244pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
245/// The `llama-3.2-1b-preview` model. Used for chat completion.
246pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
247/// The `llama-3.2-3b-preview` model. Used for chat completion.
248pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
249/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
250pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
251/// The `llama-3.2-70b-specdec` model. Used for chat completion.
252pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
253/// The `llama-3.2-70b-versatile` model. Used for chat completion.
254pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
255/// The `llama-guard-3-8b` model. Used for chat completion.
256pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
257/// The `llama3-70b-8192` model. Used for chat completion.
258pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
259/// The `llama3-8b-8192` model. Used for chat completion.
260pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
261/// The `mixtral-8x7b-32768` model. Used for chat completion.
262pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "lowercase")]
266pub enum ReasoningFormat {
267    Parsed,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub(super) struct GroqCompletionRequest {
272    model: String,
273    pub messages: Vec<Message>,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    temperature: Option<f64>,
276    #[serde(skip_serializing_if = "Vec::is_empty")]
277    tools: Vec<ToolDefinition>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
280    #[serde(flatten, skip_serializing_if = "Option::is_none")]
281    pub additional_params: Option<serde_json::Value>,
282    reasoning_format: ReasoningFormat,
283}
284
285impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
286    type Error = CompletionError;
287
288    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
289        // Build up the order of messages (context, chat_history, prompt)
290        let mut partial_history = vec![];
291        if let Some(docs) = req.normalized_documents() {
292            partial_history.push(docs);
293        }
294        partial_history.extend(req.chat_history);
295
296        // Add preamble to chat history (if available)
297        let mut full_history: Vec<Message> = match &req.preamble {
298            Some(preamble) => vec![Message::system(preamble)],
299            None => vec![],
300        };
301
302        // Convert and extend the rest of the history
303        full_history.extend(
304            partial_history
305                .into_iter()
306                .map(message::Message::try_into)
307                .collect::<Result<Vec<Message>, _>>()?,
308        );
309
310        let tool_choice = req
311            .tool_choice
312            .clone()
313            .map(crate::providers::openai::ToolChoice::try_from)
314            .transpose()?;
315
316        Ok(Self {
317            model: model.to_string(),
318            messages: full_history,
319            temperature: req.temperature,
320            tools: req
321                .tools
322                .clone()
323                .into_iter()
324                .map(ToolDefinition::from)
325                .collect::<Vec<_>>(),
326            tool_choice,
327            additional_params: req.additional_params,
328            reasoning_format: ReasoningFormat::Parsed,
329        })
330    }
331}
332
333#[derive(Clone, Debug)]
334pub struct CompletionModel<T = reqwest::Client> {
335    client: Client<T>,
336    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
337    pub model: String,
338}
339
340impl<T> CompletionModel<T> {
341    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
342        Self {
343            client,
344            model: model.into(),
345        }
346    }
347}
348
349impl<T> completion::CompletionModel for CompletionModel<T>
350where
351    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
352{
353    type Response = CompletionResponse;
354    type StreamingResponse = StreamingCompletionResponse;
355
356    type Client = Client<T>;
357
358    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
359        Self::new(client.clone(), model)
360    }
361
362    #[cfg_attr(feature = "worker", worker::send)]
363    async fn completion(
364        &self,
365        completion_request: CompletionRequest,
366    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
367        let span = if tracing::Span::current().is_disabled() {
368            info_span!(
369                target: "rig::completions",
370                "chat",
371                gen_ai.operation.name = "chat",
372                gen_ai.provider.name = "groq",
373                gen_ai.request.model = self.model,
374                gen_ai.system_instructions = tracing::field::Empty,
375                gen_ai.response.id = tracing::field::Empty,
376                gen_ai.response.model = tracing::field::Empty,
377                gen_ai.usage.output_tokens = tracing::field::Empty,
378                gen_ai.usage.input_tokens = tracing::field::Empty,
379            )
380        } else {
381            tracing::Span::current()
382        };
383
384        span.record("gen_ai.system_instructions", &completion_request.preamble);
385
386        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
387
388        if tracing::enabled!(tracing::Level::TRACE) {
389            tracing::trace!(target: "rig::completions",
390                "Groq completion request: {}",
391                serde_json::to_string_pretty(&request)?
392            );
393        }
394
395        let body = serde_json::to_vec(&request)?;
396        let req = self
397            .client
398            .post("/chat/completions")?
399            .body(body)
400            .map_err(|e| http_client::Error::Instance(e.into()))?;
401
402        let async_block = async move {
403            let response = self.client.send::<_, Bytes>(req).await?;
404            let status = response.status();
405            let response_body = response.into_body().into_future().await?.to_vec();
406
407            if status.is_success() {
408                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
409                    ApiResponse::Ok(response) => {
410                        let span = tracing::Span::current();
411                        span.record("gen_ai.response.id", response.id.clone());
412                        span.record("gen_ai.response.model_name", response.model.clone());
413                        if let Some(ref usage) = response.usage {
414                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
415                            span.record(
416                                "gen_ai.usage.output_tokens",
417                                usage.total_tokens - usage.prompt_tokens,
418                            );
419                        }
420
421                        if tracing::enabled!(tracing::Level::TRACE) {
422                            tracing::trace!(target: "rig::completions",
423                                "Groq completion response: {}",
424                                serde_json::to_string_pretty(&response)?
425                            );
426                        }
427
428                        response.try_into()
429                    }
430                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
431                }
432            } else {
433                Err(CompletionError::ProviderError(
434                    String::from_utf8_lossy(&response_body).to_string(),
435                ))
436            }
437        };
438
439        tracing::Instrument::instrument(async_block, span).await
440    }
441
442    #[cfg_attr(feature = "worker", worker::send)]
443    async fn stream(
444        &self,
445        request: CompletionRequest,
446    ) -> Result<
447        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
448        CompletionError,
449    > {
450        let span = if tracing::Span::current().is_disabled() {
451            info_span!(
452                target: "rig::completions",
453                "chat_streaming",
454                gen_ai.operation.name = "chat_streaming",
455                gen_ai.provider.name = "groq",
456                gen_ai.request.model = self.model,
457                gen_ai.system_instructions = tracing::field::Empty,
458                gen_ai.response.id = tracing::field::Empty,
459                gen_ai.response.model = tracing::field::Empty,
460                gen_ai.usage.output_tokens = tracing::field::Empty,
461                gen_ai.usage.input_tokens = tracing::field::Empty,
462            )
463        } else {
464            tracing::Span::current()
465        };
466
467        span.record("gen_ai.system_instructions", &request.preamble);
468
469        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
470
471        let params = json_utils::merge(
472            request.additional_params.unwrap_or(serde_json::json!({})),
473            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
474        );
475
476        request.additional_params = Some(params);
477
478        if tracing::enabled!(tracing::Level::TRACE) {
479            tracing::trace!(target: "rig::completions",
480                "Groq streaming completion request: {}",
481                serde_json::to_string_pretty(&request)?
482            );
483        }
484
485        let body = serde_json::to_vec(&request)?;
486        let req = self
487            .client
488            .post("/chat/completions")?
489            .body(body)
490            .map_err(|e| http_client::Error::Instance(e.into()))?;
491
492        tracing::Instrument::instrument(
493            send_compatible_streaming_request(self.client.clone(), req),
494            span,
495        )
496        .await
497    }
498}
499
500// ================================================================
501// Groq Transcription API
502// ================================================================
503
504pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
505pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
506pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
507
508#[derive(Clone)]
509pub struct TranscriptionModel<T> {
510    client: Client<T>,
511    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
512    pub model: String,
513}
514
515impl<T> TranscriptionModel<T> {
516    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
517        Self {
518            client,
519            model: model.into(),
520        }
521    }
522}
523impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
524where
525    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
526{
527    type Response = TranscriptionResponse;
528
529    type Client = Client<T>;
530
531    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
532        Self::new(client.clone(), model)
533    }
534
535    #[cfg_attr(feature = "worker", worker::send)]
536    async fn transcription(
537        &self,
538        request: transcription::TranscriptionRequest,
539    ) -> Result<
540        transcription::TranscriptionResponse<Self::Response>,
541        transcription::TranscriptionError,
542    > {
543        let data = request.data;
544
545        let mut body = reqwest::multipart::Form::new()
546            .text("model", self.model.clone())
547            .part(
548                "file",
549                Part::bytes(data).file_name(request.filename.clone()),
550            );
551
552        if let Some(language) = request.language {
553            body = body.text("language", language);
554        }
555
556        if let Some(prompt) = request.prompt {
557            body = body.text("prompt", prompt.clone());
558        }
559
560        if let Some(ref temperature) = request.temperature {
561            body = body.text("temperature", temperature.to_string());
562        }
563
564        if let Some(ref additional_params) = request.additional_params {
565            for (key, value) in additional_params
566                .as_object()
567                .expect("Additional Parameters to OpenAI Transcription should be a map")
568            {
569                body = body.text(key.to_owned(), value.to_string());
570            }
571        }
572
573        let req = self
574            .client
575            .post("/audio/transcriptions")?
576            .body(body)
577            .unwrap();
578
579        let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
580
581        let status = response.status();
582        let response_body = response.into_body().into_future().await?.to_vec();
583
584        if status.is_success() {
585            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
586                ApiResponse::Ok(response) => response.try_into(),
587                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
588                    api_error_response.message,
589                )),
590            }
591        } else {
592            Err(TranscriptionError::ProviderError(
593                String::from_utf8_lossy(&response_body).to_string(),
594            ))
595        }
596    }
597}
598
599#[derive(Deserialize, Debug)]
600#[serde(untagged)]
601enum StreamingDelta {
602    Reasoning {
603        reasoning: String,
604    },
605    MessageContent {
606        #[serde(default)]
607        content: Option<String>,
608        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
609        tool_calls: Vec<StreamingToolCall>,
610    },
611}
612
613#[derive(Deserialize, Debug)]
614struct StreamingChoice {
615    delta: StreamingDelta,
616}
617
618#[derive(Deserialize, Debug)]
619struct StreamingCompletionChunk {
620    choices: Vec<StreamingChoice>,
621    usage: Option<Usage>,
622}
623
624#[derive(Clone, Deserialize, Serialize, Debug)]
625pub struct StreamingCompletionResponse {
626    pub usage: Usage,
627}
628
629impl GetTokenUsage for StreamingCompletionResponse {
630    fn token_usage(&self) -> Option<crate::completion::Usage> {
631        let mut usage = crate::completion::Usage::new();
632
633        usage.input_tokens = self.usage.prompt_tokens as u64;
634        usage.total_tokens = self.usage.total_tokens as u64;
635        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
636
637        Some(usage)
638    }
639}
640
641pub async fn send_compatible_streaming_request<T>(
642    client: T,
643    req: Request<Vec<u8>>,
644) -> Result<
645    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
646    CompletionError,
647>
648where
649    T: HttpClientExt + Clone + 'static,
650{
651    let span = tracing::Span::current();
652
653    let mut event_source = GenericEventSource::new(client, req);
654
655    let stream = stream! {
656        let span = tracing::Span::current();
657        let mut final_usage = Usage {
658            prompt_tokens: 0,
659            total_tokens: 0
660        };
661
662        let mut text_response = String::new();
663
664        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
665
666        while let Some(event_result) = event_source.next().await {
667            match event_result {
668                Ok(Event::Open) => {
669                    tracing::trace!("SSE connection opened");
670                    continue;
671                }
672
673                Ok(Event::Message(message)) => {
674                    let data_str = message.data.trim();
675
676                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
677                    let Ok(data) = parsed else {
678                        let err = parsed.unwrap_err();
679                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
680                        continue;
681                    };
682
683                    if let Some(choice) = data.choices.first() {
684                        match &choice.delta {
685                            StreamingDelta::Reasoning { reasoning } => {
686                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
687                                    id: None,
688                                    reasoning: reasoning.to_string(),
689                                    signature: None,
690                                });
691                            }
692
693                            StreamingDelta::MessageContent { content, tool_calls } => {
694                                // Handle tool calls
695                                for tool_call in tool_calls {
696                                    let function = &tool_call.function;
697
698                                    // Start of tool call
699                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
700                                        && empty_or_none(&function.arguments)
701                                    {
702                                        let id = tool_call.id.clone().unwrap_or_default();
703                                        let name = function.name.clone().unwrap();
704                                        calls.insert(tool_call.index, (id, name, String::new()));
705                                    }
706                                    // Continuation
707                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
708                                        && let Some(arguments) = &function.arguments
709                                        && !arguments.is_empty()
710                                    {
711                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
712                                            let combined = format!("{}{}", existing_args, arguments);
713                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
714                                        } else {
715                                            tracing::debug!("Partial tool call received but tool call was never started.");
716                                        }
717                                    }
718                                    // Complete tool call
719                                    else {
720                                        let id = tool_call.id.clone().unwrap_or_default();
721                                        let name = function.name.clone().unwrap_or_default();
722                                        let arguments_str = function.arguments.clone().unwrap_or_default();
723
724                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
725                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
726                                            continue;
727                                        };
728
729                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
730                                            id,
731                                            name,
732                                            arguments: arguments_json,
733                                            call_id: None
734                                        });
735                                    }
736                                }
737
738                                // Streamed content
739                                if let Some(content) = content {
740                                    text_response += content;
741                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
742                                }
743                            }
744                        }
745                    }
746
747                    if let Some(usage) = data.usage {
748                        final_usage = usage.clone();
749                    }
750                }
751
752                Err(crate::http_client::Error::StreamEnded) => break,
753                Err(err) => {
754                    tracing::error!(?err, "SSE error");
755                    yield Err(CompletionError::ResponseError(err.to_string()));
756                    break;
757                }
758            }
759        }
760
761        event_source.close();
762
763        let mut tool_calls = Vec::new();
764        // Flush accumulated tool calls
765        for (_, (id, name, arguments)) in calls {
766            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
767                continue;
768            };
769
770            tool_calls.push(rig::providers::openai::completion::ToolCall {
771                id: id.clone(),
772                r#type: ToolType::Function,
773                function: Function {
774                    name: name.clone(),
775                    arguments: arguments_json.clone()
776                }
777            });
778            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
779                id,
780                name,
781                arguments: arguments_json,
782                call_id: None,
783            });
784        }
785
786        let response_message = crate::providers::openai::completion::Message::Assistant {
787            content: vec![AssistantContent::Text { text: text_response }],
788            refusal: None,
789            audio: None,
790            name: None,
791            tool_calls
792        };
793
794        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
795        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
796        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
797
798        // Final response
799        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
800            StreamingCompletionResponse { usage: final_usage.clone() }
801        ));
802    }.instrument(span);
803
804    Ok(crate::streaming::StreamingCompletionResponse::stream(
805        Box::pin(stream),
806    ))
807}