rig/providers/openai/
streaming.rs

1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge;
5use crate::providers::openai::Usage;
6use crate::streaming;
7use crate::streaming::RawStreamingChoice;
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::collections::HashMap;
14use tracing::debug;
15
16// ================================================================
17// OpenAI Completion Streaming API
18// ================================================================
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct StreamingFunction {
21    #[serde(default)]
22    name: Option<String>,
23    #[serde(default)]
24    arguments: String,
25}
26
27#[derive(Debug, Serialize, Deserialize, Clone)]
28pub struct StreamingToolCall {
29    pub index: usize,
30    pub id: Option<String>,
31    pub function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36    #[serde(default)]
37    content: Option<String>,
38    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
39    tool_calls: Vec<StreamingToolCall>,
40}
41
42#[derive(Deserialize, Debug)]
43struct StreamingChoice {
44    delta: StreamingDelta,
45}
46
47#[derive(Deserialize, Debug)]
48struct StreamingCompletionChunk {
49    choices: Vec<StreamingChoice>,
50    usage: Option<Usage>,
51}
52
53#[derive(Clone)]
54pub struct StreamingCompletionResponse {
55    pub usage: Usage,
56}
57
58impl CompletionModel {
59    pub(crate) async fn stream(
60        &self,
61        completion_request: CompletionRequest,
62    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
63    {
64        let mut request = self.create_completion_request(completion_request)?;
65        request = merge(
66            request,
67            json!({"stream": true, "stream_options": {"include_usage": true}}),
68        );
69
70        let builder = self.client.post("/chat/completions").json(&request);
71        send_compatible_streaming_request(builder).await
72    }
73}
74
75pub async fn send_compatible_streaming_request(
76    request_builder: RequestBuilder,
77) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
78    let response = request_builder.send().await?;
79
80    if !response.status().is_success() {
81        return Err(CompletionError::ProviderError(format!(
82            "{}: {}",
83            response.status(),
84            response.text().await?
85        )));
86    }
87
88    // Handle OpenAI Compatible SSE chunks
89    let inner = Box::pin(stream! {
90        let mut stream = response.bytes_stream();
91
92        let mut final_usage = Usage {
93            prompt_tokens: 0,
94            total_tokens: 0
95        };
96
97        let mut partial_data = None;
98        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
99
100        while let Some(chunk_result) = stream.next().await {
101            let chunk = match chunk_result {
102                Ok(c) => c,
103                Err(e) => {
104                    yield Err(CompletionError::from(e));
105                    break;
106                }
107            };
108
109            let text = match String::from_utf8(chunk.to_vec()) {
110                Ok(t) => t,
111                Err(e) => {
112                    yield Err(CompletionError::ResponseError(e.to_string()));
113                    break;
114                }
115            };
116
117
118            for line in text.lines() {
119                let mut line = line.to_string();
120
121                // If there was a remaining part, concat with current line
122                if partial_data.is_some() {
123                    line = format!("{}{}", partial_data.unwrap(), line);
124                    partial_data = None;
125                }
126                // Otherwise full data line
127                else {
128                    let Some(data) = line.strip_prefix("data: ") else {
129                        continue;
130                    };
131
132                    // Partial data, split somewhere in the middle
133                    if !line.ends_with("}") {
134                        partial_data = Some(data.to_string());
135                    } else {
136                        line = data.to_string();
137                    }
138                }
139
140                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
141
142                let Ok(data) = data else {
143                    let err = data.unwrap_err();
144                    debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
145                    continue;
146                };
147
148
149                if let Some(choice) = data.choices.first() {
150
151                    let delta = &choice.delta;
152
153                    if !delta.tool_calls.is_empty() {
154                        for tool_call in &delta.tool_calls {
155                            let function = tool_call.function.clone();
156                            // Start of tool call
157                            // name: Some(String)
158                            // arguments: None
159                            if function.name.is_some() && function.arguments.is_empty() {
160                                let id = tool_call.id.clone().unwrap_or("".to_string());
161
162                                calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
163                            }
164                            // Part of tool call
165                            // name: None or Empty String
166                            // arguments: Some(String)
167                            else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
168                                let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
169                                    debug!("Partial tool call received but tool call was never started.");
170                                    continue;
171                                };
172
173                                let new_arguments = &tool_call.function.arguments;
174                                let arguments = format!("{arguments}{new_arguments}");
175
176                                calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
177                            }
178                            // Entire tool call
179                            else {
180                                let id = tool_call.id.clone().unwrap_or("".to_string());
181                                let name = function.name.expect("function name should be present for complete tool call");
182                                let arguments = function.arguments;
183                                let Ok(arguments) = serde_json::from_str(&arguments) else {
184                                    debug!("Couldn't serialize '{}' as a json value", arguments);
185                                    continue;
186                                };
187
188                                yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments})
189                            }
190                        }
191                    }
192
193                    if let Some(content) = &choice.delta.content {
194                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
195                    }
196                }
197
198
199                if let Some(usage) = data.usage {
200                    final_usage = usage.clone();
201                }
202            }
203        }
204
205        for (_, (id, name, arguments)) in calls {
206            let Ok(arguments) = serde_json::from_str(&arguments) else {
207                continue;
208            };
209
210            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments});
211        }
212
213        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
214            usage: final_usage.clone()
215        }))
216    });
217
218    Ok(streaming::StreamingCompletionResponse::stream(inner))
219}