rig/providers/openai/completion/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use tracing::debug;
14
15// ================================================================
16// OpenAI Completion Streaming API
17// ================================================================
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct StreamingFunction {
20    #[serde(default)]
21    pub name: Option<String>,
22    #[serde(default)]
23    pub arguments: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct StreamingToolCall {
28    pub index: usize,
29    pub id: Option<String>,
30    pub function: StreamingFunction,
31}
32
33#[derive(Deserialize, Debug)]
34struct StreamingDelta {
35    #[serde(default)]
36    content: Option<String>,
37    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
38    tool_calls: Vec<StreamingToolCall>,
39}
40
41#[derive(Deserialize, Debug)]
42struct StreamingChoice {
43    delta: StreamingDelta,
44}
45
46#[derive(Deserialize, Debug)]
47struct StreamingCompletionChunk {
48    choices: Vec<StreamingChoice>,
49    usage: Option<Usage>,
50}
51
52#[derive(Clone, Serialize, Deserialize)]
53pub struct StreamingCompletionResponse {
54    pub usage: Usage,
55}
56
57impl CompletionModel {
58    pub(crate) async fn stream(
59        &self,
60        completion_request: CompletionRequest,
61    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
62    {
63        let mut request = self.create_completion_request(completion_request)?;
64        request = merge(
65            request,
66            json!({"stream": true, "stream_options": {"include_usage": true}}),
67        );
68
69        let builder = self.client.post("/chat/completions").json(&request);
70        send_compatible_streaming_request(builder).await
71    }
72}
73
74pub async fn send_compatible_streaming_request(
75    request_builder: RequestBuilder,
76) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
77    let response = request_builder.send().await?;
78
79    if !response.status().is_success() {
80        return Err(CompletionError::ProviderError(format!(
81            "{}: {}",
82            response.status(),
83            response.text().await?
84        )));
85    }
86
87    // Handle OpenAI Compatible SSE chunks
88    let inner = Box::pin(stream! {
89        let mut stream = response.bytes_stream();
90
91        let mut final_usage = Usage {
92            prompt_tokens: 0,
93            total_tokens: 0
94        };
95
96        let mut partial_data = None;
97        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
98
99        while let Some(chunk_result) = stream.next().await {
100            let chunk = match chunk_result {
101                Ok(c) => c,
102                Err(e) => {
103                    yield Err(CompletionError::from(e));
104                    break;
105                }
106            };
107
108            let text = match String::from_utf8(chunk.to_vec()) {
109                Ok(t) => t,
110                Err(e) => {
111                    yield Err(CompletionError::ResponseError(e.to_string()));
112                    break;
113                }
114            };
115
116
117            for line in text.lines() {
118                let mut line = line.to_string();
119
120                // If there was a remaining part, concat with current line
121                if partial_data.is_some() {
122                    line = format!("{}{}", partial_data.unwrap(), line);
123                    partial_data = None;
124                }
125                // Otherwise full data line
126                else {
127                    let Some(data) = line.strip_prefix("data:") else {
128                        continue;
129                    };
130
131                    let data = data.trim_start();
132
133                    if data == "[DONE]" {
134                        break
135                    }
136
137                    // Partial data, split somewhere in the middle
138                    if !line.ends_with("}") {
139                        partial_data = Some(data.to_string());
140                    } else {
141                        line = data.to_string();
142                    }
143                }
144
145                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
146
147                let Ok(data) = data else {
148                    let err = data.unwrap_err();
149                    debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
150                    continue;
151                };
152
153
154                if let Some(choice) = data.choices.first() {
155
156                    let delta = &choice.delta;
157
158                    if !delta.tool_calls.is_empty() {
159                        for tool_call in &delta.tool_calls {
160                            let function = tool_call.function.clone();
161                            // Start of tool call
162                            // name: Some(String)
163                            // arguments: None
164                            if function.name.is_some() && function.arguments.is_empty() {
165                                let id = tool_call.id.clone().unwrap_or("".to_string());
166
167                                calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
168                            }
169                            // Part of tool call
170                            // name: None or Empty String
171                            // arguments: Some(String)
172                            else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
173                                let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
174                                    debug!("Partial tool call received but tool call was never started.");
175                                    continue;
176                                };
177
178                                let new_arguments = &tool_call.function.arguments;
179                                let arguments = format!("{arguments}{new_arguments}");
180
181                                calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
182                            }
183                            // Entire tool call
184                            else {
185                                let id = tool_call.id.clone().unwrap_or("".to_string());
186                                let name = function.name.expect("function name should be present for complete tool call");
187                                let arguments = function.arguments;
188                                let Ok(arguments) = serde_json::from_str(&arguments) else {
189                                    debug!("Couldn't serialize '{}' as a json value", arguments);
190                                    continue;
191                                };
192
193                                yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
194                            }
195                        }
196                    }
197
198                    if let Some(content) = &choice.delta.content {
199                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
200                    }
201                }
202
203
204                if let Some(usage) = data.usage {
205                    final_usage = usage.clone();
206                }
207            }
208        }
209
210        for (_, (id, name, arguments)) in calls {
211            let Ok(arguments) = serde_json::from_str(&arguments) else {
212                continue;
213            };
214
215            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
216        }
217
218        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
219            usage: final_usage.clone()
220        }))
221    });
222
223    Ok(streaming::StreamingCompletionResponse::stream(inner))
224}