rig/providers/openai/completion/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use tracing::debug;
14
15// ================================================================
16// OpenAI Completion Streaming API
17// ================================================================
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct StreamingFunction {
20    #[serde(default)]
21    pub name: Option<String>,
22    #[serde(default)]
23    pub arguments: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct StreamingToolCall {
28    pub index: usize,
29    pub id: Option<String>,
30    pub function: StreamingFunction,
31}
32
33#[derive(Deserialize, Debug)]
34struct StreamingDelta {
35    #[serde(default)]
36    content: Option<String>,
37    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
38    tool_calls: Vec<StreamingToolCall>,
39}
40
41#[derive(Deserialize, Debug)]
42struct StreamingChoice {
43    delta: StreamingDelta,
44}
45
46#[derive(Deserialize, Debug)]
47struct StreamingCompletionChunk {
48    choices: Vec<StreamingChoice>,
49    usage: Option<Usage>,
50}
51
52#[derive(Clone, Serialize, Deserialize)]
53pub struct StreamingCompletionResponse {
54    pub usage: Usage,
55}
56
57impl GetTokenUsage for StreamingCompletionResponse {
58    fn token_usage(&self) -> Option<crate::completion::Usage> {
59        let mut usage = crate::completion::Usage::new();
60        usage.input_tokens = self.usage.prompt_tokens as u64;
61        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
62        usage.total_tokens = self.usage.total_tokens as u64;
63        Some(usage)
64    }
65}
66
67impl CompletionModel {
68    pub(crate) async fn stream(
69        &self,
70        completion_request: CompletionRequest,
71    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
72    {
73        let mut request = self.create_completion_request(completion_request)?;
74        request = merge(
75            request,
76            json!({"stream": true, "stream_options": {"include_usage": true}}),
77        );
78
79        let builder = self.client.post("/chat/completions").json(&request);
80        send_compatible_streaming_request(builder).await
81    }
82}
83
84pub async fn send_compatible_streaming_request(
85    request_builder: RequestBuilder,
86) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
87    let response = request_builder.send().await?;
88
89    if !response.status().is_success() {
90        return Err(CompletionError::ProviderError(format!(
91            "{}: {}",
92            response.status(),
93            response.text().await?
94        )));
95    }
96
97    // Handle OpenAI Compatible SSE chunks
98    let inner = Box::pin(stream! {
99        let mut stream = response.bytes_stream();
100
101        let mut final_usage = Usage {
102            prompt_tokens: 0,
103            total_tokens: 0
104        };
105
106        let mut partial_data = None;
107        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
108
109        while let Some(chunk_result) = stream.next().await {
110            let chunk = match chunk_result {
111                Ok(c) => c,
112                Err(e) => {
113                    yield Err(CompletionError::from(e));
114                    break;
115                }
116            };
117
118            let text = match String::from_utf8(chunk.to_vec()) {
119                Ok(t) => t,
120                Err(e) => {
121                    yield Err(CompletionError::ResponseError(e.to_string()));
122                    break;
123                }
124            };
125
126
127            for line in text.lines() {
128                let mut line = line.to_string();
129
130                // If there was a remaining part, concat with current line
131                if partial_data.is_some() {
132                    line = format!("{}{}", partial_data.unwrap(), line);
133                    partial_data = None;
134                }
135                // Otherwise full data line
136                else {
137                    let Some(data) = line.strip_prefix("data:") else {
138                        continue;
139                    };
140
141                    let data = data.trim_start();
142
143                    if data == "[DONE]" {
144                        break
145                    }
146
147                    // Partial data, split somewhere in the middle
148                    if !line.ends_with("}") {
149                        partial_data = Some(data.to_string());
150                    } else {
151                        line = data.to_string();
152                    }
153                }
154
155                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
156
157                let Ok(data) = data else {
158                    let err = data.unwrap_err();
159                    debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
160                    continue;
161                };
162
163
164                if let Some(choice) = data.choices.first() {
165
166                    let delta = &choice.delta;
167
168                    if !delta.tool_calls.is_empty() {
169                        for tool_call in &delta.tool_calls {
170                            let function = tool_call.function.clone();
171                            // Start of tool call
172                            // name: Some(String)
173                            // arguments: None
174                            if function.name.is_some() && function.arguments.is_empty() {
175                                let id = tool_call.id.clone().unwrap_or("".to_string());
176
177                                calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
178                            }
179                            // Part of tool call
180                            // name: None or Empty String
181                            // arguments: Some(String)
182                            else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
183                                let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
184                                    debug!("Partial tool call received but tool call was never started.");
185                                    continue;
186                                };
187
188                                let new_arguments = &tool_call.function.arguments;
189                                let arguments = format!("{arguments}{new_arguments}");
190
191                                calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
192                            }
193                            // Entire tool call
194                            else {
195                                let id = tool_call.id.clone().unwrap_or("".to_string());
196                                let name = function.name.expect("function name should be present for complete tool call");
197                                let arguments = function.arguments;
198                                let Ok(arguments) = serde_json::from_str(&arguments) else {
199                                    debug!("Couldn't serialize '{}' as a json value", arguments);
200                                    continue;
201                                };
202
203                                yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
204                            }
205                        }
206                    }
207
208                    if let Some(content) = &choice.delta.content {
209                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
210                    }
211                }
212
213
214                if let Some(usage) = data.usage {
215                    final_usage = usage.clone();
216                }
217            }
218        }
219
220        for (_, (id, name, arguments)) in calls {
221            let Ok(arguments) = serde_json::from_str(&arguments) else {
222                continue;
223            };
224
225            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
226        }
227
228        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
229            usage: final_usage.clone()
230        }))
231    });
232
233    Ok(streaming::StreamingCompletionResponse::stream(inner))
234}