rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::Request;
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{
19    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
20    ProviderClient,
21};
22use crate::completion::GetTokenUsage;
23use crate::http_client::sse::{Event, GenericEventSource};
24use crate::http_client::{self, HttpClientExt};
25use crate::json_utils::empty_or_none;
26use crate::providers::openai::{AssistantContent, Function, ToolType};
27use async_stream::stream;
28use futures::StreamExt;
29
30use crate::{
31    OneOrMany,
32    completion::{self, CompletionError, CompletionRequest},
33    json_utils,
34    message::{self, MessageError},
35    providers::openai::ToolDefinition,
36    transcription::{self, TranscriptionError},
37};
38use reqwest::multipart::Part;
39use serde::{Deserialize, Serialize};
40
41// ================================================================
42// Main Groq Client
43// ================================================================
44const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct GroqExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqBuilder;
50
51type GroqApiKey = BearerAuth;
52
53impl Provider for GroqExt {
54    type Builder = GroqBuilder;
55
56    const VERIFY_PATH: &'static str = "/models";
57
58    fn build<H>(
59        _: &crate::client::ClientBuilder<
60            Self::Builder,
61            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
62            H,
63        >,
64    ) -> http_client::Result<Self> {
65        Ok(Self)
66    }
67}
68
69impl<H> Capabilities<H> for GroqExt {
70    type Completion = Capable<CompletionModel<H>>;
71    type Embeddings = Nothing;
72    type Transcription = Capable<TranscriptionModel<H>>;
73    #[cfg(feature = "image")]
74    type ImageGeneration = Nothing;
75
76    #[cfg(feature = "audio")]
77    type AudioGeneration = Nothing;
78}
79
80impl DebugExt for GroqExt {}
81
82impl ProviderBuilder for GroqBuilder {
83    type Output = GroqExt;
84    type ApiKey = GroqApiKey;
85
86    const BASE_URL: &'static str = GROQ_API_BASE_URL;
87}
88
89pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
90pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
91
92impl ProviderClient for Client {
93    type Input = String;
94
95    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
96    /// Panics if the environment variable is not set.
97    fn from_env() -> Self {
98        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
99        Self::new(&api_key).unwrap()
100    }
101
102    fn from_val(input: Self::Input) -> Self {
103        Self::new(&input).unwrap()
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub struct Message {
121    pub role: String,
122    pub content: Option<String>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub reasoning: Option<String>,
125}
126
127impl Message {
128    fn system(preamble: &str) -> Self {
129        Self {
130            role: "system".to_string(),
131            content: Some(preamble.to_string()),
132            reasoning: None,
133        }
134    }
135}
136
137impl TryFrom<Message> for message::Message {
138    type Error = message::MessageError;
139
140    fn try_from(message: Message) -> Result<Self, Self::Error> {
141        match message.role.as_str() {
142            "user" => Ok(Self::User {
143                content: OneOrMany::one(
144                    message
145                        .content
146                        .map(|content| message::UserContent::text(&content))
147                        .ok_or_else(|| {
148                            message::MessageError::ConversionError("Empty user message".to_string())
149                        })?,
150                ),
151            }),
152            "assistant" => Ok(Self::Assistant {
153                id: None,
154                content: OneOrMany::one(
155                    message
156                        .content
157                        .map(|content| message::AssistantContent::text(&content))
158                        .ok_or_else(|| {
159                            message::MessageError::ConversionError(
160                                "Empty assistant message".to_string(),
161                            )
162                        })?,
163                ),
164            }),
165            _ => Err(message::MessageError::ConversionError(format!(
166                "Unknown role: {}",
167                message.role
168            ))),
169        }
170    }
171}
172
173impl TryFrom<message::Message> for Message {
174    type Error = message::MessageError;
175
176    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
177        match message {
178            message::Message::User { content } => Ok(Self {
179                role: "user".to_string(),
180                content: content.iter().find_map(|c| match c {
181                    message::UserContent::Text(text) => Some(text.text.clone()),
182                    _ => None,
183                }),
184                reasoning: None,
185            }),
186            message::Message::Assistant { content, .. } => {
187                let mut text_content: Option<String> = None;
188                let mut groq_reasoning: Option<String> = None;
189
190                for c in content.iter() {
191                    match c {
192                        message::AssistantContent::Text(text) => {
193                            text_content = Some(
194                                text_content
195                                    .map(|mut existing| {
196                                        existing.push('\n');
197                                        existing.push_str(&text.text);
198                                        existing
199                                    })
200                                    .unwrap_or_else(|| text.text.clone()),
201                            );
202                        }
203                        message::AssistantContent::ToolCall(_tool_call) => {
204                            return Err(MessageError::ConversionError(
205                                "Tool calls do not exist on this message".into(),
206                            ));
207                        }
208                        message::AssistantContent::Reasoning(message::Reasoning {
209                            reasoning,
210                            ..
211                        }) => {
212                            groq_reasoning =
213                                Some(reasoning.first().cloned().unwrap_or(String::new()));
214                        }
215                        message::AssistantContent::Image(_) => {
216                            return Err(MessageError::ConversionError(
217                                "Ollama currently doesn't support images.".into(),
218                            ));
219                        }
220                    }
221                }
222
223                Ok(Self {
224                    role: "assistant".to_string(),
225                    content: text_content,
226                    reasoning: groq_reasoning,
227                })
228            }
229        }
230    }
231}
232
233// ================================================================
234// Groq Completion API
235// ================================================================
236
237/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
238pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
239/// The `gemma2-9b-it` model. Used for chat completion.
240pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
241/// The `llama-3.1-8b-instant` model. Used for chat completion.
242pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
243/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
244pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
245/// The `llama-3.2-1b-preview` model. Used for chat completion.
246pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
247/// The `llama-3.2-3b-preview` model. Used for chat completion.
248pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
249/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
250pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
251/// The `llama-3.2-70b-specdec` model. Used for chat completion.
252pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
253/// The `llama-3.2-70b-versatile` model. Used for chat completion.
254pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
255/// The `llama-guard-3-8b` model. Used for chat completion.
256pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
257/// The `llama3-70b-8192` model. Used for chat completion.
258pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
259/// The `llama3-8b-8192` model. Used for chat completion.
260pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
261/// The `mixtral-8x7b-32768` model. Used for chat completion.
262pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "lowercase")]
266pub enum ReasoningFormat {
267    Parsed,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub(super) struct GroqCompletionRequest {
272    model: String,
273    pub messages: Vec<Message>,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    temperature: Option<f64>,
276    #[serde(skip_serializing_if = "Vec::is_empty")]
277    tools: Vec<ToolDefinition>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
280    #[serde(flatten, skip_serializing_if = "Option::is_none")]
281    pub additional_params: Option<serde_json::Value>,
282    reasoning_format: ReasoningFormat,
283}
284
285impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
286    type Error = CompletionError;
287
288    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
289        // Build up the order of messages (context, chat_history, prompt)
290        let mut partial_history = vec![];
291        if let Some(docs) = req.normalized_documents() {
292            partial_history.push(docs);
293        }
294        partial_history.extend(req.chat_history);
295
296        // Add preamble to chat history (if available)
297        let mut full_history: Vec<Message> = match &req.preamble {
298            Some(preamble) => vec![Message::system(preamble)],
299            None => vec![],
300        };
301
302        // Convert and extend the rest of the history
303        full_history.extend(
304            partial_history
305                .into_iter()
306                .map(message::Message::try_into)
307                .collect::<Result<Vec<Message>, _>>()?,
308        );
309
310        let tool_choice = req
311            .tool_choice
312            .clone()
313            .map(crate::providers::openai::ToolChoice::try_from)
314            .transpose()?;
315
316        Ok(Self {
317            model: model.to_string(),
318            messages: full_history,
319            temperature: req.temperature,
320            tools: req
321                .tools
322                .clone()
323                .into_iter()
324                .map(ToolDefinition::from)
325                .collect::<Vec<_>>(),
326            tool_choice,
327            additional_params: req.additional_params,
328            reasoning_format: ReasoningFormat::Parsed,
329        })
330    }
331}
332
333#[derive(Clone, Debug)]
334pub struct CompletionModel<T = reqwest::Client> {
335    client: Client<T>,
336    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
337    pub model: String,
338}
339
340impl<T> CompletionModel<T> {
341    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
342        Self {
343            client,
344            model: model.into(),
345        }
346    }
347}
348
349impl<T> completion::CompletionModel for CompletionModel<T>
350where
351    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
352{
353    type Response = CompletionResponse;
354    type StreamingResponse = StreamingCompletionResponse;
355
356    type Client = Client<T>;
357
358    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
359        Self::new(client.clone(), model)
360    }
361
362    #[cfg_attr(feature = "worker", worker::send)]
363    async fn completion(
364        &self,
365        completion_request: CompletionRequest,
366    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
367        let preamble = completion_request.preamble.clone();
368
369        let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
370        let span = if tracing::Span::current().is_disabled() {
371            info_span!(
372                target: "rig::completions",
373                "chat",
374                gen_ai.operation.name = "chat",
375                gen_ai.provider.name = "groq",
376                gen_ai.request.model = self.model,
377                gen_ai.system_instructions = preamble,
378                gen_ai.response.id = tracing::field::Empty,
379                gen_ai.response.model = tracing::field::Empty,
380                gen_ai.usage.output_tokens = tracing::field::Empty,
381                gen_ai.usage.input_tokens = tracing::field::Empty,
382                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
383                gen_ai.output.messages = tracing::field::Empty,
384            )
385        } else {
386            tracing::Span::current()
387        };
388
389        let body = serde_json::to_vec(&request)?;
390        let req = self
391            .client
392            .post("/chat/completions")?
393            .body(body)
394            .map_err(|e| http_client::Error::Instance(e.into()))?;
395
396        let async_block = async move {
397            let response = self.client.send::<_, Bytes>(req).await?;
398            let status = response.status();
399            let response_body = response.into_body().into_future().await?.to_vec();
400
401            if status.is_success() {
402                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
403                    ApiResponse::Ok(response) => {
404                        let span = tracing::Span::current();
405                        span.record("gen_ai.response.id", response.id.clone());
406                        span.record("gen_ai.response.model_name", response.model.clone());
407                        span.record(
408                            "gen_ai.output.messages",
409                            serde_json::to_string(&response.choices)?,
410                        );
411                        if let Some(ref usage) = response.usage {
412                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
413                            span.record(
414                                "gen_ai.usage.output_tokens",
415                                usage.total_tokens - usage.prompt_tokens,
416                            );
417                        }
418                        response.try_into()
419                    }
420                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
421                }
422            } else {
423                Err(CompletionError::ProviderError(
424                    String::from_utf8_lossy(&response_body).to_string(),
425                ))
426            }
427        };
428
429        tracing::Instrument::instrument(async_block, span).await
430    }
431
432    #[cfg_attr(feature = "worker", worker::send)]
433    async fn stream(
434        &self,
435        request: CompletionRequest,
436    ) -> Result<
437        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
438        CompletionError,
439    > {
440        let preamble = request.preamble.clone();
441        let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
442
443        let params = json_utils::merge(
444            request.additional_params.unwrap_or(serde_json::json!({})),
445            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
446        );
447
448        request.additional_params = Some(params);
449
450        let body = serde_json::to_vec(&request)?;
451        let req = self
452            .client
453            .post("/chat/completions")?
454            .body(body)
455            .map_err(|e| http_client::Error::Instance(e.into()))?;
456
457        let span = if tracing::Span::current().is_disabled() {
458            info_span!(
459                target: "rig::completions",
460                "chat_streaming",
461                gen_ai.operation.name = "chat_streaming",
462                gen_ai.provider.name = "groq",
463                gen_ai.request.model = self.model,
464                gen_ai.system_instructions = preamble,
465                gen_ai.response.id = tracing::field::Empty,
466                gen_ai.response.model = tracing::field::Empty,
467                gen_ai.usage.output_tokens = tracing::field::Empty,
468                gen_ai.usage.input_tokens = tracing::field::Empty,
469                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
470                gen_ai.output.messages = tracing::field::Empty,
471            )
472        } else {
473            tracing::Span::current()
474        };
475
476        tracing::Instrument::instrument(
477            send_compatible_streaming_request(self.client.http_client().clone(), req),
478            span,
479        )
480        .await
481    }
482}
483
484// ================================================================
485// Groq Transcription API
486// ================================================================
487
488pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
489pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
490pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
491
492#[derive(Clone)]
493pub struct TranscriptionModel<T> {
494    client: Client<T>,
495    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
496    pub model: String,
497}
498
499impl<T> TranscriptionModel<T> {
500    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
501        Self {
502            client,
503            model: model.into(),
504        }
505    }
506}
507impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
508where
509    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
510{
511    type Response = TranscriptionResponse;
512
513    type Client = Client<T>;
514
515    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
516        Self::new(client.clone(), model)
517    }
518
519    #[cfg_attr(feature = "worker", worker::send)]
520    async fn transcription(
521        &self,
522        request: transcription::TranscriptionRequest,
523    ) -> Result<
524        transcription::TranscriptionResponse<Self::Response>,
525        transcription::TranscriptionError,
526    > {
527        let data = request.data;
528
529        let mut body = reqwest::multipart::Form::new()
530            .text("model", self.model.clone())
531            .part(
532                "file",
533                Part::bytes(data).file_name(request.filename.clone()),
534            );
535
536        if let Some(language) = request.language {
537            body = body.text("language", language);
538        }
539
540        if let Some(prompt) = request.prompt {
541            body = body.text("prompt", prompt.clone());
542        }
543
544        if let Some(ref temperature) = request.temperature {
545            body = body.text("temperature", temperature.to_string());
546        }
547
548        if let Some(ref additional_params) = request.additional_params {
549            for (key, value) in additional_params
550                .as_object()
551                .expect("Additional Parameters to OpenAI Transcription should be a map")
552            {
553                body = body.text(key.to_owned(), value.to_string());
554            }
555        }
556
557        let req = self
558            .client
559            .post("/audio/transcriptions")?
560            .body(body)
561            .unwrap();
562
563        let response = self
564            .client
565            .http_client()
566            .send_multipart::<Bytes>(req)
567            .await
568            .unwrap();
569
570        let status = response.status();
571        let response_body = response.into_body().into_future().await?.to_vec();
572
573        if status.is_success() {
574            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
575                ApiResponse::Ok(response) => response.try_into(),
576                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
577                    api_error_response.message,
578                )),
579            }
580        } else {
581            Err(TranscriptionError::ProviderError(
582                String::from_utf8_lossy(&response_body).to_string(),
583            ))
584        }
585    }
586}
587
588#[derive(Deserialize, Debug)]
589#[serde(untagged)]
590enum StreamingDelta {
591    Reasoning {
592        reasoning: String,
593    },
594    MessageContent {
595        #[serde(default)]
596        content: Option<String>,
597        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
598        tool_calls: Vec<StreamingToolCall>,
599    },
600}
601
602#[derive(Deserialize, Debug)]
603struct StreamingChoice {
604    delta: StreamingDelta,
605}
606
607#[derive(Deserialize, Debug)]
608struct StreamingCompletionChunk {
609    choices: Vec<StreamingChoice>,
610    usage: Option<Usage>,
611}
612
613#[derive(Clone, Deserialize, Serialize, Debug)]
614pub struct StreamingCompletionResponse {
615    pub usage: Usage,
616}
617
618impl GetTokenUsage for StreamingCompletionResponse {
619    fn token_usage(&self) -> Option<crate::completion::Usage> {
620        let mut usage = crate::completion::Usage::new();
621
622        usage.input_tokens = self.usage.prompt_tokens as u64;
623        usage.total_tokens = self.usage.total_tokens as u64;
624        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
625
626        Some(usage)
627    }
628}
629
630pub async fn send_compatible_streaming_request<T>(
631    client: T,
632    req: Request<Vec<u8>>,
633) -> Result<
634    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
635    CompletionError,
636>
637where
638    T: HttpClientExt + Clone + 'static,
639{
640    let span = tracing::Span::current();
641
642    let mut event_source = GenericEventSource::new(client, req);
643
644    let stream = stream! {
645        let span = tracing::Span::current();
646        let mut final_usage = Usage {
647            prompt_tokens: 0,
648            total_tokens: 0
649        };
650
651        let mut text_response = String::new();
652
653        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
654
655        while let Some(event_result) = event_source.next().await {
656            match event_result {
657                Ok(Event::Open) => {
658                    tracing::trace!("SSE connection opened");
659                    continue;
660                }
661
662                Ok(Event::Message(message)) => {
663                    let data_str = message.data.trim();
664
665                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
666                    let Ok(data) = parsed else {
667                        let err = parsed.unwrap_err();
668                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
669                        continue;
670                    };
671
672                    if let Some(choice) = data.choices.first() {
673                        match &choice.delta {
674                            StreamingDelta::Reasoning { reasoning } => {
675                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
676                                    id: None,
677                                    reasoning: reasoning.to_string(),
678                                    signature: None,
679                                });
680                            }
681
682                            StreamingDelta::MessageContent { content, tool_calls } => {
683                                // Handle tool calls
684                                for tool_call in tool_calls {
685                                    let function = &tool_call.function;
686
687                                    // Start of tool call
688                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
689                                        && empty_or_none(&function.arguments)
690                                    {
691                                        let id = tool_call.id.clone().unwrap_or_default();
692                                        let name = function.name.clone().unwrap();
693                                        calls.insert(tool_call.index, (id, name, String::new()));
694                                    }
695                                    // Continuation
696                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
697                                        && let Some(arguments) = &function.arguments
698                                        && !arguments.is_empty()
699                                    {
700                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
701                                            let combined = format!("{}{}", existing_args, arguments);
702                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
703                                        } else {
704                                            tracing::debug!("Partial tool call received but tool call was never started.");
705                                        }
706                                    }
707                                    // Complete tool call
708                                    else {
709                                        let id = tool_call.id.clone().unwrap_or_default();
710                                        let name = function.name.clone().unwrap_or_default();
711                                        let arguments_str = function.arguments.clone().unwrap_or_default();
712
713                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
714                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
715                                            continue;
716                                        };
717
718                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
719                                            id,
720                                            name,
721                                            arguments: arguments_json,
722                                            call_id: None
723                                        });
724                                    }
725                                }
726
727                                // Streamed content
728                                if let Some(content) = content {
729                                    text_response += content;
730                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
731                                }
732                            }
733                        }
734                    }
735
736                    if let Some(usage) = data.usage {
737                        final_usage = usage.clone();
738                    }
739                }
740
741                Err(crate::http_client::Error::StreamEnded) => break,
742                Err(err) => {
743                    tracing::error!(?err, "SSE error");
744                    yield Err(CompletionError::ResponseError(err.to_string()));
745                    break;
746                }
747            }
748        }
749
750        event_source.close();
751
752        let mut tool_calls = Vec::new();
753        // Flush accumulated tool calls
754        for (_, (id, name, arguments)) in calls {
755            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
756                continue;
757            };
758
759            tool_calls.push(rig::providers::openai::completion::ToolCall {
760                id: id.clone(),
761                r#type: ToolType::Function,
762                function: Function {
763                    name: name.clone(),
764                    arguments: arguments_json.clone()
765                }
766            });
767            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
768                id,
769                name,
770                arguments: arguments_json,
771                call_id: None,
772            });
773        }
774
775        let response_message = crate::providers::openai::completion::Message::Assistant {
776            content: vec![AssistantContent::Text { text: text_response }],
777            refusal: None,
778            audio: None,
779            name: None,
780            tool_calls
781        };
782
783        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
784        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
785        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
786
787        // Final response
788        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
789            StreamingCompletionResponse { usage: final_usage.clone() }
790        ));
791    }.instrument(span);
792
793    Ok(crate::streaming::StreamingCompletionResponse::stream(
794        Box::pin(stream),
795    ))
796}