rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use reqwest_eventsource::{Event, RequestBuilderExt};
12use std::collections::HashMap;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
17use crate::client::{CompletionClient, TranscriptionClient, VerifyClient, VerifyError};
18use crate::completion::GetTokenUsage;
19use crate::http_client::{self, HttpClientExt};
20use crate::json_utils::merge;
21use crate::providers::openai::{AssistantContent, Function, ToolType};
22use async_stream::stream;
23use futures::StreamExt;
24
25use crate::{
26    OneOrMany,
27    completion::{self, CompletionError, CompletionRequest},
28    json_utils,
29    message::{self, MessageError},
30    providers::openai::ToolDefinition,
31    transcription::{self, TranscriptionError},
32};
33use reqwest::RequestBuilder;
34use reqwest::multipart::Part;
35use rig::client::ProviderClient;
36use rig::impl_conversion_traits;
37use serde::{Deserialize, Serialize};
38use serde_json::{Value, json};
39
40// ================================================================
41// Main Groq Client
42// ================================================================
43const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
44
45pub struct ClientBuilder<'a, T = reqwest::Client> {
46    api_key: &'a str,
47    base_url: &'a str,
48    http_client: T,
49}
50
51impl<'a, T> ClientBuilder<'a, T>
52where
53    T: Default,
54{
55    pub fn new(api_key: &'a str) -> Self {
56        Self {
57            api_key,
58            base_url: GROQ_API_BASE_URL,
59            http_client: Default::default(),
60        }
61    }
62}
63
64impl<'a, T> ClientBuilder<'a, T> {
65    pub fn base_url(mut self, base_url: &'a str) -> Self {
66        self.base_url = base_url;
67        self
68    }
69
70    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
71        ClientBuilder {
72            api_key: self.api_key,
73            base_url: self.base_url,
74            http_client,
75        }
76    }
77
78    pub fn build(self) -> Client<T> {
79        Client {
80            base_url: self.base_url.to_string(),
81            api_key: self.api_key.to_string(),
82            http_client: self.http_client,
83        }
84    }
85}
86
87#[derive(Clone)]
88pub struct Client<T = reqwest::Client> {
89    base_url: String,
90    api_key: String,
91    http_client: T,
92}
93
94impl<T> std::fmt::Debug for Client<T>
95where
96    T: std::fmt::Debug,
97{
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("Client")
100            .field("base_url", &self.base_url)
101            .field("http_client", &self.http_client)
102            .field("api_key", &"<REDACTED>")
103            .finish()
104    }
105}
106
107impl<T> Client<T>
108where
109    T: Default,
110{
111    /// Create a new Groq client builder.
112    ///
113    /// # Example
114    /// ```
115    /// use rig::providers::groq::{ClientBuilder, self};
116    ///
117    /// // Initialize the Groq client
118    /// let groq = Client::builder("your-groq-api-key")
119    ///    .build()
120    /// ```
121    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
122        ClientBuilder::new(api_key)
123    }
124
125    /// Create a new Groq client with the given API key.
126    ///
127    /// # Panics
128    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
129    pub fn new(api_key: &str) -> Self {
130        Self::builder(api_key).build()
131    }
132}
133
134impl<T> Client<T>
135where
136    T: HttpClientExt,
137{
138    fn req(
139        &self,
140        method: http_client::Method,
141        path: &str,
142    ) -> http_client::Result<http_client::Builder> {
143        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
144
145        http_client::with_bearer_auth(
146            http_client::Builder::new().method(method).uri(url),
147            &self.api_key,
148        )
149    }
150
151    fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
152        self.req(http_client::Method::GET, path)
153    }
154}
155
156impl Client<reqwest::Client> {
157    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
158        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
159
160        self.http_client.post(url).bearer_auth(&self.api_key)
161    }
162}
163
164impl ProviderClient for Client<reqwest::Client> {
165    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
166    /// Panics if the environment variable is not set.
167    fn from_env() -> Self {
168        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
169        Self::new(&api_key)
170    }
171
172    fn from_val(input: crate::client::ProviderValue) -> Self {
173        let crate::client::ProviderValue::Simple(api_key) = input else {
174            panic!("Incorrect provider value type")
175        };
176        Self::new(&api_key)
177    }
178}
179
180impl CompletionClient for Client<reqwest::Client> {
181    type CompletionModel = CompletionModel<reqwest::Client>;
182
183    /// Create a completion model with the given name.
184    ///
185    /// # Example
186    /// ```
187    /// use rig::providers::groq::{Client, self};
188    ///
189    /// // Initialize the Groq client
190    /// let groq = Client::new("your-groq-api-key");
191    ///
192    /// let gpt4 = groq.completion_model(groq::GPT_4);
193    /// ```
194    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
195        CompletionModel::new(self.clone(), model)
196    }
197}
198
199impl TranscriptionClient for Client<reqwest::Client> {
200    type TranscriptionModel = TranscriptionModel<reqwest::Client>;
201
202    /// Create a transcription model with the given name.
203    ///
204    /// # Example
205    /// ```
206    /// use rig::providers::groq::{Client, self};
207    ///
208    /// // Initialize the Groq client
209    /// let groq = Client::new("your-groq-api-key");
210    ///
211    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
212    /// ```
213    fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
214        TranscriptionModel::new(self.clone(), model)
215    }
216}
217
218impl VerifyClient for Client<reqwest::Client> {
219    #[cfg_attr(feature = "worker", worker::send)]
220    async fn verify(&self) -> Result<(), VerifyError> {
221        let req = self
222            .get("/models")?
223            .body(http_client::NoBody)
224            .map_err(http_client::Error::from)?;
225
226        let response = HttpClientExt::send(&self.http_client, req).await?;
227
228        match response.status() {
229            reqwest::StatusCode::OK => Ok(()),
230            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
231            reqwest::StatusCode::INTERNAL_SERVER_ERROR
232            | reqwest::StatusCode::SERVICE_UNAVAILABLE
233            | reqwest::StatusCode::BAD_GATEWAY => {
234                let text = http_client::text(response).await?;
235                Err(VerifyError::ProviderError(text))
236            }
237            _ => {
238                //response.error_for_status()?;
239                Ok(())
240            }
241        }
242    }
243}
244
245impl_conversion_traits!(
246    AsEmbeddings,
247    AsImageGeneration,
248    AsAudioGeneration for Client<T>
249);
250
251#[derive(Debug, Deserialize)]
252struct ApiErrorResponse {
253    message: String,
254}
255
256#[derive(Debug, Deserialize)]
257#[serde(untagged)]
258enum ApiResponse<T> {
259    Ok(T),
260    Err(ApiErrorResponse),
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264pub struct Message {
265    pub role: String,
266    pub content: Option<String>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub reasoning: Option<String>,
269}
270
271impl TryFrom<Message> for message::Message {
272    type Error = message::MessageError;
273
274    fn try_from(message: Message) -> Result<Self, Self::Error> {
275        match message.role.as_str() {
276            "user" => Ok(Self::User {
277                content: OneOrMany::one(
278                    message
279                        .content
280                        .map(|content| message::UserContent::text(&content))
281                        .ok_or_else(|| {
282                            message::MessageError::ConversionError("Empty user message".to_string())
283                        })?,
284                ),
285            }),
286            "assistant" => Ok(Self::Assistant {
287                id: None,
288                content: OneOrMany::one(
289                    message
290                        .content
291                        .map(|content| message::AssistantContent::text(&content))
292                        .ok_or_else(|| {
293                            message::MessageError::ConversionError(
294                                "Empty assistant message".to_string(),
295                            )
296                        })?,
297                ),
298            }),
299            _ => Err(message::MessageError::ConversionError(format!(
300                "Unknown role: {}",
301                message.role
302            ))),
303        }
304    }
305}
306
307impl TryFrom<message::Message> for Message {
308    type Error = message::MessageError;
309
310    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
311        match message {
312            message::Message::User { content } => Ok(Self {
313                role: "user".to_string(),
314                content: content.iter().find_map(|c| match c {
315                    message::UserContent::Text(text) => Some(text.text.clone()),
316                    _ => None,
317                }),
318                reasoning: None,
319            }),
320            message::Message::Assistant { content, .. } => {
321                let mut text_content: Option<String> = None;
322                let mut groq_reasoning: Option<String> = None;
323
324                for c in content.iter() {
325                    match c {
326                        message::AssistantContent::Text(text) => {
327                            text_content = Some(
328                                text_content
329                                    .map(|mut existing| {
330                                        existing.push('\n');
331                                        existing.push_str(&text.text);
332                                        existing
333                                    })
334                                    .unwrap_or_else(|| text.text.clone()),
335                            );
336                        }
337                        message::AssistantContent::ToolCall(_tool_call) => {
338                            return Err(MessageError::ConversionError(
339                                "Tool calls do not exist on this message".into(),
340                            ));
341                        }
342                        message::AssistantContent::Reasoning(message::Reasoning {
343                            reasoning,
344                            ..
345                        }) => {
346                            groq_reasoning =
347                                Some(reasoning.first().cloned().unwrap_or(String::new()));
348                        }
349                    }
350                }
351
352                Ok(Self {
353                    role: "assistant".to_string(),
354                    content: text_content,
355                    reasoning: groq_reasoning,
356                })
357            }
358        }
359    }
360}
361
362// ================================================================
363// Groq Completion API
364// ================================================================
365/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
366pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
367/// The `gemma2-9b-it` model. Used for chat completion.
368pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
369/// The `llama-3.1-8b-instant` model. Used for chat completion.
370pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
371/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
372pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
373/// The `llama-3.2-1b-preview` model. Used for chat completion.
374pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
375/// The `llama-3.2-3b-preview` model. Used for chat completion.
376pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
377/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
378pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
379/// The `llama-3.2-70b-specdec` model. Used for chat completion.
380pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
381/// The `llama-3.2-70b-versatile` model. Used for chat completion.
382pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
383/// The `llama-guard-3-8b` model. Used for chat completion.
384pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
385/// The `llama3-70b-8192` model. Used for chat completion.
386pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
387/// The `llama3-8b-8192` model. Used for chat completion.
388pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
389/// The `mixtral-8x7b-32768` model. Used for chat completion.
390pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
391
392#[derive(Clone, Debug)]
393pub struct CompletionModel<T> {
394    client: Client<T>,
395    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
396    pub model: String,
397}
398
399impl<T> CompletionModel<T> {
400    pub fn new(client: Client<T>, model: &str) -> Self {
401        Self {
402            client,
403            model: model.to_string(),
404        }
405    }
406
407    fn create_completion_request(
408        &self,
409        completion_request: CompletionRequest,
410    ) -> Result<Value, CompletionError> {
411        // Build up the order of messages (context, chat_history, prompt)
412        let mut partial_history = vec![];
413        if let Some(docs) = completion_request.normalized_documents() {
414            partial_history.push(docs);
415        }
416        partial_history.extend(completion_request.chat_history);
417
418        // Initialize full history with preamble (or empty if non-existent)
419        let mut full_history: Vec<Message> =
420            completion_request
421                .preamble
422                .map_or_else(Vec::new, |preamble| {
423                    vec![Message {
424                        role: "system".to_string(),
425                        content: Some(preamble),
426                        reasoning: None,
427                    }]
428                });
429
430        // Convert and extend the rest of the history
431        full_history.extend(
432            partial_history
433                .into_iter()
434                .map(message::Message::try_into)
435                .collect::<Result<Vec<Message>, _>>()?,
436        );
437
438        let tool_choice = completion_request
439            .tool_choice
440            .map(crate::providers::openai::ToolChoice::try_from)
441            .transpose()?;
442
443        let request = if completion_request.tools.is_empty() {
444            json!({
445                "model": self.model,
446                "messages": full_history,
447                "temperature": completion_request.temperature,
448            })
449        } else {
450            json!({
451                "model": self.model,
452                "messages": full_history,
453                "temperature": completion_request.temperature,
454                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
455                "tool_choice": tool_choice,
456                "reasoning_format": "parsed"
457            })
458        };
459
460        let request = if let Some(params) = completion_request.additional_params {
461            json_utils::merge(request, params)
462        } else {
463            request
464        };
465
466        Ok(request)
467    }
468}
469
470impl completion::CompletionModel for CompletionModel<reqwest::Client> {
471    type Response = CompletionResponse;
472    type StreamingResponse = StreamingCompletionResponse;
473
474    #[cfg_attr(feature = "worker", worker::send)]
475    async fn completion(
476        &self,
477        completion_request: CompletionRequest,
478    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
479        let preamble = completion_request.preamble.clone();
480
481        let request = self.create_completion_request(completion_request)?;
482        let span = if tracing::Span::current().is_disabled() {
483            info_span!(
484                target: "rig::completions",
485                "chat",
486                gen_ai.operation.name = "chat",
487                gen_ai.provider.name = "groq",
488                gen_ai.request.model = self.model,
489                gen_ai.system_instructions = preamble,
490                gen_ai.response.id = tracing::field::Empty,
491                gen_ai.response.model = tracing::field::Empty,
492                gen_ai.usage.output_tokens = tracing::field::Empty,
493                gen_ai.usage.input_tokens = tracing::field::Empty,
494                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
495                gen_ai.output.messages = tracing::field::Empty,
496            )
497        } else {
498            tracing::Span::current()
499        };
500
501        let async_block = async move {
502            let response = self
503                .client
504                .reqwest_post("/chat/completions")
505                .json(&request)
506                .send()
507                .await
508                .map_err(|e| http_client::Error::Instance(e.into()))?;
509
510            if response.status().is_success() {
511                match response
512                    .json::<ApiResponse<CompletionResponse>>()
513                    .await
514                    .map_err(|e| http_client::Error::Instance(e.into()))?
515                {
516                    ApiResponse::Ok(response) => {
517                        let span = tracing::Span::current();
518                        span.record("gen_ai.response.id", response.id.clone());
519                        span.record("gen_ai.response.model_name", response.model.clone());
520                        span.record(
521                            "gen_ai.output.messages",
522                            serde_json::to_string(&response.choices).unwrap(),
523                        );
524                        if let Some(ref usage) = response.usage {
525                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
526                            span.record(
527                                "gen_ai.usage.output_tokens",
528                                usage.total_tokens - usage.prompt_tokens,
529                            );
530                        }
531                        response.try_into()
532                    }
533                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
534                }
535            } else {
536                Err(CompletionError::ProviderError(
537                    response
538                        .text()
539                        .await
540                        .map_err(|e| http_client::Error::Instance(e.into()))?,
541                ))
542            }
543        };
544
545        tracing::Instrument::instrument(async_block, span).await
546    }
547
548    #[cfg_attr(feature = "worker", worker::send)]
549    async fn stream(
550        &self,
551        request: CompletionRequest,
552    ) -> Result<
553        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
554        CompletionError,
555    > {
556        let preamble = request.preamble.clone();
557        let mut request = self.create_completion_request(request)?;
558
559        request = merge(
560            request,
561            json!({"stream": true, "stream_options": {"include_usage": true}}),
562        );
563
564        let builder = self.client.reqwest_post("/chat/completions").json(&request);
565
566        let span = if tracing::Span::current().is_disabled() {
567            info_span!(
568                target: "rig::completions",
569                "chat_streaming",
570                gen_ai.operation.name = "chat_streaming",
571                gen_ai.provider.name = "groq",
572                gen_ai.request.model = self.model,
573                gen_ai.system_instructions = preamble,
574                gen_ai.response.id = tracing::field::Empty,
575                gen_ai.response.model = tracing::field::Empty,
576                gen_ai.usage.output_tokens = tracing::field::Empty,
577                gen_ai.usage.input_tokens = tracing::field::Empty,
578                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
579                gen_ai.output.messages = tracing::field::Empty,
580            )
581        } else {
582            tracing::Span::current()
583        };
584
585        tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await
586    }
587}
588
589// ================================================================
590// Groq Transcription API
591// ================================================================
592pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
593pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
594pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
595
596#[derive(Clone)]
597pub struct TranscriptionModel<T> {
598    client: Client<T>,
599    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
600    pub model: String,
601}
602
603impl<T> TranscriptionModel<T> {
604    pub fn new(client: Client<T>, model: &str) -> Self {
605        Self {
606            client,
607            model: model.to_string(),
608        }
609    }
610}
611impl transcription::TranscriptionModel for TranscriptionModel<reqwest::Client> {
612    type Response = TranscriptionResponse;
613
614    #[cfg_attr(feature = "worker", worker::send)]
615    async fn transcription(
616        &self,
617        request: transcription::TranscriptionRequest,
618    ) -> Result<
619        transcription::TranscriptionResponse<Self::Response>,
620        transcription::TranscriptionError,
621    > {
622        let data = request.data;
623
624        let mut body = reqwest::multipart::Form::new()
625            .text("model", self.model.clone())
626            .text("language", request.language)
627            .part(
628                "file",
629                Part::bytes(data).file_name(request.filename.clone()),
630            );
631
632        if let Some(prompt) = request.prompt {
633            body = body.text("prompt", prompt.clone());
634        }
635
636        if let Some(ref temperature) = request.temperature {
637            body = body.text("temperature", temperature.to_string());
638        }
639
640        if let Some(ref additional_params) = request.additional_params {
641            for (key, value) in additional_params
642                .as_object()
643                .expect("Additional Parameters to OpenAI Transcription should be a map")
644            {
645                body = body.text(key.to_owned(), value.to_string());
646            }
647        }
648
649        let response = self
650            .client
651            .reqwest_post("audio/transcriptions")
652            .multipart(body)
653            .send()
654            .await
655            .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?;
656
657        if response.status().is_success() {
658            match response
659                .json::<ApiResponse<TranscriptionResponse>>()
660                .await
661                .map_err(|e| {
662                    TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
663                })? {
664                ApiResponse::Ok(response) => response.try_into(),
665                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
666                    api_error_response.message,
667                )),
668            }
669        } else {
670            Err(TranscriptionError::ProviderError(
671                response.text().await.map_err(|e| {
672                    TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
673                })?,
674            ))
675        }
676    }
677}
678
679#[derive(Deserialize, Debug)]
680#[serde(untagged)]
681pub enum StreamingDelta {
682    Reasoning {
683        reasoning: String,
684    },
685    MessageContent {
686        #[serde(default)]
687        content: Option<String>,
688        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
689        tool_calls: Vec<StreamingToolCall>,
690    },
691}
692
693#[derive(Deserialize, Debug)]
694struct StreamingChoice {
695    delta: StreamingDelta,
696}
697
698#[derive(Deserialize, Debug)]
699struct StreamingCompletionChunk {
700    choices: Vec<StreamingChoice>,
701    usage: Option<Usage>,
702}
703
704#[derive(Clone, Deserialize, Serialize, Debug)]
705pub struct StreamingCompletionResponse {
706    pub usage: Usage,
707}
708
709impl GetTokenUsage for StreamingCompletionResponse {
710    fn token_usage(&self) -> Option<crate::completion::Usage> {
711        let mut usage = crate::completion::Usage::new();
712
713        usage.input_tokens = self.usage.prompt_tokens as u64;
714        usage.total_tokens = self.usage.total_tokens as u64;
715        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
716
717        Some(usage)
718    }
719}
720
721pub async fn send_compatible_streaming_request(
722    request_builder: RequestBuilder,
723) -> Result<
724    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
725    CompletionError,
726> {
727    let span = tracing::Span::current();
728    let mut event_source = request_builder
729        .eventsource()
730        .expect("Cloning request must succeed");
731
732    let stream = stream! {
733        let span = tracing::Span::current();
734        let mut final_usage = Usage {
735            prompt_tokens: 0,
736            total_tokens: 0
737        };
738
739        let mut text_response = String::new();
740
741        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
742
743        while let Some(event_result) = event_source.next().await {
744            match event_result {
745                Ok(Event::Open) => {
746                    tracing::trace!("SSE connection opened");
747                    continue;
748                }
749
750                Ok(Event::Message(message)) => {
751                    let data_str = message.data.trim();
752
753                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
754                    let Ok(data) = parsed else {
755                        let err = parsed.unwrap_err();
756                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
757                        continue;
758                    };
759
760                    if let Some(choice) = data.choices.first() {
761                        match &choice.delta {
762                            StreamingDelta::Reasoning { reasoning } => {
763                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
764                                    id: None,
765                                    reasoning: reasoning.to_string()
766                                });
767                            }
768
769                            StreamingDelta::MessageContent { content, tool_calls } => {
770                                // Handle tool calls
771                                for tool_call in tool_calls {
772                                    let function = &tool_call.function;
773
774                                    // Start of tool call
775                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
776                                        && function.arguments.is_empty()
777                                    {
778                                        let id = tool_call.id.clone().unwrap_or_default();
779                                        let name = function.name.clone().unwrap();
780                                        calls.insert(tool_call.index, (id, name, String::new()));
781                                    }
782                                    // Continuation
783                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
784                                        && !function.arguments.is_empty()
785                                    {
786                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
787                                            let combined = format!("{}{}", existing_args, function.arguments);
788                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
789                                        } else {
790                                            tracing::debug!("Partial tool call received but tool call was never started.");
791                                        }
792                                    }
793                                    // Complete tool call
794                                    else {
795                                        let id = tool_call.id.clone().unwrap_or_default();
796                                        let name = function.name.clone().unwrap_or_default();
797                                        let arguments_str = function.arguments.clone();
798
799                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
800                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
801                                            continue;
802                                        };
803
804                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
805                                            id,
806                                            name,
807                                            arguments: arguments_json,
808                                            call_id: None
809                                        });
810                                    }
811                                }
812
813                                // Streamed content
814                                if let Some(content) = content {
815                                    text_response += content;
816                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
817                                }
818                            }
819                        }
820                    }
821
822                    if let Some(usage) = data.usage {
823                        final_usage = usage.clone();
824                    }
825                }
826
827                Err(reqwest_eventsource::Error::StreamEnded) => break,
828
829                Err(err) => {
830                    tracing::error!(?err, "SSE error");
831                    yield Err(CompletionError::ResponseError(err.to_string()));
832                    break;
833                }
834            }
835        }
836
837        let mut tool_calls = Vec::new();
838        // Flush accumulated tool calls
839        for (_, (id, name, arguments)) in calls {
840            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
841                continue;
842            };
843
844            tool_calls.push(rig::providers::openai::completion::ToolCall {
845                id: id.clone(),
846                r#type: ToolType::Function,
847                function: Function {
848                    name: name.clone(),
849                    arguments: arguments_json.clone()
850                }
851            });
852            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
853                id,
854                name,
855                arguments: arguments_json,
856                call_id: None,
857            });
858        }
859
860        let response_message = crate::providers::openai::completion::Message::Assistant {
861            content: vec![AssistantContent::Text { text: text_response }],
862            refusal: None,
863            audio: None,
864            name: None,
865            tool_calls
866        };
867
868        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
869        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
870        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
871
872        // Final response
873        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
874            StreamingCompletionResponse { usage: final_usage.clone() }
875        ));
876    }.instrument(span);
877
878    Ok(crate::streaming::StreamingCompletionResponse::stream(
879        Box::pin(stream),
880    ))
881}