rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::message::{ToolCall, ToolFunction};
16use crate::providers::openai::completion::{self, CompletionModel, Usage};
17use crate::streaming::{self, RawStreamingChoice};
18
19// ================================================================
20// OpenAI Completion Streaming API
21// ================================================================
22#[derive(Deserialize, Debug)]
23pub(crate) struct StreamingFunction {
24    pub(crate) name: Option<String>,
25    pub(crate) arguments: Option<String>,
26}
27
28#[derive(Deserialize, Debug)]
29pub(crate) struct StreamingToolCall {
30    pub(crate) index: usize,
31    pub(crate) id: Option<String>,
32    pub(crate) function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37    #[serde(default)]
38    content: Option<String>,
39    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40    tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug, PartialEq)]
44#[serde(rename_all = "snake_case")]
45pub enum FinishReason {
46    ToolCalls,
47    Stop,
48    ContentFilter,
49    Length,
50    #[serde(untagged)]
51    Other(String), // This will handle the deprecated function_call
52}
53
54#[derive(Deserialize, Debug)]
55struct StreamingChoice {
56    delta: StreamingDelta,
57    finish_reason: Option<FinishReason>,
58}
59
60#[derive(Deserialize, Debug)]
61struct StreamingCompletionChunk {
62    choices: Vec<StreamingChoice>,
63    usage: Option<Usage>,
64}
65
66#[derive(Clone, Serialize, Deserialize)]
67pub struct StreamingCompletionResponse {
68    pub usage: Usage,
69}
70
71impl GetTokenUsage for StreamingCompletionResponse {
72    fn token_usage(&self) -> Option<crate::completion::Usage> {
73        let mut usage = crate::completion::Usage::new();
74        usage.input_tokens = self.usage.prompt_tokens as u64;
75        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
76        usage.total_tokens = self.usage.total_tokens as u64;
77        Some(usage)
78    }
79}
80
81impl<T> CompletionModel<T>
82where
83    T: HttpClientExt + Clone + 'static,
84{
85    pub(crate) async fn stream(
86        &self,
87        completion_request: CompletionRequest,
88    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
89    {
90        let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
91        let request_messages = serde_json::to_string(&request.messages)
92            .expect("Converting to JSON from a Rust struct shouldn't fail");
93        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
94
95        request_as_json = merge(
96            request_as_json,
97            json!({"stream": true, "stream_options": {"include_usage": true}}),
98        );
99
100        let req_body = serde_json::to_vec(&request_as_json)?;
101
102        let req = self
103            .client
104            .post("/chat/completions")?
105            .body(req_body)
106            .map_err(|e| CompletionError::HttpError(e.into()))?;
107
108        let span = if tracing::Span::current().is_disabled() {
109            info_span!(
110                target: "rig::completions",
111                "chat",
112                gen_ai.operation.name = "chat",
113                gen_ai.provider.name = "openai",
114                gen_ai.request.model = self.model,
115                gen_ai.response.id = tracing::field::Empty,
116                gen_ai.response.model = self.model,
117                gen_ai.usage.output_tokens = tracing::field::Empty,
118                gen_ai.usage.input_tokens = tracing::field::Empty,
119                gen_ai.input.messages = request_messages,
120                gen_ai.output.messages = tracing::field::Empty,
121            )
122        } else {
123            tracing::Span::current()
124        };
125
126        let client = self.client.http_client().clone();
127
128        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
129    }
130}
131
132pub async fn send_compatible_streaming_request<T>(
133    http_client: T,
134    req: Request<Vec<u8>>,
135) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
136where
137    T: HttpClientExt + Clone + 'static,
138{
139    let span = tracing::Span::current();
140    // Build the request with proper headers for SSE
141    let mut event_source = GenericEventSource::new(http_client, req);
142
143    let stream = stream! {
144        let span = tracing::Span::current();
145
146        // Accumulate tool calls by index while streaming
147        let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
148        let mut text_content = String::new();
149        let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
150        let mut final_usage = None;
151
152        while let Some(event_result) = event_source.next().await {
153            match event_result {
154                Ok(Event::Open) => {
155                    tracing::trace!("SSE connection opened");
156                    continue;
157                }
158
159                Ok(Event::Message(message)) => {
160                    if message.data.trim().is_empty() || message.data == "[DONE]" {
161                        continue;
162                    }
163
164                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
165                        Ok(data) => data,
166                        Err(error) => {
167                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
168                            continue;
169                        }
170                    };
171
172                    // Expect at least one choice
173                     let Some(choice) = data.choices.first() else {
174                        tracing::debug!("There is no choice");
175                        continue;
176                    };
177                    let delta = &choice.delta;
178
179                    if !delta.tool_calls.is_empty() {
180                        for tool_call in &delta.tool_calls {
181                            let index = tool_call.index;
182
183                            // Get or create tool call entry
184                            let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
185                                id: String::new(),
186                                call_id: None,
187                                function: ToolFunction {
188                                    name: String::new(),
189                                    arguments: serde_json::Value::Null,
190                                },
191                            });
192
193                            // Update fields if present
194                            if let Some(id) = &tool_call.id && !id.is_empty() {
195                                    existing_tool_call.id = id.clone();
196                            }
197
198                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
199                                    existing_tool_call.function.name = name.clone();
200                            }
201
202                            if let Some(chunk) = &tool_call.function.arguments {
203                                // Convert current arguments to string if needed
204                                let current_args = match &existing_tool_call.function.arguments {
205                                    serde_json::Value::Null => String::new(),
206                                    serde_json::Value::String(s) => s.clone(),
207                                    v => v.to_string(),
208                                };
209
210                                // Concatenate the new chunk
211                                let combined = format!("{current_args}{chunk}");
212
213                                // Try to parse as JSON if it looks complete
214                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
215                                    match serde_json::from_str(&combined) {
216                                        Ok(parsed) => existing_tool_call.function.arguments = parsed,
217                                        Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
218                                    }
219                                } else {
220                                    existing_tool_call.function.arguments = serde_json::Value::String(combined);
221                                }
222
223                                // Emit the delta so UI can show progress
224                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
225                                    id: existing_tool_call.id.clone(),
226                                    delta: chunk.clone(),
227                                });
228                            }
229                        }
230                    }
231
232                    // Streamed text content
233                    if let Some(content) = &delta.content && !content.is_empty() {
234                        text_content += content;
235                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
236                    }
237
238                    // Usage updates
239                    if let Some(usage) = data.usage {
240                        final_usage = Some(usage);
241                    }
242
243                    // Finish reason
244                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
245                        for (_idx, tool_call) in tool_calls.into_iter() {
246                            final_tool_calls.push(completion::ToolCall {
247                                id: tool_call.id.clone(),
248                                r#type: completion::ToolType::Function,
249                                function: completion::Function {
250                                    name: tool_call.function.name.clone(),
251                                    arguments: tool_call.function.arguments.clone(),
252                                },
253                            });
254                            yield Ok(streaming::RawStreamingChoice::ToolCall {
255                                name: tool_call.function.name,
256                                id: tool_call.id,
257                                arguments: tool_call.function.arguments,
258                                call_id: None,
259                            });
260                        }
261                        tool_calls = HashMap::new();
262                    }
263                }
264                Err(crate::http_client::Error::StreamEnded) => {
265                    break;
266                }
267                Err(error) => {
268                    tracing::error!(?error, "SSE error");
269                    yield Err(CompletionError::ProviderError(error.to_string()));
270                    break;
271                }
272            }
273        }
274
275
276        // Ensure event source is closed when stream ends
277        event_source.close();
278
279        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
280        for (_idx, tool_call) in tool_calls.into_iter() {
281            yield Ok(streaming::RawStreamingChoice::ToolCall {
282                name: tool_call.function.name,
283                id: tool_call.id,
284                arguments: tool_call.function.arguments,
285                call_id: None,
286            });
287        }
288
289        let final_usage = final_usage.unwrap_or_default();
290        if !span.is_disabled() {
291            let message_output = super::Message::Assistant {
292                content: vec![super::AssistantContent::Text { text: text_content }],
293                refusal: None,
294                audio: None,
295                name: None,
296                tool_calls: final_tool_calls
297            };
298            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
299            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
300            span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
301        }
302
303        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
304            usage: final_usage
305        }));
306    }.instrument(span);
307
308    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
309        stream,
310    )))
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_streaming_function_deserialization() {
319        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
320        let function: StreamingFunction = serde_json::from_str(json).unwrap();
321        assert_eq!(function.name, Some("get_weather".to_string()));
322        assert_eq!(
323            function.arguments.as_ref().unwrap(),
324            r#"{"location":"Paris"}"#
325        );
326    }
327
328    #[test]
329    fn test_streaming_tool_call_deserialization() {
330        let json = r#"{
331            "index": 0,
332            "id": "call_abc123",
333            "function": {
334                "name": "get_weather",
335                "arguments": "{\"city\":\"London\"}"
336            }
337        }"#;
338        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
339        assert_eq!(tool_call.index, 0);
340        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
341        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
342    }
343
344    #[test]
345    fn test_streaming_tool_call_partial_deserialization() {
346        // Partial tool calls have no name and partial arguments
347        let json = r#"{
348            "index": 0,
349            "id": null,
350            "function": {
351                "name": null,
352                "arguments": "Paris"
353            }
354        }"#;
355        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
356        assert_eq!(tool_call.index, 0);
357        assert!(tool_call.id.is_none());
358        assert!(tool_call.function.name.is_none());
359        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
360    }
361
362    #[test]
363    fn test_streaming_delta_with_tool_calls() {
364        let json = r#"{
365            "content": null,
366            "tool_calls": [{
367                "index": 0,
368                "id": "call_xyz",
369                "function": {
370                    "name": "search",
371                    "arguments": ""
372                }
373            }]
374        }"#;
375        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
376        assert!(delta.content.is_none());
377        assert_eq!(delta.tool_calls.len(), 1);
378        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
379    }
380
381    #[test]
382    fn test_streaming_chunk_deserialization() {
383        let json = r#"{
384            "choices": [{
385                "delta": {
386                    "content": "Hello",
387                    "tool_calls": []
388                }
389            }],
390            "usage": {
391                "prompt_tokens": 10,
392                "completion_tokens": 5,
393                "total_tokens": 15
394            }
395        }"#;
396        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
397        assert_eq!(chunk.choices.len(), 1);
398        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
399        assert!(chunk.usage.is_some());
400    }
401
402    #[test]
403    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
404        // Simulates multiple partial tool call chunks arriving
405        let json_start = r#"{
406            "choices": [{
407                "delta": {
408                    "content": null,
409                    "tool_calls": [{
410                        "index": 0,
411                        "id": "call_123",
412                        "function": {
413                            "name": "get_weather",
414                            "arguments": ""
415                        }
416                    }]
417                }
418            }],
419            "usage": null
420        }"#;
421
422        let json_chunk1 = r#"{
423            "choices": [{
424                "delta": {
425                    "content": null,
426                    "tool_calls": [{
427                        "index": 0,
428                        "id": null,
429                        "function": {
430                            "name": null,
431                            "arguments": "{\"loc"
432                        }
433                    }]
434                }
435            }],
436            "usage": null
437        }"#;
438
439        let json_chunk2 = r#"{
440            "choices": [{
441                "delta": {
442                    "content": null,
443                    "tool_calls": [{
444                        "index": 0,
445                        "id": null,
446                        "function": {
447                            "name": null,
448                            "arguments": "ation\":\"NYC\"}"
449                        }
450                    }]
451                }
452            }],
453            "usage": null
454        }"#;
455
456        // Verify each chunk deserializes correctly
457        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
458        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
459        assert_eq!(
460            start_chunk.choices[0].delta.tool_calls[0]
461                .function
462                .name
463                .as_ref()
464                .unwrap(),
465            "get_weather"
466        );
467
468        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
469        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
470        assert_eq!(
471            chunk1.choices[0].delta.tool_calls[0]
472                .function
473                .arguments
474                .as_ref()
475                .unwrap(),
476            "{\"loc"
477        );
478
479        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
480        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
481        assert_eq!(
482            chunk2.choices[0].delta.tool_calls[0]
483                .function
484                .arguments
485                .as_ref()
486                .unwrap(),
487            "ation\":\"NYC\"}"
488        );
489    }
490}