rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct StreamingFunction {
20 #[serde(default)]
21 pub name: Option<String>,
22 #[serde(default)]
23 pub arguments: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct StreamingToolCall {
28 pub index: usize,
29 pub id: Option<String>,
30 pub function: StreamingFunction,
31}
32
33#[derive(Deserialize, Debug)]
34struct StreamingDelta {
35 #[serde(default)]
36 content: Option<String>,
37 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
38 tool_calls: Vec<StreamingToolCall>,
39}
40
41#[derive(Deserialize, Debug)]
42struct StreamingChoice {
43 delta: StreamingDelta,
44}
45
46#[derive(Deserialize, Debug)]
47struct StreamingCompletionChunk {
48 choices: Vec<StreamingChoice>,
49 usage: Option<Usage>,
50}
51
52#[derive(Clone, Serialize, Deserialize)]
53pub struct StreamingCompletionResponse {
54 pub usage: Usage,
55}
56
57impl CompletionModel {
58 pub(crate) async fn stream(
59 &self,
60 completion_request: CompletionRequest,
61 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
62 {
63 let mut request = self.create_completion_request(completion_request)?;
64 request = merge(
65 request,
66 json!({"stream": true, "stream_options": {"include_usage": true}}),
67 );
68
69 let builder = self.client.post("/chat/completions").json(&request);
70 send_compatible_streaming_request(builder).await
71 }
72}
73
74pub async fn send_compatible_streaming_request(
75 request_builder: RequestBuilder,
76) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
77 let response = request_builder.send().await?;
78
79 if !response.status().is_success() {
80 return Err(CompletionError::ProviderError(format!(
81 "{}: {}",
82 response.status(),
83 response.text().await?
84 )));
85 }
86
87 let inner = Box::pin(stream! {
89 let mut stream = response.bytes_stream();
90
91 let mut final_usage = Usage {
92 prompt_tokens: 0,
93 total_tokens: 0
94 };
95
96 let mut partial_data = None;
97 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
98
99 while let Some(chunk_result) = stream.next().await {
100 let chunk = match chunk_result {
101 Ok(c) => c,
102 Err(e) => {
103 yield Err(CompletionError::from(e));
104 break;
105 }
106 };
107
108 let text = match String::from_utf8(chunk.to_vec()) {
109 Ok(t) => t,
110 Err(e) => {
111 yield Err(CompletionError::ResponseError(e.to_string()));
112 break;
113 }
114 };
115
116
117 for line in text.lines() {
118 let mut line = line.to_string();
119
120 if partial_data.is_some() {
122 line = format!("{}{}", partial_data.unwrap(), line);
123 partial_data = None;
124 }
125 else {
127 let Some(data) = line.strip_prefix("data:") else {
128 continue;
129 };
130
131 let data = data.trim_start();
132
133 if data == "[DONE]" {
134 break
135 }
136
137 if !line.ends_with("}") {
139 partial_data = Some(data.to_string());
140 } else {
141 line = data.to_string();
142 }
143 }
144
145 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
146
147 let Ok(data) = data else {
148 let err = data.unwrap_err();
149 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
150 continue;
151 };
152
153
154 if let Some(choice) = data.choices.first() {
155
156 let delta = &choice.delta;
157
158 if !delta.tool_calls.is_empty() {
159 for tool_call in &delta.tool_calls {
160 let function = tool_call.function.clone();
161 if function.name.is_some() && function.arguments.is_empty() {
165 let id = tool_call.id.clone().unwrap_or("".to_string());
166
167 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
168 }
169 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
173 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
174 debug!("Partial tool call received but tool call was never started.");
175 continue;
176 };
177
178 let new_arguments = &tool_call.function.arguments;
179 let arguments = format!("{arguments}{new_arguments}");
180
181 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
182 }
183 else {
185 let id = tool_call.id.clone().unwrap_or("".to_string());
186 let name = function.name.expect("function name should be present for complete tool call");
187 let arguments = function.arguments;
188 let Ok(arguments) = serde_json::from_str(&arguments) else {
189 debug!("Couldn't serialize '{}' as a json value", arguments);
190 continue;
191 };
192
193 yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
194 }
195 }
196 }
197
198 if let Some(content) = &choice.delta.content {
199 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
200 }
201 }
202
203
204 if let Some(usage) = data.usage {
205 final_usage = usage.clone();
206 }
207 }
208 }
209
210 for (_, (id, name, arguments)) in calls {
211 let Ok(arguments) = serde_json::from_str(&arguments) else {
212 continue;
213 };
214
215 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
216 }
217
218 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
219 usage: final_usage.clone()
220 }))
221 });
222
223 Ok(streaming::StreamingCompletionResponse::stream(inner))
224}