rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::message::{ToolCall, ToolFunction};
16use crate::providers::openai::completion::{self, CompletionModel, OpenAIRequestParams, Usage};
17use crate::streaming::{self, RawStreamingChoice};
18
19// ================================================================
20// OpenAI Completion Streaming API
21// ================================================================
22#[derive(Deserialize, Debug)]
23pub(crate) struct StreamingFunction {
24    pub(crate) name: Option<String>,
25    pub(crate) arguments: Option<String>,
26}
27
28#[derive(Deserialize, Debug)]
29pub(crate) struct StreamingToolCall {
30    pub(crate) index: usize,
31    pub(crate) id: Option<String>,
32    pub(crate) function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37    #[serde(default)]
38    content: Option<String>,
39    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40    tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug, PartialEq)]
44#[serde(rename_all = "snake_case")]
45pub enum FinishReason {
46    ToolCalls,
47    Stop,
48    ContentFilter,
49    Length,
50    #[serde(untagged)]
51    Other(String), // This will handle the deprecated function_call
52}
53
54#[derive(Deserialize, Debug)]
55struct StreamingChoice {
56    delta: StreamingDelta,
57    finish_reason: Option<FinishReason>,
58}
59
60#[derive(Deserialize, Debug)]
61struct StreamingCompletionChunk {
62    choices: Vec<StreamingChoice>,
63    usage: Option<Usage>,
64}
65
66#[derive(Clone, Serialize, Deserialize)]
67pub struct StreamingCompletionResponse {
68    pub usage: Usage,
69}
70
71impl GetTokenUsage for StreamingCompletionResponse {
72    fn token_usage(&self) -> Option<crate::completion::Usage> {
73        let mut usage = crate::completion::Usage::new();
74        usage.input_tokens = self.usage.prompt_tokens as u64;
75        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
76        usage.total_tokens = self.usage.total_tokens as u64;
77        Some(usage)
78    }
79}
80
81impl<T> CompletionModel<T>
82where
83    T: HttpClientExt + Clone + 'static,
84{
85    pub(crate) async fn stream(
86        &self,
87        completion_request: CompletionRequest,
88    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
89    {
90        let request = super::CompletionRequest::try_from(OpenAIRequestParams {
91            model: self.model.clone(),
92            request: completion_request,
93            strict_tools: self.strict_tools,
94            tool_result_array_content: self.tool_result_array_content,
95        })?;
96        let request_messages = serde_json::to_string(&request.messages)
97            .expect("Converting to JSON from a Rust struct shouldn't fail");
98        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
99
100        request_as_json = merge(
101            request_as_json,
102            json!({"stream": true, "stream_options": {"include_usage": true}}),
103        );
104
105        if enabled!(Level::TRACE) {
106            tracing::trace!(
107                target: "rig::completions",
108                "OpenAI Chat Completions streaming completion request: {}",
109                serde_json::to_string_pretty(&request_as_json)?
110            );
111        }
112
113        let req_body = serde_json::to_vec(&request_as_json)?;
114
115        let req = self
116            .client
117            .post("/chat/completions")?
118            .body(req_body)
119            .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121        let span = if tracing::Span::current().is_disabled() {
122            info_span!(
123                target: "rig::completions",
124                "chat",
125                gen_ai.operation.name = "chat",
126                gen_ai.provider.name = "openai",
127                gen_ai.request.model = self.model,
128                gen_ai.response.id = tracing::field::Empty,
129                gen_ai.response.model = self.model,
130                gen_ai.usage.output_tokens = tracing::field::Empty,
131                gen_ai.usage.input_tokens = tracing::field::Empty,
132                gen_ai.input.messages = request_messages,
133                gen_ai.output.messages = tracing::field::Empty,
134            )
135        } else {
136            tracing::Span::current()
137        };
138
139        let client = self.client.clone();
140
141        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
142    }
143}
144
145pub async fn send_compatible_streaming_request<T>(
146    http_client: T,
147    req: Request<Vec<u8>>,
148) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
149where
150    T: HttpClientExt + Clone + 'static,
151{
152    let span = tracing::Span::current();
153    // Build the request with proper headers for SSE
154    let mut event_source = GenericEventSource::new(http_client, req);
155
156    let stream = stream! {
157        let span = tracing::Span::current();
158
159        // Accumulate tool calls by index while streaming
160        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
161        let mut text_content = String::new();
162        let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
163        let mut final_usage = None;
164
165        while let Some(event_result) = event_source.next().await {
166            match event_result {
167                Ok(Event::Open) => {
168                    tracing::trace!("SSE connection opened");
169                    continue;
170                }
171
172                Ok(Event::Message(message)) => {
173                    if message.data.trim().is_empty() || message.data == "[DONE]" {
174                        continue;
175                    }
176
177                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
178                        Ok(data) => data,
179                        Err(error) => {
180                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
181                            continue;
182                        }
183                    };
184
185                    // Usage updates (some providers send a final "usage-only" chunk with empty choices)
186                    if let Some(usage) = data.usage {
187                        final_usage = Some(usage);
188                    }
189
190                    // Expect at least one choice
191                     let Some(choice) = data.choices.first() else {
192                        tracing::debug!("There is no choice");
193                        continue;
194                    };
195                    let delta = &choice.delta;
196
197                    if !delta.tool_calls.is_empty() {
198                        for tool_call in &delta.tool_calls {
199                            let index = tool_call.index;
200
201                            // Get or create tool call entry
202                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
203                                id: String::new(),
204                                call_id: None,
205                                function: ToolFunction {
206                                    name: String::new(),
207                                    arguments: serde_json::Value::Null,
208                                },
209                                signature: None,
210                                additional_params: None,
211                            });
212
213                            // Update fields if present
214                            if let Some(id) = &tool_call.id && !id.is_empty() {
215                                    existing_tool_call.id = id.clone();
216                            }
217
218                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
219                                    existing_tool_call.function.name = name.clone();
220                                    yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
221                                        id: existing_tool_call.id.clone(),
222                                        content: streaming::ToolCallDeltaContent::Name(name.clone()),
223                                    });
224                            }
225
226                                // Convert current arguments to string if needed
227                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
228                                let current_args = match &existing_tool_call.function.arguments {
229                                    serde_json::Value::Null => String::new(),
230                                    serde_json::Value::String(s) => s.clone(),
231                                    v => v.to_string(),
232                                };
233
234                                // Concatenate the new chunk
235                                let combined = format!("{current_args}{chunk}");
236
237                                // Try to parse as JSON if it looks complete
238                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
239                                    match serde_json::from_str(&combined) {
240                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
241                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
242                                    }
243                                } else {
244                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
245                                }
246
247                                // Emit the delta so UI can show progress
248                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
249                                    id: existing_tool_call.id.clone(),
250                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
251                                });
252                            }
253                        }
254                    }
255
256                    // Streamed text content
257                    if let Some(content) = &delta.content && !content.is_empty() {
258                        text_content += content;
259                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
260                    }
261
262                    // Finish reason
263                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
264                        for (_idx, tool_call) in tool_calls.into_iter() {
265                            final_tool_calls.push(completion::ToolCall {
266                                id: tool_call.id.clone(),
267                                r#type: completion::ToolType::Function,
268                                function: completion::Function {
269                                    name: tool_call.function.name.clone(),
270                                    arguments: tool_call.function.arguments.clone(),
271                                },
272                            });
273                            yield Ok(streaming::RawStreamingChoice::ToolCall(
274                                streaming::RawStreamingToolCall::new(
275                                    tool_call.id,
276                                    tool_call.function.name,
277                                    tool_call.function.arguments,
278                                )
279                            ));
280                        }
281                        tool_calls = HashMap::new();
282                    }
283                }
284                Err(crate::http_client::Error::StreamEnded) => {
285                    break;
286                }
287                Err(error) => {
288                    tracing::error!(?error, "SSE error");
289                    yield Err(CompletionError::ProviderError(error.to_string()));
290                    break;
291                }
292            }
293        }
294
295
296        // Ensure event source is closed when stream ends
297        event_source.close();
298
299        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
300        for (_idx, tool_call) in tool_calls.into_iter() {
301            yield Ok(streaming::RawStreamingChoice::ToolCall(
302                streaming::RawStreamingToolCall::new(
303                    tool_call.id,
304                    tool_call.function.name,
305                    tool_call.function.arguments,
306                )
307            ));
308        }
309
310        let final_usage = final_usage.unwrap_or_default();
311        if !span.is_disabled() {
312            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
313            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
314        }
315
316        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
317            usage: final_usage
318        }));
319    }.instrument(span);
320
321    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
322        stream,
323    )))
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_streaming_function_deserialization() {
332        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
333        let function: StreamingFunction = serde_json::from_str(json).unwrap();
334        assert_eq!(function.name, Some("get_weather".to_string()));
335        assert_eq!(
336            function.arguments.as_ref().unwrap(),
337            r#"{"location":"Paris"}"#
338        );
339    }
340
341    #[test]
342    fn test_streaming_tool_call_deserialization() {
343        let json = r#"{
344            "index": 0,
345            "id": "call_abc123",
346            "function": {
347                "name": "get_weather",
348                "arguments": "{\"city\":\"London\"}"
349            }
350        }"#;
351        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
352        assert_eq!(tool_call.index, 0);
353        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
354        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
355    }
356
357    #[test]
358    fn test_streaming_tool_call_partial_deserialization() {
359        // Partial tool calls have no name and partial arguments
360        let json = r#"{
361            "index": 0,
362            "id": null,
363            "function": {
364                "name": null,
365                "arguments": "Paris"
366            }
367        }"#;
368        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
369        assert_eq!(tool_call.index, 0);
370        assert!(tool_call.id.is_none());
371        assert!(tool_call.function.name.is_none());
372        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
373    }
374
375    #[test]
376    fn test_streaming_delta_with_tool_calls() {
377        let json = r#"{
378            "content": null,
379            "tool_calls": [{
380                "index": 0,
381                "id": "call_xyz",
382                "function": {
383                    "name": "search",
384                    "arguments": ""
385                }
386            }]
387        }"#;
388        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
389        assert!(delta.content.is_none());
390        assert_eq!(delta.tool_calls.len(), 1);
391        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
392    }
393
394    #[test]
395    fn test_streaming_chunk_deserialization() {
396        let json = r#"{
397            "choices": [{
398                "delta": {
399                    "content": "Hello",
400                    "tool_calls": []
401                }
402            }],
403            "usage": {
404                "prompt_tokens": 10,
405                "completion_tokens": 5,
406                "total_tokens": 15
407            }
408        }"#;
409        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
410        assert_eq!(chunk.choices.len(), 1);
411        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
412        assert!(chunk.usage.is_some());
413    }
414
415    #[test]
416    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
417        // Simulates multiple partial tool call chunks arriving
418        let json_start = r#"{
419            "choices": [{
420                "delta": {
421                    "content": null,
422                    "tool_calls": [{
423                        "index": 0,
424                        "id": "call_123",
425                        "function": {
426                            "name": "get_weather",
427                            "arguments": ""
428                        }
429                    }]
430                }
431            }],
432            "usage": null
433        }"#;
434
435        let json_chunk1 = r#"{
436            "choices": [{
437                "delta": {
438                    "content": null,
439                    "tool_calls": [{
440                        "index": 0,
441                        "id": null,
442                        "function": {
443                            "name": null,
444                            "arguments": "{\"loc"
445                        }
446                    }]
447                }
448            }],
449            "usage": null
450        }"#;
451
452        let json_chunk2 = r#"{
453            "choices": [{
454                "delta": {
455                    "content": null,
456                    "tool_calls": [{
457                        "index": 0,
458                        "id": null,
459                        "function": {
460                            "name": null,
461                            "arguments": "ation\":\"NYC\"}"
462                        }
463                    }]
464                }
465            }],
466            "usage": null
467        }"#;
468
469        // Verify each chunk deserializes correctly
470        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
471        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
472        assert_eq!(
473            start_chunk.choices[0].delta.tool_calls[0]
474                .function
475                .name
476                .as_ref()
477                .unwrap(),
478            "get_weather"
479        );
480
481        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
482        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
483        assert_eq!(
484            chunk1.choices[0].delta.tool_calls[0]
485                .function
486                .arguments
487                .as_ref()
488                .unwrap(),
489            "{\"loc"
490        );
491
492        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
493        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
494        assert_eq!(
495            chunk2.choices[0].delta.tool_calls[0]
496                .function
497                .arguments
498                .as_ref()
499                .unwrap(),
500            "ation\":\"NYC\"}"
501        );
502    }
503
504    #[tokio::test]
505    async fn test_streaming_usage_only_chunk_is_not_ignored() {
506        use bytes::Bytes;
507        use futures::StreamExt;
508
509        #[derive(Clone)]
510        struct MockHttpClient {
511            sse_bytes: Bytes,
512        }
513
514        impl crate::http_client::HttpClientExt for MockHttpClient {
515            fn send<T, U>(
516                &self,
517                _req: http::Request<T>,
518            ) -> impl std::future::Future<
519                Output = crate::http_client::Result<
520                    http::Response<crate::http_client::LazyBody<U>>,
521                >,
522            > + crate::wasm_compat::WasmCompatSend
523            + 'static
524            where
525                T: Into<Bytes>,
526                T: crate::wasm_compat::WasmCompatSend,
527                U: From<Bytes>,
528                U: crate::wasm_compat::WasmCompatSend + 'static,
529            {
530                std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
531                    http::StatusCode::NOT_IMPLEMENTED,
532                )))
533            }
534
535            fn send_multipart<U>(
536                &self,
537                _req: http::Request<crate::http_client::MultipartForm>,
538            ) -> impl std::future::Future<
539                Output = crate::http_client::Result<
540                    http::Response<crate::http_client::LazyBody<U>>,
541                >,
542            > + crate::wasm_compat::WasmCompatSend
543            + 'static
544            where
545                U: From<Bytes>,
546                U: crate::wasm_compat::WasmCompatSend + 'static,
547            {
548                std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
549                    http::StatusCode::NOT_IMPLEMENTED,
550                )))
551            }
552
553            fn send_streaming<T>(
554                &self,
555                _req: http::Request<T>,
556            ) -> impl std::future::Future<
557                Output = crate::http_client::Result<crate::http_client::StreamingResponse>,
558            > + crate::wasm_compat::WasmCompatSend
559            where
560                T: Into<Bytes>,
561            {
562                let sse_bytes = self.sse_bytes.clone();
563                async move {
564                    let byte_stream = futures::stream::iter(vec![Ok::<
565                        Bytes,
566                        crate::http_client::Error,
567                    >(sse_bytes)]);
568                    let boxed_stream: crate::http_client::sse::BoxedStream = Box::pin(byte_stream);
569
570                    http::Response::builder()
571                        .status(http::StatusCode::OK)
572                        .header(reqwest::header::CONTENT_TYPE, "text/event-stream")
573                        .body(boxed_stream)
574                        .map_err(crate::http_client::Error::Protocol)
575                }
576            }
577        }
578
579        // Some providers emit a final "usage-only" chunk where `choices` is empty.
580        let sse = concat!(
581            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
582            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
583            "data: [DONE]\n\n",
584        );
585
586        let client = MockHttpClient {
587            sse_bytes: Bytes::from(sse),
588        };
589
590        let req = http::Request::builder()
591            .method("POST")
592            .uri("http://localhost/v1/chat/completions")
593            .body(Vec::new())
594            .unwrap();
595
596        let mut stream = send_compatible_streaming_request(client, req)
597            .await
598            .unwrap();
599
600        let mut final_usage = None;
601        while let Some(chunk) = stream.next().await {
602            if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
603                final_usage = Some(res.usage);
604                break;
605            }
606        }
607
608        let usage = final_usage.expect("expected a final response with usage");
609        assert_eq!(usage.prompt_tokens, 10);
610        assert_eq!(usage.total_tokens, 15);
611    }
612}