Skip to main content

rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{self, CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18// ================================================================
19// OpenAI Completion Streaming API
20// ================================================================
21#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23    pub(crate) name: Option<String>,
24    pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29    pub(crate) index: usize,
30    pub(crate) id: Option<String>,
31    pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36    #[serde(default)]
37    content: Option<String>,
38    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
39    tool_calls: Vec<StreamingToolCall>,
40}
41
42#[derive(Deserialize, Debug, PartialEq)]
43#[serde(rename_all = "snake_case")]
44pub enum FinishReason {
45    ToolCalls,
46    Stop,
47    ContentFilter,
48    Length,
49    #[serde(untagged)]
50    Other(String), // This will handle the deprecated function_call
51}
52
53#[derive(Deserialize, Debug)]
54struct StreamingChoice {
55    delta: StreamingDelta,
56    finish_reason: Option<FinishReason>,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingCompletionChunk {
61    choices: Vec<StreamingChoice>,
62    usage: Option<Usage>,
63}
64
65#[derive(Clone, Serialize, Deserialize)]
66pub struct StreamingCompletionResponse {
67    pub usage: Usage,
68}
69
70impl GetTokenUsage for StreamingCompletionResponse {
71    fn token_usage(&self) -> Option<crate::completion::Usage> {
72        let mut usage = crate::completion::Usage::new();
73        usage.input_tokens = self.usage.prompt_tokens as u64;
74        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
75        usage.total_tokens = self.usage.total_tokens as u64;
76        Some(usage)
77    }
78}
79
80impl<T> CompletionModel<T>
81where
82    T: HttpClientExt + Clone + 'static,
83{
84    pub(crate) async fn stream(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
88    {
89        let request = super::CompletionRequest::try_from(OpenAIRequestParams {
90            model: self.model.clone(),
91            request: completion_request,
92            strict_tools: self.strict_tools,
93            tool_result_array_content: self.tool_result_array_content,
94        })?;
95        let request_messages = serde_json::to_string(&request.messages)
96            .expect("Converting to JSON from a Rust struct shouldn't fail");
97        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
98
99        request_as_json = merge(
100            request_as_json,
101            json!({"stream": true, "stream_options": {"include_usage": true}}),
102        );
103
104        if enabled!(Level::TRACE) {
105            tracing::trace!(
106                target: "rig::completions",
107                "OpenAI Chat Completions streaming completion request: {}",
108                serde_json::to_string_pretty(&request_as_json)?
109            );
110        }
111
112        let req_body = serde_json::to_vec(&request_as_json)?;
113
114        let req = self
115            .client
116            .post("/chat/completions")?
117            .body(req_body)
118            .map_err(|e| CompletionError::HttpError(e.into()))?;
119
120        let span = if tracing::Span::current().is_disabled() {
121            info_span!(
122                target: "rig::completions",
123                "chat",
124                gen_ai.operation.name = "chat",
125                gen_ai.provider.name = "openai",
126                gen_ai.request.model = self.model,
127                gen_ai.response.id = tracing::field::Empty,
128                gen_ai.response.model = self.model,
129                gen_ai.usage.output_tokens = tracing::field::Empty,
130                gen_ai.usage.input_tokens = tracing::field::Empty,
131                gen_ai.input.messages = request_messages,
132                gen_ai.output.messages = tracing::field::Empty,
133            )
134        } else {
135            tracing::Span::current()
136        };
137
138        let client = self.client.clone();
139
140        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
141    }
142}
143
144pub async fn send_compatible_streaming_request<T>(
145    http_client: T,
146    req: Request<Vec<u8>>,
147) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
148where
149    T: HttpClientExt + Clone + 'static,
150{
151    let span = tracing::Span::current();
152    // Build the request with proper headers for SSE
153    let mut event_source = GenericEventSource::new(http_client, req);
154
155    let stream = stream! {
156        let span = tracing::Span::current();
157
158        // Accumulate tool calls by index while streaming
159        let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
160        let mut text_content = String::new();
161        let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
162        let mut final_usage = None;
163
164        while let Some(event_result) = event_source.next().await {
165            match event_result {
166                Ok(Event::Open) => {
167                    tracing::trace!("SSE connection opened");
168                    continue;
169                }
170
171                Ok(Event::Message(message)) => {
172                    if message.data.trim().is_empty() || message.data == "[DONE]" {
173                        continue;
174                    }
175
176                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
177                        Ok(data) => data,
178                        Err(error) => {
179                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
180                            continue;
181                        }
182                    };
183
184                    // Usage updates (some providers send a final "usage-only" chunk with empty choices)
185                    if let Some(usage) = data.usage {
186                        final_usage = Some(usage);
187                    }
188
189                    // Expect at least one choice
190                     let Some(choice) = data.choices.first() else {
191                        tracing::debug!("There is no choice");
192                        continue;
193                    };
194                    let delta = &choice.delta;
195
196                    if !delta.tool_calls.is_empty() {
197                        for tool_call in &delta.tool_calls {
198                            let index = tool_call.index;
199
200                            // Get or create tool call entry
201                            let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
202
203                            // Update fields if present
204                            if let Some(id) = &tool_call.id && !id.is_empty() {
205                                    existing_tool_call.id = id.clone();
206                            }
207
208                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
209                                    existing_tool_call.name = name.clone();
210                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
211                                        id: existing_tool_call.id.clone(),
212                                        internal_call_id: existing_tool_call.internal_call_id.clone(),
213                                        content: streaming::ToolCallDeltaContent::Name(name.clone()),
214                                    });
215                            }
216
217                                // Convert current arguments to string if needed
218                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
219                                let current_args = match &existing_tool_call.arguments {
220                                    serde_json::Value::Null => String::new(),
221                                    serde_json::Value::String(s) => s.clone(),
222                                    v => v.to_string(),
223                                };
224
225                                // Concatenate the new chunk
226                                let combined = format!("{current_args}{chunk}");
227
228                                // Try to parse as JSON if it looks complete
229                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
230                                    match serde_json::from_str(&combined) {
231                                        Ok(parsed) => existing_tool_call.arguments = parsed,
232                                        Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
233                                    }
234                                } else {
235                                    existing_tool_call.arguments = serde_json::Value::String(combined);
236                                }
237
238                                // Emit the delta so UI can show progress
239                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
240                                    id: existing_tool_call.id.clone(),
241                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
242                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
243                                });
244                            }
245                        }
246                    }
247
248                    // Streamed text content
249                    if let Some(content) = &delta.content && !content.is_empty() {
250                        text_content += content;
251                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
252                    }
253
254                    // Finish reason
255                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
256                        for (_idx, tool_call) in tool_calls.into_iter() {
257                            final_tool_calls.push(completion::ToolCall {
258                                id: tool_call.id.clone(),
259                                r#type: completion::ToolType::Function,
260                                function: completion::Function {
261                                    name: tool_call.name.clone(),
262                                    arguments: tool_call.arguments.clone(),
263                                },
264                            });
265                            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
266                        }
267                        tool_calls = HashMap::new();
268                    }
269                }
270                Err(crate::http_client::Error::StreamEnded) => {
271                    break;
272                }
273                Err(error) => {
274                    tracing::error!(?error, "SSE error");
275                    yield Err(CompletionError::ProviderError(error.to_string()));
276                    break;
277                }
278            }
279        }
280
281
282        // Ensure event source is closed when stream ends
283        event_source.close();
284
285        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
286        for (_idx, tool_call) in tool_calls.into_iter() {
287            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
288        }
289
290        let final_usage = final_usage.unwrap_or_default();
291        if !span.is_disabled() {
292            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
293            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
294        }
295
296        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
297            usage: final_usage
298        }));
299    }.instrument(span);
300
301    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
302        stream,
303    )))
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_streaming_function_deserialization() {
312        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
313        let function: StreamingFunction = serde_json::from_str(json).unwrap();
314        assert_eq!(function.name, Some("get_weather".to_string()));
315        assert_eq!(
316            function.arguments.as_ref().unwrap(),
317            r#"{"location":"Paris"}"#
318        );
319    }
320
321    #[test]
322    fn test_streaming_tool_call_deserialization() {
323        let json = r#"{
324            "index": 0,
325            "id": "call_abc123",
326            "function": {
327                "name": "get_weather",
328                "arguments": "{\"city\":\"London\"}"
329            }
330        }"#;
331        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
332        assert_eq!(tool_call.index, 0);
333        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
334        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
335    }
336
337    #[test]
338    fn test_streaming_tool_call_partial_deserialization() {
339        // Partial tool calls have no name and partial arguments
340        let json = r#"{
341            "index": 0,
342            "id": null,
343            "function": {
344                "name": null,
345                "arguments": "Paris"
346            }
347        }"#;
348        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
349        assert_eq!(tool_call.index, 0);
350        assert!(tool_call.id.is_none());
351        assert!(tool_call.function.name.is_none());
352        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
353    }
354
355    #[test]
356    fn test_streaming_delta_with_tool_calls() {
357        let json = r#"{
358            "content": null,
359            "tool_calls": [{
360                "index": 0,
361                "id": "call_xyz",
362                "function": {
363                    "name": "search",
364                    "arguments": ""
365                }
366            }]
367        }"#;
368        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
369        assert!(delta.content.is_none());
370        assert_eq!(delta.tool_calls.len(), 1);
371        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
372    }
373
374    #[test]
375    fn test_streaming_chunk_deserialization() {
376        let json = r#"{
377            "choices": [{
378                "delta": {
379                    "content": "Hello",
380                    "tool_calls": []
381                }
382            }],
383            "usage": {
384                "prompt_tokens": 10,
385                "completion_tokens": 5,
386                "total_tokens": 15
387            }
388        }"#;
389        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
390        assert_eq!(chunk.choices.len(), 1);
391        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
392        assert!(chunk.usage.is_some());
393    }
394
395    #[test]
396    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
397        // Simulates multiple partial tool call chunks arriving
398        let json_start = r#"{
399            "choices": [{
400                "delta": {
401                    "content": null,
402                    "tool_calls": [{
403                        "index": 0,
404                        "id": "call_123",
405                        "function": {
406                            "name": "get_weather",
407                            "arguments": ""
408                        }
409                    }]
410                }
411            }],
412            "usage": null
413        }"#;
414
415        let json_chunk1 = r#"{
416            "choices": [{
417                "delta": {
418                    "content": null,
419                    "tool_calls": [{
420                        "index": 0,
421                        "id": null,
422                        "function": {
423                            "name": null,
424                            "arguments": "{\"loc"
425                        }
426                    }]
427                }
428            }],
429            "usage": null
430        }"#;
431
432        let json_chunk2 = r#"{
433            "choices": [{
434                "delta": {
435                    "content": null,
436                    "tool_calls": [{
437                        "index": 0,
438                        "id": null,
439                        "function": {
440                            "name": null,
441                            "arguments": "ation\":\"NYC\"}"
442                        }
443                    }]
444                }
445            }],
446            "usage": null
447        }"#;
448
449        // Verify each chunk deserializes correctly
450        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
451        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
452        assert_eq!(
453            start_chunk.choices[0].delta.tool_calls[0]
454                .function
455                .name
456                .as_ref()
457                .unwrap(),
458            "get_weather"
459        );
460
461        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
462        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
463        assert_eq!(
464            chunk1.choices[0].delta.tool_calls[0]
465                .function
466                .arguments
467                .as_ref()
468                .unwrap(),
469            "{\"loc"
470        );
471
472        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
473        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
474        assert_eq!(
475            chunk2.choices[0].delta.tool_calls[0]
476                .function
477                .arguments
478                .as_ref()
479                .unwrap(),
480            "ation\":\"NYC\"}"
481        );
482    }
483
484    #[tokio::test]
485    async fn test_streaming_usage_only_chunk_is_not_ignored() {
486        use bytes::Bytes;
487        use futures::StreamExt;
488
489        #[derive(Clone)]
490        struct MockHttpClient {
491            sse_bytes: Bytes,
492        }
493
494        impl crate::http_client::HttpClientExt for MockHttpClient {
495            fn send<T, U>(
496                &self,
497                _req: http::Request<T>,
498            ) -> impl std::future::Future<
499                Output = crate::http_client::Result<
500                    http::Response<crate::http_client::LazyBody<U>>,
501                >,
502            > + crate::wasm_compat::WasmCompatSend
503            + 'static
504            where
505                T: Into<Bytes>,
506                T: crate::wasm_compat::WasmCompatSend,
507                U: From<Bytes>,
508                U: crate::wasm_compat::WasmCompatSend + 'static,
509            {
510                std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
511                    http::StatusCode::NOT_IMPLEMENTED,
512                )))
513            }
514
515            fn send_multipart<U>(
516                &self,
517                _req: http::Request<crate::http_client::MultipartForm>,
518            ) -> impl std::future::Future<
519                Output = crate::http_client::Result<
520                    http::Response<crate::http_client::LazyBody<U>>,
521                >,
522            > + crate::wasm_compat::WasmCompatSend
523            + 'static
524            where
525                U: From<Bytes>,
526                U: crate::wasm_compat::WasmCompatSend + 'static,
527            {
528                std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
529                    http::StatusCode::NOT_IMPLEMENTED,
530                )))
531            }
532
533            fn send_streaming<T>(
534                &self,
535                _req: http::Request<T>,
536            ) -> impl std::future::Future<
537                Output = crate::http_client::Result<crate::http_client::StreamingResponse>,
538            > + crate::wasm_compat::WasmCompatSend
539            where
540                T: Into<Bytes>,
541            {
542                let sse_bytes = self.sse_bytes.clone();
543                async move {
544                    let byte_stream = futures::stream::iter(vec![Ok::<
545                        Bytes,
546                        crate::http_client::Error,
547                    >(sse_bytes)]);
548                    let boxed_stream: crate::http_client::sse::BoxedStream = Box::pin(byte_stream);
549
550                    http::Response::builder()
551                        .status(http::StatusCode::OK)
552                        .header(reqwest::header::CONTENT_TYPE, "text/event-stream")
553                        .body(boxed_stream)
554                        .map_err(crate::http_client::Error::Protocol)
555                }
556            }
557        }
558
559        // Some providers emit a final "usage-only" chunk where `choices` is empty.
560        let sse = concat!(
561            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
562            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
563            "data: [DONE]\n\n",
564        );
565
566        let client = MockHttpClient {
567            sse_bytes: Bytes::from(sse),
568        };
569
570        let req = http::Request::builder()
571            .method("POST")
572            .uri("http://localhost/v1/chat/completions")
573            .body(Vec::new())
574            .unwrap();
575
576        let mut stream = send_compatible_streaming_request(client, req)
577            .await
578            .unwrap();
579
580        let mut final_usage = None;
581        while let Some(chunk) = stream.next().await {
582            if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
583                final_usage = Some(res.usage);
584                break;
585            }
586        }
587
588        let usage = final_usage.expect("expected a final response with usage");
589        assert_eq!(usage.prompt_tokens, 10);
590        assert_eq!(usage.total_tokens, 15);
591    }
592}