Skip to main content

rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18// ================================================================
19// OpenAI Completion Streaming API
20// ================================================================
21#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23    pub(crate) name: Option<String>,
24    pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29    pub(crate) index: usize,
30    pub(crate) id: Option<String>,
31    pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36    #[serde(default)]
37    content: Option<String>,
38    #[serde(default)]
39    reasoning_content: Option<String>, // This is not part of the official OpenAI API
40    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41    tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug, PartialEq)]
45#[serde(rename_all = "snake_case")]
46pub enum FinishReason {
47    ToolCalls,
48    Stop,
49    ContentFilter,
50    Length,
51    #[serde(untagged)]
52    Other(String), // This will handle the deprecated function_call
53}
54
55#[derive(Deserialize, Debug)]
56struct StreamingChoice {
57    delta: StreamingDelta,
58    finish_reason: Option<FinishReason>,
59}
60
61#[derive(Deserialize, Debug)]
62struct StreamingCompletionChunk {
63    choices: Vec<StreamingChoice>,
64    usage: Option<Usage>,
65}
66
67#[derive(Clone, Serialize, Deserialize)]
68pub struct StreamingCompletionResponse {
69    pub usage: Usage,
70}
71
72impl GetTokenUsage for StreamingCompletionResponse {
73    fn token_usage(&self) -> Option<crate::completion::Usage> {
74        let mut usage = crate::completion::Usage::new();
75        usage.input_tokens = self.usage.prompt_tokens as u64;
76        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
77        usage.total_tokens = self.usage.total_tokens as u64;
78        Some(usage)
79    }
80}
81
82impl<T> CompletionModel<T>
83where
84    T: HttpClientExt + Clone + 'static,
85{
86    pub(crate) async fn stream(
87        &self,
88        completion_request: CompletionRequest,
89    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
90    {
91        let request = super::CompletionRequest::try_from(OpenAIRequestParams {
92            model: self.model.clone(),
93            request: completion_request,
94            strict_tools: self.strict_tools,
95            tool_result_array_content: self.tool_result_array_content,
96        })?;
97        let request_messages = serde_json::to_string(&request.messages)
98            .expect("Converting to JSON from a Rust struct shouldn't fail");
99        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
100
101        request_as_json = merge(
102            request_as_json,
103            json!({"stream": true, "stream_options": {"include_usage": true}}),
104        );
105
106        if enabled!(Level::TRACE) {
107            tracing::trace!(
108                target: "rig::completions",
109                "OpenAI Chat Completions streaming completion request: {}",
110                serde_json::to_string_pretty(&request_as_json)?
111            );
112        }
113
114        let req_body = serde_json::to_vec(&request_as_json)?;
115
116        let req = self
117            .client
118            .post("/chat/completions")?
119            .body(req_body)
120            .map_err(|e| CompletionError::HttpError(e.into()))?;
121
122        let span = if tracing::Span::current().is_disabled() {
123            info_span!(
124                target: "rig::completions",
125                "chat",
126                gen_ai.operation.name = "chat",
127                gen_ai.provider.name = "openai",
128                gen_ai.request.model = self.model,
129                gen_ai.response.id = tracing::field::Empty,
130                gen_ai.response.model = self.model,
131                gen_ai.usage.output_tokens = tracing::field::Empty,
132                gen_ai.usage.input_tokens = tracing::field::Empty,
133                gen_ai.input.messages = request_messages,
134                gen_ai.output.messages = tracing::field::Empty,
135            )
136        } else {
137            tracing::Span::current()
138        };
139
140        let client = self.client.clone();
141
142        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
143    }
144}
145
146pub async fn send_compatible_streaming_request<T>(
147    http_client: T,
148    req: Request<Vec<u8>>,
149) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
150where
151    T: HttpClientExt + Clone + 'static,
152{
153    let span = tracing::Span::current();
154    // Build the request with proper headers for SSE
155    let mut event_source = GenericEventSource::new(http_client, req);
156
157    let stream = stream! {
158        let span = tracing::Span::current();
159
160        // Accumulate tool calls by index while streaming
161        let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
162        let mut text_content = String::new();
163        let mut final_usage = None;
164
165        while let Some(event_result) = event_source.next().await {
166            match event_result {
167                Ok(Event::Open) => {
168                    tracing::trace!("SSE connection opened");
169                    continue;
170                }
171
172                Ok(Event::Message(message)) => {
173                    if message.data.trim().is_empty() || message.data == "[DONE]" {
174                        continue;
175                    }
176
177                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
178                        Ok(data) => data,
179                        Err(error) => {
180                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
181                            continue;
182                        }
183                    };
184
185                    // Usage updates (some providers send a final "usage-only" chunk with empty choices)
186                    if let Some(usage) = data.usage {
187                        final_usage = Some(usage);
188                    }
189
190                    // Expect at least one choice
191                     let Some(choice) = data.choices.first() else {
192                        tracing::debug!("There is no choice");
193                        continue;
194                    };
195                    let delta = &choice.delta;
196
197                    if !delta.tool_calls.is_empty() {
198                        for tool_call in &delta.tool_calls {
199                            let index = tool_call.index;
200
201                            // Some API gateways (e.g. LiteLLM, OneAPI) emit multiple
202                            // distinct tool calls all sharing index 0.  Detect this by
203                            // comparing the provider-supplied `id`: if a new, non-empty
204                            // id arrives for an index that already has a different id,
205                            // flush the old entry as a completed tool call first.
206                            if let Some(new_id) = &tool_call.id
207                                && !new_id.is_empty()
208                                && let Some(existing) = tool_calls.get(&index)
209                                && !existing.id.is_empty()
210                                && existing.id != *new_id
211                            {
212                                let evicted = tool_calls.remove(&index).expect("checked above");
213                                yield Ok(streaming::RawStreamingChoice::ToolCall(evicted));
214                            }
215
216                            let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
217
218                            if let Some(id) = &tool_call.id && !id.is_empty() {
219                                existing_tool_call.id = id.clone();
220                            }
221
222                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
223                                existing_tool_call.name = name.clone();
224                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
225                                    id: existing_tool_call.id.clone(),
226                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
227                                    content: streaming::ToolCallDeltaContent::Name(name.clone()),
228                                });
229                            }
230
231                            // Convert current arguments to string if needed
232                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
233                                let current_args = match &existing_tool_call.arguments {
234                                    serde_json::Value::Null => String::new(),
235                                    serde_json::Value::String(s) => s.clone(),
236                                    v => v.to_string(),
237                                };
238
239                                // Concatenate the new chunk
240                                let combined = format!("{current_args}{chunk}");
241
242                                // Try to parse as JSON if it looks complete
243                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
244                                    match serde_json::from_str(&combined) {
245                                        Ok(parsed) => existing_tool_call.arguments = parsed,
246                                        Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
247                                    }
248                                } else {
249                                    existing_tool_call.arguments = serde_json::Value::String(combined);
250                                }
251
252                                // Emit the delta so UI can show progress
253                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
254                                    id: existing_tool_call.id.clone(),
255                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
256                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
257                                });
258                            }
259                        }
260                    }
261
262                    // Streamed reasoning/thinking content (e.g. GLM-4, DeepSeek via compatible endpoint)
263                    if let Some(reasoning) = &delta.reasoning_content && !reasoning.is_empty() {
264                        yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
265                            id: None,
266                            reasoning: reasoning.clone(),
267                        });
268                    }
269
270                    // Streamed text content
271                    if let Some(content) = &delta.content && !content.is_empty() {
272                        text_content += content;
273                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
274                    }
275
276                    // Finish reason
277                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
278                        for (_idx, tool_call) in tool_calls.into_iter() {
279                            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
280                        }
281                        tool_calls = HashMap::new();
282                    }
283                }
284                Err(crate::http_client::Error::StreamEnded) => {
285                    break;
286                }
287                Err(error) => {
288                    tracing::error!(?error, "SSE error");
289                    yield Err(CompletionError::ProviderError(error.to_string()));
290                    break;
291                }
292            }
293        }
294
295
296        // Ensure event source is closed when stream ends
297        event_source.close();
298
299        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
300        for (_idx, tool_call) in tool_calls.into_iter() {
301            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
302        }
303
304        let final_usage = final_usage.unwrap_or_default();
305        if !span.is_disabled() {
306            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
307            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
308        }
309
310        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
311            usage: final_usage
312        }));
313    }.instrument(span);
314
315    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
316        stream,
317    )))
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_streaming_function_deserialization() {
326        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
327        let function: StreamingFunction = serde_json::from_str(json).unwrap();
328        assert_eq!(function.name, Some("get_weather".to_string()));
329        assert_eq!(
330            function.arguments.as_ref().unwrap(),
331            r#"{"location":"Paris"}"#
332        );
333    }
334
335    #[test]
336    fn test_streaming_tool_call_deserialization() {
337        let json = r#"{
338            "index": 0,
339            "id": "call_abc123",
340            "function": {
341                "name": "get_weather",
342                "arguments": "{\"city\":\"London\"}"
343            }
344        }"#;
345        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
346        assert_eq!(tool_call.index, 0);
347        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
348        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
349    }
350
351    #[test]
352    fn test_streaming_tool_call_partial_deserialization() {
353        // Partial tool calls have no name and partial arguments
354        let json = r#"{
355            "index": 0,
356            "id": null,
357            "function": {
358                "name": null,
359                "arguments": "Paris"
360            }
361        }"#;
362        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
363        assert_eq!(tool_call.index, 0);
364        assert!(tool_call.id.is_none());
365        assert!(tool_call.function.name.is_none());
366        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
367    }
368
369    #[test]
370    fn test_streaming_delta_with_tool_calls() {
371        let json = r#"{
372            "content": null,
373            "tool_calls": [{
374                "index": 0,
375                "id": "call_xyz",
376                "function": {
377                    "name": "search",
378                    "arguments": ""
379                }
380            }]
381        }"#;
382        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
383        assert!(delta.content.is_none());
384        assert_eq!(delta.tool_calls.len(), 1);
385        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
386    }
387
388    #[test]
389    fn test_streaming_chunk_deserialization() {
390        let json = r#"{
391            "choices": [{
392                "delta": {
393                    "content": "Hello",
394                    "tool_calls": []
395                }
396            }],
397            "usage": {
398                "prompt_tokens": 10,
399                "completion_tokens": 5,
400                "total_tokens": 15
401            }
402        }"#;
403        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
404        assert_eq!(chunk.choices.len(), 1);
405        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
406        assert!(chunk.usage.is_some());
407    }
408
409    #[test]
410    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
411        // Simulates multiple partial tool call chunks arriving
412        let json_start = r#"{
413            "choices": [{
414                "delta": {
415                    "content": null,
416                    "tool_calls": [{
417                        "index": 0,
418                        "id": "call_123",
419                        "function": {
420                            "name": "get_weather",
421                            "arguments": ""
422                        }
423                    }]
424                }
425            }],
426            "usage": null
427        }"#;
428
429        let json_chunk1 = r#"{
430            "choices": [{
431                "delta": {
432                    "content": null,
433                    "tool_calls": [{
434                        "index": 0,
435                        "id": null,
436                        "function": {
437                            "name": null,
438                            "arguments": "{\"loc"
439                        }
440                    }]
441                }
442            }],
443            "usage": null
444        }"#;
445
446        let json_chunk2 = r#"{
447            "choices": [{
448                "delta": {
449                    "content": null,
450                    "tool_calls": [{
451                        "index": 0,
452                        "id": null,
453                        "function": {
454                            "name": null,
455                            "arguments": "ation\":\"NYC\"}"
456                        }
457                    }]
458                }
459            }],
460            "usage": null
461        }"#;
462
463        // Verify each chunk deserializes correctly
464        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
465        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
466        assert_eq!(
467            start_chunk.choices[0].delta.tool_calls[0]
468                .function
469                .name
470                .as_ref()
471                .unwrap(),
472            "get_weather"
473        );
474
475        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
476        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
477        assert_eq!(
478            chunk1.choices[0].delta.tool_calls[0]
479                .function
480                .arguments
481                .as_ref()
482                .unwrap(),
483            "{\"loc"
484        );
485
486        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
487        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
488        assert_eq!(
489            chunk2.choices[0].delta.tool_calls[0]
490                .function
491                .arguments
492                .as_ref()
493                .unwrap(),
494            "ation\":\"NYC\"}"
495        );
496    }
497
498    #[tokio::test]
499    async fn test_streaming_usage_only_chunk_is_not_ignored() {
500        use crate::http_client::mock::MockStreamingClient;
501        use bytes::Bytes;
502        use futures::StreamExt;
503
504        // Some providers emit a final "usage-only" chunk where `choices` is empty.
505        let sse = concat!(
506            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
507            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
508            "data: [DONE]\n\n",
509        );
510
511        let client = MockStreamingClient {
512            sse_bytes: Bytes::from(sse),
513        };
514
515        let req = http::Request::builder()
516            .method("POST")
517            .uri("http://localhost/v1/chat/completions")
518            .body(Vec::new())
519            .unwrap();
520
521        let mut stream = send_compatible_streaming_request(client, req)
522            .await
523            .unwrap();
524
525        let mut final_usage = None;
526        while let Some(chunk) = stream.next().await {
527            if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
528                final_usage = Some(res.usage);
529                break;
530            }
531        }
532
533        let usage = final_usage.expect("expected a final response with usage");
534        assert_eq!(usage.prompt_tokens, 10);
535        assert_eq!(usage.total_tokens, 15);
536    }
537
538    /// Reproduces the bug where a proxy/gateway sends multiple parallel tool
539    /// calls all sharing `index: 0` but with distinct `id` values.  Without
540    /// the fix, rig merges both calls into one corrupted entry.
541    #[tokio::test]
542    async fn test_duplicate_index_different_id_tool_calls() {
543        use crate::http_client::mock::MockStreamingClient;
544        use bytes::Bytes;
545        use futures::StreamExt;
546
547        // Simulate a gateway that sends two tool calls both at index 0.
548        // First tool call: id="call_aaa", name="command", args={"cmd":"ls"}
549        // Second tool call: id="call_bbb", name="git", args={"action":"log"}
550        let sse = concat!(
551            // First tool call starts
552            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_aaa\",\"function\":{\"name\":\"command\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
553            // First tool call argument chunks
554            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"cmd\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
555            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"ls\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
556            // Second tool call starts AT THE SAME index 0 but with a NEW id
557            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_bbb\",\"function\":{\"name\":\"git\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
558            // Second tool call argument chunks
559            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"action\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
560            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"log\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
561            // Finish with tool_calls reason
562            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
563            // Usage chunk
564            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":20,\"completion_tokens\":10,\"total_tokens\":30}}\n\n",
565            "data: [DONE]\n\n",
566        );
567
568        let client = MockStreamingClient {
569            sse_bytes: Bytes::from(sse),
570        };
571
572        let req = http::Request::builder()
573            .method("POST")
574            .uri("http://localhost/v1/chat/completions")
575            .body(Vec::new())
576            .unwrap();
577
578        let mut stream = send_compatible_streaming_request(client, req)
579            .await
580            .unwrap();
581
582        let mut collected_tool_calls = Vec::new();
583        while let Some(chunk) = stream.next().await {
584            if let streaming::StreamedAssistantContent::ToolCall {
585                tool_call,
586                internal_call_id: _,
587            } = chunk.unwrap()
588            {
589                collected_tool_calls.push(tool_call);
590            }
591        }
592
593        assert_eq!(
594            collected_tool_calls.len(),
595            2,
596            "expected 2 separate tool calls, got {collected_tool_calls:?}"
597        );
598
599        assert_eq!(collected_tool_calls[0].id, "call_aaa");
600        assert_eq!(collected_tool_calls[0].function.name, "command");
601        assert_eq!(
602            collected_tool_calls[0].function.arguments,
603            serde_json::json!({"cmd": "ls"})
604        );
605
606        assert_eq!(collected_tool_calls[1].id, "call_bbb");
607        assert_eq!(collected_tool_calls[1].function.name, "git");
608        assert_eq!(
609            collected_tool_calls[1].function.arguments,
610            serde_json::json!({"action": "log"})
611        );
612    }
613}