rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct StreamingFunction {
20 #[serde(default)]
21 pub name: Option<String>,
22 #[serde(default)]
23 pub arguments: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct StreamingToolCall {
28 pub index: usize,
29 pub id: Option<String>,
30 pub function: StreamingFunction,
31}
32
33#[derive(Deserialize, Debug)]
34struct StreamingDelta {
35 #[serde(default)]
36 content: Option<String>,
37 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
38 tool_calls: Vec<StreamingToolCall>,
39}
40
41#[derive(Deserialize, Debug)]
42struct StreamingChoice {
43 delta: StreamingDelta,
44}
45
46#[derive(Deserialize, Debug)]
47struct StreamingCompletionChunk {
48 choices: Vec<StreamingChoice>,
49 usage: Option<Usage>,
50}
51
52#[derive(Clone, Serialize, Deserialize)]
53pub struct StreamingCompletionResponse {
54 pub usage: Usage,
55}
56
57impl GetTokenUsage for StreamingCompletionResponse {
58 fn token_usage(&self) -> Option<crate::completion::Usage> {
59 let mut usage = crate::completion::Usage::new();
60 usage.input_tokens = self.usage.prompt_tokens as u64;
61 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
62 usage.total_tokens = self.usage.total_tokens as u64;
63 Some(usage)
64 }
65}
66
67impl CompletionModel {
68 pub(crate) async fn stream(
69 &self,
70 completion_request: CompletionRequest,
71 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
72 {
73 let mut request = self.create_completion_request(completion_request)?;
74 request = merge(
75 request,
76 json!({"stream": true, "stream_options": {"include_usage": true}}),
77 );
78
79 let builder = self.client.post("/chat/completions").json(&request);
80 send_compatible_streaming_request(builder).await
81 }
82}
83
84pub async fn send_compatible_streaming_request(
85 request_builder: RequestBuilder,
86) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
87 let response = request_builder.send().await?;
88
89 if !response.status().is_success() {
90 return Err(CompletionError::ProviderError(format!(
91 "{}: {}",
92 response.status(),
93 response.text().await?
94 )));
95 }
96
97 let inner = Box::pin(stream! {
99 let mut stream = response.bytes_stream();
100
101 let mut final_usage = Usage {
102 prompt_tokens: 0,
103 total_tokens: 0
104 };
105
106 let mut partial_data = None;
107 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
108
109 while let Some(chunk_result) = stream.next().await {
110 let chunk = match chunk_result {
111 Ok(c) => c,
112 Err(e) => {
113 yield Err(CompletionError::from(e));
114 break;
115 }
116 };
117
118 let text = match String::from_utf8(chunk.to_vec()) {
119 Ok(t) => t,
120 Err(e) => {
121 yield Err(CompletionError::ResponseError(e.to_string()));
122 break;
123 }
124 };
125
126
127 for line in text.lines() {
128 let mut line = line.to_string();
129
130 if partial_data.is_some() {
132 line = format!("{}{}", partial_data.unwrap(), line);
133 partial_data = None;
134 }
135 else {
137 let Some(data) = line.strip_prefix("data:") else {
138 continue;
139 };
140
141 let data = data.trim_start();
142
143 if data == "[DONE]" {
144 break
145 }
146
147 if !line.ends_with("}") {
149 partial_data = Some(data.to_string());
150 } else {
151 line = data.to_string();
152 }
153 }
154
155 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
156
157 let Ok(data) = data else {
158 let err = data.unwrap_err();
159 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
160 continue;
161 };
162
163
164 if let Some(choice) = data.choices.first() {
165
166 let delta = &choice.delta;
167
168 if !delta.tool_calls.is_empty() {
169 for tool_call in &delta.tool_calls {
170 let function = tool_call.function.clone();
171 if function.name.is_some() && function.arguments.is_empty() {
175 let id = tool_call.id.clone().unwrap_or("".to_string());
176
177 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
178 }
179 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
183 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
184 debug!("Partial tool call received but tool call was never started.");
185 continue;
186 };
187
188 let new_arguments = &tool_call.function.arguments;
189 let arguments = format!("{arguments}{new_arguments}");
190
191 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
192 }
193 else {
195 let id = tool_call.id.clone().unwrap_or("".to_string());
196 let name = function.name.expect("function name should be present for complete tool call");
197 let arguments = function.arguments;
198 let Ok(arguments) = serde_json::from_str(&arguments) else {
199 debug!("Couldn't serialize '{}' as a json value", arguments);
200 continue;
201 };
202
203 yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
204 }
205 }
206 }
207
208 if let Some(content) = &choice.delta.content {
209 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
210 }
211 }
212
213
214 if let Some(usage) = data.usage {
215 final_usage = usage.clone();
216 }
217 }
218 }
219
220 for (_, (id, name, arguments)) in calls {
221 let Ok(arguments) = serde_json::from_str(&arguments) else {
222 continue;
223 };
224
225 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
226 }
227
228 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
229 usage: final_usage.clone()
230 }))
231 });
232
233 Ok(streaming::StreamingCompletionResponse::stream(inner))
234}