rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use bytes::Bytes;
12use http::{Method, Request};
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{CompletionClient, TranscriptionClient, VerifyClient, VerifyError};
19use crate::completion::GetTokenUsage;
20use crate::http_client::sse::{Event, GenericEventSource};
21use crate::http_client::{self, HttpClientExt};
22use crate::json_utils::merge;
23use crate::providers::openai::{AssistantContent, Function, ToolType};
24use async_stream::stream;
25use futures::StreamExt;
26
27use crate::{
28    OneOrMany,
29    completion::{self, CompletionError, CompletionRequest},
30    json_utils,
31    message::{self, MessageError},
32    providers::openai::ToolDefinition,
33    transcription::{self, TranscriptionError},
34};
35use reqwest::multipart::Part;
36use rig::client::ProviderClient;
37use rig::impl_conversion_traits;
38use serde::{Deserialize, Serialize};
39use serde_json::{Value, json};
40
41// ================================================================
42// Main Groq Client
43// ================================================================
44const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46pub struct ClientBuilder<'a, T = reqwest::Client> {
47    api_key: &'a str,
48    base_url: &'a str,
49    http_client: T,
50}
51
52impl<'a, T> ClientBuilder<'a, T>
53where
54    T: Default,
55{
56    pub fn new(api_key: &'a str) -> Self {
57        Self {
58            api_key,
59            base_url: GROQ_API_BASE_URL,
60            http_client: Default::default(),
61        }
62    }
63}
64
65impl<'a, T> ClientBuilder<'a, T> {
66    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
67        Self {
68            api_key,
69            base_url: GROQ_API_BASE_URL,
70            http_client,
71        }
72    }
73
74    pub fn base_url(mut self, base_url: &'a str) -> Self {
75        self.base_url = base_url;
76        self
77    }
78
79    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
80        ClientBuilder {
81            api_key: self.api_key,
82            base_url: self.base_url,
83            http_client,
84        }
85    }
86
87    pub fn build(self) -> Client<T> {
88        Client {
89            base_url: self.base_url.to_string(),
90            api_key: self.api_key.to_string(),
91            http_client: self.http_client,
92        }
93    }
94}
95
96#[derive(Clone)]
97pub struct Client<T = reqwest::Client> {
98    base_url: String,
99    api_key: String,
100    http_client: T,
101}
102
103impl<T> std::fmt::Debug for Client<T>
104where
105    T: std::fmt::Debug,
106{
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("Client")
109            .field("base_url", &self.base_url)
110            .field("http_client", &self.http_client)
111            .field("api_key", &"<REDACTED>")
112            .finish()
113    }
114}
115
116impl<T> Client<T>
117where
118    T: HttpClientExt,
119{
120    fn req(
121        &self,
122        method: http_client::Method,
123        path: &str,
124    ) -> http_client::Result<http_client::Builder> {
125        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126
127        http_client::with_bearer_auth(
128            http_client::Builder::new().method(method).uri(url),
129            &self.api_key,
130        )
131    }
132
133    fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
134        self.req(http_client::Method::GET, path)
135    }
136}
137
138impl Client<reqwest::Client> {
139    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
140        ClientBuilder::new(api_key)
141    }
142
143    pub fn new(api_key: &str) -> Self {
144        ClientBuilder::new(api_key).build()
145    }
146
147    pub fn from_env() -> Self {
148        <Self as ProviderClient>::from_env()
149    }
150}
151
152impl<T> ProviderClient for Client<T>
153where
154    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
155{
156    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
157    /// Panics if the environment variable is not set.
158    fn from_env() -> Self {
159        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
160        ClientBuilder::<T>::new(&api_key).build()
161    }
162
163    fn from_val(input: crate::client::ProviderValue) -> Self {
164        let crate::client::ProviderValue::Simple(api_key) = input else {
165            panic!("Incorrect provider value type")
166        };
167        ClientBuilder::<T>::new(&api_key).build()
168    }
169}
170
171impl<T> CompletionClient for Client<T>
172where
173    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
174{
175    type CompletionModel = CompletionModel<T>;
176
177    /// Create a completion model with the given name.
178    ///
179    /// # Example
180    /// ```
181    /// use rig::providers::groq::{Client, self};
182    ///
183    /// // Initialize the Groq client
184    /// let groq = Client::new("your-groq-api-key");
185    ///
186    /// let gpt4 = groq.completion_model(groq::GPT_4);
187    /// ```
188    fn completion_model(&self, model: &str) -> Self::CompletionModel {
189        CompletionModel::new(self.clone(), model)
190    }
191}
192
193impl<T> TranscriptionClient for Client<T>
194where
195    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
196{
197    type TranscriptionModel = TranscriptionModel<T>;
198
199    /// Create a transcription model with the given name.
200    ///
201    /// # Example
202    /// ```
203    /// use rig::providers::groq::{Client, self};
204    ///
205    /// // Initialize the Groq client
206    /// let groq = Client::new("your-groq-api-key");
207    ///
208    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
209    /// ```
210    fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
211        TranscriptionModel::new(self.clone(), model)
212    }
213}
214
215impl<T> VerifyClient for Client<T>
216where
217    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
218{
219    #[cfg_attr(feature = "worker", worker::send)]
220    async fn verify(&self) -> Result<(), VerifyError> {
221        let req = self
222            .get("/models")?
223            .body(http_client::NoBody)
224            .map_err(http_client::Error::from)?;
225
226        let response = HttpClientExt::send(&self.http_client, req).await?;
227
228        match response.status() {
229            reqwest::StatusCode::OK => Ok(()),
230            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
231            reqwest::StatusCode::INTERNAL_SERVER_ERROR
232            | reqwest::StatusCode::SERVICE_UNAVAILABLE
233            | reqwest::StatusCode::BAD_GATEWAY => {
234                let text = http_client::text(response).await?;
235                Err(VerifyError::ProviderError(text))
236            }
237            _ => {
238                //response.error_for_status()?;
239                Ok(())
240            }
241        }
242    }
243}
244
245impl_conversion_traits!(
246    AsEmbeddings,
247    AsImageGeneration,
248    AsAudioGeneration for Client<T>
249);
250
251#[derive(Debug, Deserialize)]
252struct ApiErrorResponse {
253    message: String,
254}
255
256#[derive(Debug, Deserialize)]
257#[serde(untagged)]
258enum ApiResponse<T> {
259    Ok(T),
260    Err(ApiErrorResponse),
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264pub struct Message {
265    pub role: String,
266    pub content: Option<String>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub reasoning: Option<String>,
269}
270
271impl TryFrom<Message> for message::Message {
272    type Error = message::MessageError;
273
274    fn try_from(message: Message) -> Result<Self, Self::Error> {
275        match message.role.as_str() {
276            "user" => Ok(Self::User {
277                content: OneOrMany::one(
278                    message
279                        .content
280                        .map(|content| message::UserContent::text(&content))
281                        .ok_or_else(|| {
282                            message::MessageError::ConversionError("Empty user message".to_string())
283                        })?,
284                ),
285            }),
286            "assistant" => Ok(Self::Assistant {
287                id: None,
288                content: OneOrMany::one(
289                    message
290                        .content
291                        .map(|content| message::AssistantContent::text(&content))
292                        .ok_or_else(|| {
293                            message::MessageError::ConversionError(
294                                "Empty assistant message".to_string(),
295                            )
296                        })?,
297                ),
298            }),
299            _ => Err(message::MessageError::ConversionError(format!(
300                "Unknown role: {}",
301                message.role
302            ))),
303        }
304    }
305}
306
307impl TryFrom<message::Message> for Message {
308    type Error = message::MessageError;
309
310    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
311        match message {
312            message::Message::User { content } => Ok(Self {
313                role: "user".to_string(),
314                content: content.iter().find_map(|c| match c {
315                    message::UserContent::Text(text) => Some(text.text.clone()),
316                    _ => None,
317                }),
318                reasoning: None,
319            }),
320            message::Message::Assistant { content, .. } => {
321                let mut text_content: Option<String> = None;
322                let mut groq_reasoning: Option<String> = None;
323
324                for c in content.iter() {
325                    match c {
326                        message::AssistantContent::Text(text) => {
327                            text_content = Some(
328                                text_content
329                                    .map(|mut existing| {
330                                        existing.push('\n');
331                                        existing.push_str(&text.text);
332                                        existing
333                                    })
334                                    .unwrap_or_else(|| text.text.clone()),
335                            );
336                        }
337                        message::AssistantContent::ToolCall(_tool_call) => {
338                            return Err(MessageError::ConversionError(
339                                "Tool calls do not exist on this message".into(),
340                            ));
341                        }
342                        message::AssistantContent::Reasoning(message::Reasoning {
343                            reasoning,
344                            ..
345                        }) => {
346                            groq_reasoning =
347                                Some(reasoning.first().cloned().unwrap_or(String::new()));
348                        }
349                    }
350                }
351
352                Ok(Self {
353                    role: "assistant".to_string(),
354                    content: text_content,
355                    reasoning: groq_reasoning,
356                })
357            }
358        }
359    }
360}
361
362// ================================================================
363// Groq Completion API
364// ================================================================
365/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
366pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
367/// The `gemma2-9b-it` model. Used for chat completion.
368pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
369/// The `llama-3.1-8b-instant` model. Used for chat completion.
370pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
371/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
372pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
373/// The `llama-3.2-1b-preview` model. Used for chat completion.
374pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
375/// The `llama-3.2-3b-preview` model. Used for chat completion.
376pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
377/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
378pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
379/// The `llama-3.2-70b-specdec` model. Used for chat completion.
380pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
381/// The `llama-3.2-70b-versatile` model. Used for chat completion.
382pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
383/// The `llama-guard-3-8b` model. Used for chat completion.
384pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
385/// The `llama3-70b-8192` model. Used for chat completion.
386pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
387/// The `llama3-8b-8192` model. Used for chat completion.
388pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
389/// The `mixtral-8x7b-32768` model. Used for chat completion.
390pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
391
392#[derive(Clone, Debug)]
393pub struct CompletionModel<T> {
394    client: Client<T>,
395    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
396    pub model: String,
397}
398
399impl<T> CompletionModel<T> {
400    pub fn new(client: Client<T>, model: &str) -> Self {
401        Self {
402            client,
403            model: model.to_string(),
404        }
405    }
406
407    fn create_completion_request(
408        &self,
409        completion_request: CompletionRequest,
410    ) -> Result<Value, CompletionError> {
411        // Build up the order of messages (context, chat_history, prompt)
412        let mut partial_history = vec![];
413        if let Some(docs) = completion_request.normalized_documents() {
414            partial_history.push(docs);
415        }
416        partial_history.extend(completion_request.chat_history);
417
418        // Initialize full history with preamble (or empty if non-existent)
419        let mut full_history: Vec<Message> =
420            completion_request
421                .preamble
422                .map_or_else(Vec::new, |preamble| {
423                    vec![Message {
424                        role: "system".to_string(),
425                        content: Some(preamble),
426                        reasoning: None,
427                    }]
428                });
429
430        // Convert and extend the rest of the history
431        full_history.extend(
432            partial_history
433                .into_iter()
434                .map(message::Message::try_into)
435                .collect::<Result<Vec<Message>, _>>()?,
436        );
437
438        let tool_choice = completion_request
439            .tool_choice
440            .map(crate::providers::openai::ToolChoice::try_from)
441            .transpose()?;
442
443        let request = if completion_request.tools.is_empty() {
444            json!({
445                "model": self.model,
446                "messages": full_history,
447                "temperature": completion_request.temperature,
448            })
449        } else {
450            json!({
451                "model": self.model,
452                "messages": full_history,
453                "temperature": completion_request.temperature,
454                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
455                "tool_choice": tool_choice,
456                "reasoning_format": "parsed"
457            })
458        };
459
460        let request = if let Some(params) = completion_request.additional_params {
461            json_utils::merge(request, params)
462        } else {
463            request
464        };
465
466        Ok(request)
467    }
468}
469
470impl<T> completion::CompletionModel for CompletionModel<T>
471where
472    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
473{
474    type Response = CompletionResponse;
475    type StreamingResponse = StreamingCompletionResponse;
476
477    #[cfg_attr(feature = "worker", worker::send)]
478    async fn completion(
479        &self,
480        completion_request: CompletionRequest,
481    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
482        let preamble = completion_request.preamble.clone();
483
484        let request = self.create_completion_request(completion_request)?;
485        let span = if tracing::Span::current().is_disabled() {
486            info_span!(
487                target: "rig::completions",
488                "chat",
489                gen_ai.operation.name = "chat",
490                gen_ai.provider.name = "groq",
491                gen_ai.request.model = self.model,
492                gen_ai.system_instructions = preamble,
493                gen_ai.response.id = tracing::field::Empty,
494                gen_ai.response.model = tracing::field::Empty,
495                gen_ai.usage.output_tokens = tracing::field::Empty,
496                gen_ai.usage.input_tokens = tracing::field::Empty,
497                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
498                gen_ai.output.messages = tracing::field::Empty,
499            )
500        } else {
501            tracing::Span::current()
502        };
503
504        let body = serde_json::to_vec(&request)?;
505        let req = self
506            .client
507            .req(Method::POST, "/chat/completions")?
508            .header("Content-Type", "application/json")
509            .body(body)
510            .map_err(|e| http_client::Error::Instance(e.into()))?;
511
512        let async_block = async move {
513            let response = self.client.http_client.send::<_, Bytes>(req).await?;
514            let status = response.status();
515            let response_body = response.into_body().into_future().await?.to_vec();
516
517            if status.is_success() {
518                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
519                    ApiResponse::Ok(response) => {
520                        let span = tracing::Span::current();
521                        span.record("gen_ai.response.id", response.id.clone());
522                        span.record("gen_ai.response.model_name", response.model.clone());
523                        span.record(
524                            "gen_ai.output.messages",
525                            serde_json::to_string(&response.choices).unwrap(),
526                        );
527                        if let Some(ref usage) = response.usage {
528                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
529                            span.record(
530                                "gen_ai.usage.output_tokens",
531                                usage.total_tokens - usage.prompt_tokens,
532                            );
533                        }
534                        response.try_into()
535                    }
536                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
537                }
538            } else {
539                Err(CompletionError::ProviderError(
540                    String::from_utf8_lossy(&response_body).to_string(),
541                ))
542            }
543        };
544
545        tracing::Instrument::instrument(async_block, span).await
546    }
547
548    #[cfg_attr(feature = "worker", worker::send)]
549    async fn stream(
550        &self,
551        request: CompletionRequest,
552    ) -> Result<
553        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
554        CompletionError,
555    > {
556        let preamble = request.preamble.clone();
557        let mut request = self.create_completion_request(request)?;
558
559        request = merge(
560            request,
561            json!({"stream": true, "stream_options": {"include_usage": true}}),
562        );
563
564        let body = serde_json::to_vec(&request)?;
565        let req = self
566            .client
567            .req(Method::POST, "/chat/completions")?
568            .header("Content-Type", "application/json")
569            .body(body)
570            .map_err(|e| http_client::Error::Instance(e.into()))?;
571
572        let span = if tracing::Span::current().is_disabled() {
573            info_span!(
574                target: "rig::completions",
575                "chat_streaming",
576                gen_ai.operation.name = "chat_streaming",
577                gen_ai.provider.name = "groq",
578                gen_ai.request.model = self.model,
579                gen_ai.system_instructions = preamble,
580                gen_ai.response.id = tracing::field::Empty,
581                gen_ai.response.model = tracing::field::Empty,
582                gen_ai.usage.output_tokens = tracing::field::Empty,
583                gen_ai.usage.input_tokens = tracing::field::Empty,
584                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
585                gen_ai.output.messages = tracing::field::Empty,
586            )
587        } else {
588            tracing::Span::current()
589        };
590
591        tracing::Instrument::instrument(
592            send_compatible_streaming_request(self.client.http_client.clone(), req),
593            span,
594        )
595        .await
596    }
597}
598
599// ================================================================
600// Groq Transcription API
601// ================================================================
602pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
603pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
604pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
605
606#[derive(Clone)]
607pub struct TranscriptionModel<T> {
608    client: Client<T>,
609    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
610    pub model: String,
611}
612
613impl<T> TranscriptionModel<T> {
614    pub fn new(client: Client<T>, model: &str) -> Self {
615        Self {
616            client,
617            model: model.to_string(),
618        }
619    }
620}
621impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
622where
623    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
624{
625    type Response = TranscriptionResponse;
626
627    #[cfg_attr(feature = "worker", worker::send)]
628    async fn transcription(
629        &self,
630        request: transcription::TranscriptionRequest,
631    ) -> Result<
632        transcription::TranscriptionResponse<Self::Response>,
633        transcription::TranscriptionError,
634    > {
635        let data = request.data;
636
637        let mut body = reqwest::multipart::Form::new()
638            .text("model", self.model.clone())
639            .part(
640                "file",
641                Part::bytes(data).file_name(request.filename.clone()),
642            );
643
644        if let Some(language) = request.language {
645            body = body.text("language", language);
646        }
647
648        if let Some(prompt) = request.prompt {
649            body = body.text("prompt", prompt.clone());
650        }
651
652        if let Some(ref temperature) = request.temperature {
653            body = body.text("temperature", temperature.to_string());
654        }
655
656        if let Some(ref additional_params) = request.additional_params {
657            for (key, value) in additional_params
658                .as_object()
659                .expect("Additional Parameters to OpenAI Transcription should be a map")
660            {
661                body = body.text(key.to_owned(), value.to_string());
662            }
663        }
664
665        let req = self
666            .client
667            .req(Method::POST, "/audio/transcriptions")?
668            .body(body)
669            .unwrap();
670
671        let response = self
672            .client
673            .http_client
674            .send_multipart::<Bytes>(req)
675            .await
676            .unwrap();
677
678        let status = response.status();
679        let response_body = response.into_body().into_future().await?.to_vec();
680
681        if status.is_success() {
682            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
683                ApiResponse::Ok(response) => response.try_into(),
684                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
685                    api_error_response.message,
686                )),
687            }
688        } else {
689            Err(TranscriptionError::ProviderError(
690                String::from_utf8_lossy(&response_body).to_string(),
691            ))
692        }
693    }
694}
695
696#[derive(Deserialize, Debug)]
697#[serde(untagged)]
698pub enum StreamingDelta {
699    Reasoning {
700        reasoning: String,
701    },
702    MessageContent {
703        #[serde(default)]
704        content: Option<String>,
705        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
706        tool_calls: Vec<StreamingToolCall>,
707    },
708}
709
710#[derive(Deserialize, Debug)]
711struct StreamingChoice {
712    delta: StreamingDelta,
713}
714
715#[derive(Deserialize, Debug)]
716struct StreamingCompletionChunk {
717    choices: Vec<StreamingChoice>,
718    usage: Option<Usage>,
719}
720
721#[derive(Clone, Deserialize, Serialize, Debug)]
722pub struct StreamingCompletionResponse {
723    pub usage: Usage,
724}
725
726impl GetTokenUsage for StreamingCompletionResponse {
727    fn token_usage(&self) -> Option<crate::completion::Usage> {
728        let mut usage = crate::completion::Usage::new();
729
730        usage.input_tokens = self.usage.prompt_tokens as u64;
731        usage.total_tokens = self.usage.total_tokens as u64;
732        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
733
734        Some(usage)
735    }
736}
737
738pub async fn send_compatible_streaming_request<T>(
739    client: T,
740    req: Request<Vec<u8>>,
741) -> Result<
742    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
743    CompletionError,
744>
745where
746    T: HttpClientExt + Clone + 'static,
747{
748    let span = tracing::Span::current();
749
750    let mut event_source = GenericEventSource::new(client, req);
751
752    let stream = stream! {
753        let span = tracing::Span::current();
754        let mut final_usage = Usage {
755            prompt_tokens: 0,
756            total_tokens: 0
757        };
758
759        let mut text_response = String::new();
760
761        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
762
763        while let Some(event_result) = event_source.next().await {
764            match event_result {
765                Ok(Event::Open) => {
766                    tracing::trace!("SSE connection opened");
767                    continue;
768                }
769
770                Ok(Event::Message(message)) => {
771                    let data_str = message.data.trim();
772
773                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
774                    let Ok(data) = parsed else {
775                        let err = parsed.unwrap_err();
776                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
777                        continue;
778                    };
779
780                    if let Some(choice) = data.choices.first() {
781                        match &choice.delta {
782                            StreamingDelta::Reasoning { reasoning } => {
783                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
784                                    id: None,
785                                    reasoning: reasoning.to_string(),
786                                    signature: None,
787                                });
788                            }
789
790                            StreamingDelta::MessageContent { content, tool_calls } => {
791                                // Handle tool calls
792                                for tool_call in tool_calls {
793                                    let function = &tool_call.function;
794
795                                    // Start of tool call
796                                    if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
797                                        && function.arguments.is_empty()
798                                    {
799                                        let id = tool_call.id.clone().unwrap_or_default();
800                                        let name = function.name.clone().unwrap();
801                                        calls.insert(tool_call.index, (id, name, String::new()));
802                                    }
803                                    // Continuation
804                                    else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
805                                        && !function.arguments.is_empty()
806                                    {
807                                        if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
808                                            let combined = format!("{}{}", existing_args, function.arguments);
809                                            calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
810                                        } else {
811                                            tracing::debug!("Partial tool call received but tool call was never started.");
812                                        }
813                                    }
814                                    // Complete tool call
815                                    else {
816                                        let id = tool_call.id.clone().unwrap_or_default();
817                                        let name = function.name.clone().unwrap_or_default();
818                                        let arguments_str = function.arguments.clone();
819
820                                        let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
821                                            tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
822                                            continue;
823                                        };
824
825                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
826                                            id,
827                                            name,
828                                            arguments: arguments_json,
829                                            call_id: None
830                                        });
831                                    }
832                                }
833
834                                // Streamed content
835                                if let Some(content) = content {
836                                    text_response += content;
837                                    yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
838                                }
839                            }
840                        }
841                    }
842
843                    if let Some(usage) = data.usage {
844                        final_usage = usage.clone();
845                    }
846                }
847
848                Err(crate::http_client::Error::StreamEnded) => break,
849                Err(err) => {
850                    tracing::error!(?err, "SSE error");
851                    yield Err(CompletionError::ResponseError(err.to_string()));
852                    break;
853                }
854            }
855        }
856
857        event_source.close();
858
859        let mut tool_calls = Vec::new();
860        // Flush accumulated tool calls
861        for (_, (id, name, arguments)) in calls {
862            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
863                continue;
864            };
865
866            tool_calls.push(rig::providers::openai::completion::ToolCall {
867                id: id.clone(),
868                r#type: ToolType::Function,
869                function: Function {
870                    name: name.clone(),
871                    arguments: arguments_json.clone()
872                }
873            });
874            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
875                id,
876                name,
877                arguments: arguments_json,
878                call_id: None,
879            });
880        }
881
882        let response_message = crate::providers::openai::completion::Message::Assistant {
883            content: vec![AssistantContent::Text { text: text_response }],
884            refusal: None,
885            audio: None,
886            name: None,
887            tool_calls
888        };
889
890        span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
891        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
892        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
893
894        // Final response
895        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
896            StreamingCompletionResponse { usage: final_usage.clone() }
897        ));
898    }.instrument(span);
899
900    Ok(crate::streaming::StreamingCompletionResponse::stream(
901        Box::pin(stream),
902    ))
903}