rig/providers/openai/completion/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::json_utils;
5use crate::json_utils::merge;
6use crate::providers::openai::completion::{CompletionModel, Usage};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use http::Request;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::{debug, info_span};
16use tracing_futures::Instrument;
17
18// ================================================================
19// OpenAI Completion Streaming API
20// ================================================================
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct StreamingFunction {
23    #[serde(default)]
24    pub name: Option<String>,
25    #[serde(default)]
26    pub arguments: String,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct StreamingToolCall {
31    pub index: usize,
32    pub id: Option<String>,
33    pub function: StreamingFunction,
34}
35
36#[derive(Deserialize, Debug)]
37struct StreamingDelta {
38    #[serde(default)]
39    content: Option<String>,
40    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41    tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug)]
45struct StreamingChoice {
46    delta: StreamingDelta,
47}
48
49#[derive(Deserialize, Debug)]
50struct StreamingCompletionChunk {
51    choices: Vec<StreamingChoice>,
52    usage: Option<Usage>,
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57    pub usage: Usage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61    fn token_usage(&self) -> Option<crate::completion::Usage> {
62        let mut usage = crate::completion::Usage::new();
63        usage.input_tokens = self.usage.prompt_tokens as u64;
64        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
65        usage.total_tokens = self.usage.total_tokens as u64;
66        Some(usage)
67    }
68}
69
70impl CompletionModel<reqwest::Client> {
71    pub(crate) async fn stream(
72        &self,
73        completion_request: CompletionRequest,
74    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
75    {
76        let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
77        let request_messages = serde_json::to_string(&request.messages)
78            .expect("Converting to JSON from a Rust struct shouldn't fail");
79        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
80
81        request_as_json = merge(
82            request_as_json,
83            json!({"stream": true, "stream_options": {"include_usage": true}}),
84        );
85
86        let req_body = serde_json::to_vec(&request_as_json)?;
87
88        let req = self
89            .client
90            .post("/chat/completions")?
91            .header("Content-Type", "application/json")
92            .body(req_body)
93            .map_err(|e| CompletionError::HttpError(e.into()))?;
94
95        let span = if tracing::Span::current().is_disabled() {
96            info_span!(
97                target: "rig::completions",
98                "chat",
99                gen_ai.operation.name = "chat",
100                gen_ai.provider.name = "openai",
101                gen_ai.request.model = self.model,
102                gen_ai.response.id = tracing::field::Empty,
103                gen_ai.response.model = self.model,
104                gen_ai.usage.output_tokens = tracing::field::Empty,
105                gen_ai.usage.input_tokens = tracing::field::Empty,
106                gen_ai.input.messages = request_messages,
107                gen_ai.output.messages = tracing::field::Empty,
108            )
109        } else {
110            tracing::Span::current()
111        };
112
113        tracing::Instrument::instrument(
114            send_compatible_streaming_request(self.client.http_client.clone(), req),
115            span,
116        )
117        .await
118    }
119}
120
121pub async fn send_compatible_streaming_request<T>(
122    http_client: T,
123    req: Request<Vec<u8>>,
124) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
125where
126    T: HttpClientExt + Clone + 'static,
127{
128    let span = tracing::Span::current();
129    // Build the request with proper headers for SSE
130    let mut event_source = GenericEventSource::new(http_client, req);
131
132    let stream = stream! {
133        let span = tracing::Span::current();
134        let mut final_usage = Usage::new();
135
136        // Track in-progress tool calls
137        let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
138
139        let mut text_content = String::new();
140
141        while let Some(event_result) = event_source.next().await {
142            match event_result {
143                Ok(Event::Open) => {
144                    tracing::trace!("SSE connection opened");
145                    continue;
146                }
147                Ok(Event::Message(message)) => {
148                    if message.data.trim().is_empty() || message.data == "[DONE]" {
149                        continue;
150                    }
151
152                    let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
153                    let Ok(data) = data else {
154                        let err = data.unwrap_err();
155                        debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
156                        continue;
157                    };
158
159                    if let Some(choice) = data.choices.first() {
160                        let delta = &choice.delta;
161
162                        // Tool calls
163                        if !delta.tool_calls.is_empty() {
164                            for tool_call in &delta.tool_calls {
165                                let function = tool_call.function.clone();
166
167                                // Start of tool call
168                                if function.name.is_some() && function.arguments.is_empty() {
169                                    let id = tool_call.id.clone().unwrap_or_default();
170                                    tool_calls.insert(
171                                        tool_call.index,
172                                        (id, function.name.clone().unwrap(), "".to_string()),
173                                    );
174                                }
175                                // tool call partial (ie, a continuation of a previously received tool call)
176                                // name: None or Empty String
177                                // arguments: Some(String)
178                                else if function.name.clone().is_none_or(|s| s.is_empty())
179                                    && !function.arguments.is_empty()
180                                {
181                                    if let Some((id, name, arguments)) =
182                                        tool_calls.get(&tool_call.index).cloned()
183                                    {
184                                        let new_arguments = &tool_call.function.arguments;
185                                        let combined_arguments = format!("{arguments}{new_arguments}");
186                                        tool_calls.insert(
187                                            tool_call.index,
188                                            (id.clone(), name.clone(), combined_arguments),
189                                        );
190
191                                        // Emit the delta so UI can show progress
192                                        yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
193                                            id: id.clone(),
194                                            delta: new_arguments.clone(),
195                                        });
196                                    } else {
197                                        debug!("Partial tool call received but tool call was never started.");
198                                    }
199                                }
200                                // Complete tool call
201                                else {
202                                    let id = tool_call.id.clone().unwrap_or_default();
203                                    let name = function.name.expect("tool call should have a name");
204                                    let arguments = function.arguments;
205                                    let Ok(arguments) = serde_json::from_str(&arguments) else {
206                                        debug!("Couldn't serialize '{arguments}' as JSON");
207                                        continue;
208                                    };
209
210                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
211                                        id,
212                                        name,
213                                        arguments,
214                                        call_id: None,
215                                    });
216                                }
217                            }
218                        }
219
220                        // Message content
221                        if let Some(content) = &choice.delta.content {
222                            text_content += content;
223                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
224                        }
225                    }
226
227                    // Usage updates
228                    if let Some(usage) = data.usage {
229                        final_usage = usage.clone();
230                    }
231                }
232                Err(crate::http_client::Error::StreamEnded) => {
233                    break;
234                }
235                Err(error) => {
236                    tracing::error!(?error, "SSE error");
237                    yield Err(CompletionError::ResponseError(error.to_string()));
238                    break;
239                }
240            }
241        }
242
243        // Ensure event source is closed when stream ends
244        event_source.close();
245
246        let mut vec_toolcalls = vec![];
247
248        // Flush any tool calls that weren’t fully yielded
249        for (_, (id, name, arguments)) in tool_calls {
250            let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&arguments) else {
251                continue;
252            };
253
254            vec_toolcalls.push(super::ToolCall {
255                r#type: super::ToolType::Function,
256                id: id.clone(),
257                function: super::Function {
258                    name: name.clone(), arguments: arguments.clone()
259                },
260            });
261
262            yield Ok(RawStreamingChoice::ToolCall {
263                id,
264                name,
265                arguments,
266                call_id: None,
267            });
268        }
269
270        let message_output = super::Message::Assistant {
271            content: vec![super::AssistantContent::Text { text: text_content }],
272            refusal: None,
273            audio: None,
274            name: None,
275            tool_calls: vec_toolcalls
276        };
277
278        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
279        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
280        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
281
282        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
283            usage: final_usage.clone()
284        }));
285    }.instrument(span);
286
287    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
288        stream,
289    )))
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_streaming_function_deserialization() {
298        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
299        let function: StreamingFunction = serde_json::from_str(json).unwrap();
300        assert_eq!(function.name, Some("get_weather".to_string()));
301        assert_eq!(function.arguments, r#"{"location":"Paris"}"#.to_string());
302    }
303
304    #[test]
305    fn test_streaming_tool_call_deserialization() {
306        let json = r#"{
307            "index": 0,
308            "id": "call_abc123",
309            "function": {
310                "name": "get_weather",
311                "arguments": "{\"city\":\"London\"}"
312            }
313        }"#;
314        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
315        assert_eq!(tool_call.index, 0);
316        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
317        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
318    }
319
320    #[test]
321    fn test_streaming_tool_call_partial_deserialization() {
322        // Partial tool calls have no name and partial arguments
323        let json = r#"{
324            "index": 0,
325            "id": null,
326            "function": {
327                "name": null,
328                "arguments": "Paris"
329            }
330        }"#;
331        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
332        assert_eq!(tool_call.index, 0);
333        assert!(tool_call.id.is_none());
334        assert!(tool_call.function.name.is_none());
335        assert_eq!(tool_call.function.arguments, "Paris");
336    }
337
338    #[test]
339    fn test_streaming_delta_with_tool_calls() {
340        let json = r#"{
341            "content": null,
342            "tool_calls": [{
343                "index": 0,
344                "id": "call_xyz",
345                "function": {
346                    "name": "search",
347                    "arguments": ""
348                }
349            }]
350        }"#;
351        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
352        assert!(delta.content.is_none());
353        assert_eq!(delta.tool_calls.len(), 1);
354        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
355    }
356
357    #[test]
358    fn test_streaming_chunk_deserialization() {
359        let json = r#"{
360            "choices": [{
361                "delta": {
362                    "content": "Hello",
363                    "tool_calls": []
364                }
365            }],
366            "usage": {
367                "prompt_tokens": 10,
368                "completion_tokens": 5,
369                "total_tokens": 15
370            }
371        }"#;
372        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
373        assert_eq!(chunk.choices.len(), 1);
374        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
375        assert!(chunk.usage.is_some());
376    }
377
378    #[test]
379    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
380        // Simulates multiple partial tool call chunks arriving
381        let json_start = r#"{
382            "choices": [{
383                "delta": {
384                    "content": null,
385                    "tool_calls": [{
386                        "index": 0,
387                        "id": "call_123",
388                        "function": {
389                            "name": "get_weather",
390                            "arguments": ""
391                        }
392                    }]
393                }
394            }],
395            "usage": null
396        }"#;
397
398        let json_chunk1 = r#"{
399            "choices": [{
400                "delta": {
401                    "content": null,
402                    "tool_calls": [{
403                        "index": 0,
404                        "id": null,
405                        "function": {
406                            "name": null,
407                            "arguments": "{\"loc"
408                        }
409                    }]
410                }
411            }],
412            "usage": null
413        }"#;
414
415        let json_chunk2 = r#"{
416            "choices": [{
417                "delta": {
418                    "content": null,
419                    "tool_calls": [{
420                        "index": 0,
421                        "id": null,
422                        "function": {
423                            "name": null,
424                            "arguments": "ation\":\"NYC\"}"
425                        }
426                    }]
427                }
428            }],
429            "usage": null
430        }"#;
431
432        // Verify each chunk deserializes correctly
433        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
434        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
435        assert_eq!(
436            start_chunk.choices[0].delta.tool_calls[0]
437                .function
438                .name
439                .as_ref()
440                .unwrap(),
441            "get_weather"
442        );
443
444        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
445        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
446        assert_eq!(
447            chunk1.choices[0].delta.tool_calls[0].function.arguments,
448            "{\"loc"
449        );
450
451        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
452        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
453        assert_eq!(
454            chunk2.choices[0].delta.tool_calls[0].function.arguments,
455            "ation\":\"NYC\"}"
456        );
457    }
458}