rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::message::{ToolCall, ToolFunction};
16use crate::providers::openai::completion::{self, CompletionModel, Usage};
17use crate::streaming::{self, RawStreamingChoice};
18
19// ================================================================
20// OpenAI Completion Streaming API
21// ================================================================
22#[derive(Deserialize, Debug)]
23pub(crate) struct StreamingFunction {
24    pub(crate) name: Option<String>,
25    pub(crate) arguments: Option<String>,
26}
27
28#[derive(Deserialize, Debug)]
29pub(crate) struct StreamingToolCall {
30    pub(crate) index: usize,
31    pub(crate) id: Option<String>,
32    pub(crate) function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37    #[serde(default)]
38    content: Option<String>,
39    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40    tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug, PartialEq)]
44#[serde(rename_all = "snake_case")]
45pub enum FinishReason {
46    ToolCalls,
47    Stop,
48    ContentFilter,
49    Length,
50    #[serde(untagged)]
51    Other(String), // This will handle the deprecated function_call
52}
53
54#[derive(Deserialize, Debug)]
55struct StreamingChoice {
56    delta: StreamingDelta,
57    finish_reason: Option<FinishReason>,
58}
59
60#[derive(Deserialize, Debug)]
61struct StreamingCompletionChunk {
62    choices: Vec<StreamingChoice>,
63    usage: Option<Usage>,
64}
65
66#[derive(Clone, Serialize, Deserialize)]
67pub struct StreamingCompletionResponse {
68    pub usage: Usage,
69}
70
71impl GetTokenUsage for StreamingCompletionResponse {
72    fn token_usage(&self) -> Option<crate::completion::Usage> {
73        let mut usage = crate::completion::Usage::new();
74        usage.input_tokens = self.usage.prompt_tokens as u64;
75        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
76        usage.total_tokens = self.usage.total_tokens as u64;
77        Some(usage)
78    }
79}
80
81impl<T> CompletionModel<T>
82where
83    T: HttpClientExt + Clone + 'static,
84{
85    pub(crate) async fn stream(
86        &self,
87        completion_request: CompletionRequest,
88    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
89    {
90        let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
91        let request_messages = serde_json::to_string(&request.messages)
92            .expect("Converting to JSON from a Rust struct shouldn't fail");
93        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
94
95        request_as_json = merge(
96            request_as_json,
97            json!({"stream": true, "stream_options": {"include_usage": true}}),
98        );
99
100        if enabled!(Level::TRACE) {
101            tracing::trace!(
102                target: "rig::completions",
103                "OpenAI Chat Completions streaming completion request: {}",
104                serde_json::to_string_pretty(&request_as_json)?
105            );
106        }
107
108        let req_body = serde_json::to_vec(&request_as_json)?;
109
110        let req = self
111            .client
112            .post("/chat/completions")?
113            .body(req_body)
114            .map_err(|e| CompletionError::HttpError(e.into()))?;
115
116        let span = if tracing::Span::current().is_disabled() {
117            info_span!(
118                target: "rig::completions",
119                "chat",
120                gen_ai.operation.name = "chat",
121                gen_ai.provider.name = "openai",
122                gen_ai.request.model = self.model,
123                gen_ai.response.id = tracing::field::Empty,
124                gen_ai.response.model = self.model,
125                gen_ai.usage.output_tokens = tracing::field::Empty,
126                gen_ai.usage.input_tokens = tracing::field::Empty,
127                gen_ai.input.messages = request_messages,
128                gen_ai.output.messages = tracing::field::Empty,
129            )
130        } else {
131            tracing::Span::current()
132        };
133
134        let client = self.client.clone();
135
136        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
137    }
138}
139
140pub async fn send_compatible_streaming_request<T>(
141    http_client: T,
142    req: Request<Vec<u8>>,
143) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
144where
145    T: HttpClientExt + Clone + 'static,
146{
147    let span = tracing::Span::current();
148    // Build the request with proper headers for SSE
149    let mut event_source = GenericEventSource::new(http_client, req);
150
151    let stream = stream! {
152        let span = tracing::Span::current();
153
154        // Accumulate tool calls by index while streaming
155        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
156        let mut text_content = String::new();
157        let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
158        let mut final_usage = None;
159
160        while let Some(event_result) = event_source.next().await {
161            match event_result {
162                Ok(Event::Open) => {
163                    tracing::trace!("SSE connection opened");
164                    continue;
165                }
166
167                Ok(Event::Message(message)) => {
168                    if message.data.trim().is_empty() || message.data == "[DONE]" {
169                        continue;
170                    }
171
172                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
173                        Ok(data) => data,
174                        Err(error) => {
175                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
176                            continue;
177                        }
178                    };
179
180                    // Expect at least one choice
181                     let Some(choice) = data.choices.first() else {
182                        tracing::debug!("There is no choice");
183                        continue;
184                    };
185                    let delta = &choice.delta;
186
187                    if !delta.tool_calls.is_empty() {
188                        for tool_call in &delta.tool_calls {
189                            let index = tool_call.index;
190
191                            // Get or create tool call entry
192                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
193                                id: String::new(),
194                                call_id: None,
195                                function: ToolFunction {
196                                    name: String::new(),
197                                    arguments: serde_json::Value::Null,
198                                },
199                            });
200
201                            // Update fields if present
202                            if let Some(id) = &tool_call.id && !id.is_empty() {
203                                    existing_tool_call.id = id.clone();
204                            }
205
206                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
207                                    existing_tool_call.function.name = name.clone();
208                            }
209
210                            if let Some(chunk) = &tool_call.function.arguments {
211                                // Convert current arguments to string if needed
212                                let current_args = match &existing_tool_call.function.arguments {
213                                    serde_json::Value::Null => String::new(),
214                                    serde_json::Value::String(s) => s.clone(),
215                                    v => v.to_string(),
216                                };
217
218                                // Concatenate the new chunk
219                                let combined = format!("{current_args}{chunk}");
220
221                                // Try to parse as JSON if it looks complete
222                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
223                                    match serde_json::from_str(&combined) {
224                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
225                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
226                                    }
227                                } else {
228                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
229                                }
230
231                                // Emit the delta so UI can show progress
232                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
233                                    id: existing_tool_call.id.clone(),
234                                    delta: chunk.clone(),
235                                });
236                            }
237                        }
238                    }
239
240                    // Streamed text content
241                    if let Some(content) = &delta.content && !content.is_empty() {
242                        text_content += content;
243                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
244                    }
245
246                    // Usage updates
247                    if let Some(usage) = data.usage {
248                        final_usage = Some(usage);
249                    }
250
251                    // Finish reason
252                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
253                        for (_idx, tool_call) in tool_calls.into_iter() {
254                            final_tool_calls.push(completion::ToolCall {
255                                id: tool_call.id.clone(),
256                                r#type: completion::ToolType::Function,
257                                function: completion::Function {
258                                    name: tool_call.function.name.clone(),
259                                    arguments: tool_call.function.arguments.clone(),
260                                },
261                            });
262                            yield Ok(streaming::RawStreamingChoice::ToolCall {
263                                name: tool_call.function.name,
264                                id: tool_call.id,
265                                arguments: tool_call.function.arguments,
266                                call_id: None,
267                            });
268                        }
269                        tool_calls = HashMap::new();
270                    }
271                }
272                Err(crate::http_client::Error::StreamEnded) => {
273                    break;
274                }
275                Err(error) => {
276                    tracing::error!(?error, "SSE error");
277                    yield Err(CompletionError::ProviderError(error.to_string()));
278                    break;
279                }
280            }
281        }
282
283
284        // Ensure event source is closed when stream ends
285        event_source.close();
286
287        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
288        for (_idx, tool_call) in tool_calls.into_iter() {
289            yield Ok(streaming::RawStreamingChoice::ToolCall {
290                name: tool_call.function.name,
291                id: tool_call.id,
292                arguments: tool_call.function.arguments,
293                call_id: None,
294            });
295        }
296
297        let final_usage = final_usage.unwrap_or_default();
298        if !span.is_disabled() {
299            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
300            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
301        }
302
303        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
304            usage: final_usage
305        }));
306    }.instrument(span);
307
308    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
309        stream,
310    )))
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_streaming_function_deserialization() {
319        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
320        let function: StreamingFunction = serde_json::from_str(json).unwrap();
321        assert_eq!(function.name, Some("get_weather".to_string()));
322        assert_eq!(
323            function.arguments.as_ref().unwrap(),
324            r#"{"location":"Paris"}"#
325        );
326    }
327
328    #[test]
329    fn test_streaming_tool_call_deserialization() {
330        let json = r#"{
331            "index": 0,
332            "id": "call_abc123",
333            "function": {
334                "name": "get_weather",
335                "arguments": "{\"city\":\"London\"}"
336            }
337        }"#;
338        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
339        assert_eq!(tool_call.index, 0);
340        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
341        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
342    }
343
344    #[test]
345    fn test_streaming_tool_call_partial_deserialization() {
346        // Partial tool calls have no name and partial arguments
347        let json = r#"{
348            "index": 0,
349            "id": null,
350            "function": {
351                "name": null,
352                "arguments": "Paris"
353            }
354        }"#;
355        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
356        assert_eq!(tool_call.index, 0);
357        assert!(tool_call.id.is_none());
358        assert!(tool_call.function.name.is_none());
359        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
360    }
361
362    #[test]
363    fn test_streaming_delta_with_tool_calls() {
364        let json = r#"{
365            "content": null,
366            "tool_calls": [{
367                "index": 0,
368                "id": "call_xyz",
369                "function": {
370                    "name": "search",
371                    "arguments": ""
372                }
373            }]
374        }"#;
375        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
376        assert!(delta.content.is_none());
377        assert_eq!(delta.tool_calls.len(), 1);
378        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
379    }
380
381    #[test]
382    fn test_streaming_chunk_deserialization() {
383        let json = r#"{
384            "choices": [{
385                "delta": {
386                    "content": "Hello",
387                    "tool_calls": []
388                }
389            }],
390            "usage": {
391                "prompt_tokens": 10,
392                "completion_tokens": 5,
393                "total_tokens": 15
394            }
395        }"#;
396        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
397        assert_eq!(chunk.choices.len(), 1);
398        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
399        assert!(chunk.usage.is_some());
400    }
401
402    #[test]
403    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
404        // Simulates multiple partial tool call chunks arriving
405        let json_start = r#"{
406            "choices": [{
407                "delta": {
408                    "content": null,
409                    "tool_calls": [{
410                        "index": 0,
411                        "id": "call_123",
412                        "function": {
413                            "name": "get_weather",
414                            "arguments": ""
415                        }
416                    }]
417                }
418            }],
419            "usage": null
420        }"#;
421
422        let json_chunk1 = r#"{
423            "choices": [{
424                "delta": {
425                    "content": null,
426                    "tool_calls": [{
427                        "index": 0,
428                        "id": null,
429                        "function": {
430                            "name": null,
431                            "arguments": "{\"loc"
432                        }
433                    }]
434                }
435            }],
436            "usage": null
437        }"#;
438
439        let json_chunk2 = r#"{
440            "choices": [{
441                "delta": {
442                    "content": null,
443                    "tool_calls": [{
444                        "index": 0,
445                        "id": null,
446                        "function": {
447                            "name": null,
448                            "arguments": "ation\":\"NYC\"}"
449                        }
450                    }]
451                }
452            }],
453            "usage": null
454        }"#;
455
456        // Verify each chunk deserializes correctly
457        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
458        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
459        assert_eq!(
460            start_chunk.choices[0].delta.tool_calls[0]
461                .function
462                .name
463                .as_ref()
464                .unwrap(),
465            "get_weather"
466        );
467
468        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
469        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
470        assert_eq!(
471            chunk1.choices[0].delta.tool_calls[0]
472                .function
473                .arguments
474                .as_ref()
475                .unwrap(),
476            "{\"loc"
477        );
478
479        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
480        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
481        assert_eq!(
482            chunk2.choices[0].delta.tool_calls[0]
483                .function
484                .arguments
485                .as_ref()
486                .unwrap(),
487            "ation\":\"NYC\"}"
488        );
489    }
490}