rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct StreamingFunction {
20 #[serde(default)]
21 name: Option<String>,
22 #[serde(default)]
23 arguments: String,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct StreamingToolCall {
28 pub index: usize,
29 pub id: Option<String>,
30 pub function: StreamingFunction,
31}
32
33#[derive(Deserialize, Debug)]
34struct StreamingDelta {
35 #[serde(default)]
36 content: Option<String>,
37 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
38 tool_calls: Vec<StreamingToolCall>,
39}
40
41#[derive(Deserialize, Debug)]
42struct StreamingChoice {
43 delta: StreamingDelta,
44}
45
46#[derive(Deserialize, Debug)]
47struct StreamingCompletionChunk {
48 choices: Vec<StreamingChoice>,
49 usage: Option<Usage>,
50}
51
52#[derive(Clone)]
53pub struct StreamingCompletionResponse {
54 pub usage: Usage,
55}
56
57impl CompletionModel {
58 pub(crate) async fn stream(
59 &self,
60 completion_request: CompletionRequest,
61 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
62 {
63 let mut request = self.create_completion_request(completion_request)?;
64 request = merge(
65 request,
66 json!({"stream": true, "stream_options": {"include_usage": true}}),
67 );
68
69 let builder = self.client.post("/chat/completions").json(&request);
70 send_compatible_streaming_request(builder).await
71 }
72}
73
74pub async fn send_compatible_streaming_request(
75 request_builder: RequestBuilder,
76) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
77 let response = request_builder.send().await?;
78
79 if !response.status().is_success() {
80 return Err(CompletionError::ProviderError(format!(
81 "{}: {}",
82 response.status(),
83 response.text().await?
84 )));
85 }
86
87 let inner = Box::pin(stream! {
89 let mut stream = response.bytes_stream();
90
91 let mut final_usage = Usage {
92 prompt_tokens: 0,
93 total_tokens: 0
94 };
95
96 let mut partial_data = None;
97 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
98
99 while let Some(chunk_result) = stream.next().await {
100 let chunk = match chunk_result {
101 Ok(c) => c,
102 Err(e) => {
103 yield Err(CompletionError::from(e));
104 break;
105 }
106 };
107
108 let text = match String::from_utf8(chunk.to_vec()) {
109 Ok(t) => t,
110 Err(e) => {
111 yield Err(CompletionError::ResponseError(e.to_string()));
112 break;
113 }
114 };
115
116
117 for line in text.lines() {
118 let mut line = line.to_string();
119
120 if partial_data.is_some() {
122 line = format!("{}{}", partial_data.unwrap(), line);
123 partial_data = None;
124 }
125 else {
127 let Some(data) = line.strip_prefix("data:") else {
128 continue;
129 };
130
131 let data = data.trim_start();
132
133 if !line.ends_with("}") {
135 partial_data = Some(data.to_string());
136 } else {
137 line = data.to_string();
138 }
139 }
140
141 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
142
143 let Ok(data) = data else {
144 let err = data.unwrap_err();
145 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
146 continue;
147 };
148
149
150 if let Some(choice) = data.choices.first() {
151
152 let delta = &choice.delta;
153
154 if !delta.tool_calls.is_empty() {
155 for tool_call in &delta.tool_calls {
156 let function = tool_call.function.clone();
157 if function.name.is_some() && function.arguments.is_empty() {
161 let id = tool_call.id.clone().unwrap_or("".to_string());
162
163 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
164 }
165 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
169 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
170 debug!("Partial tool call received but tool call was never started.");
171 continue;
172 };
173
174 let new_arguments = &tool_call.function.arguments;
175 let arguments = format!("{arguments}{new_arguments}");
176
177 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
178 }
179 else {
181 let id = tool_call.id.clone().unwrap_or("".to_string());
182 let name = function.name.expect("function name should be present for complete tool call");
183 let arguments = function.arguments;
184 let Ok(arguments) = serde_json::from_str(&arguments) else {
185 debug!("Couldn't serialize '{}' as a json value", arguments);
186 continue;
187 };
188
189 yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
190 }
191 }
192 }
193
194 if let Some(content) = &choice.delta.content {
195 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
196 }
197 }
198
199
200 if let Some(usage) = data.usage {
201 final_usage = usage.clone();
202 }
203 }
204 }
205
206 for (_, (id, name, arguments)) in calls {
207 let Ok(arguments) = serde_json::from_str(&arguments) else {
208 continue;
209 };
210
211 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
212 }
213
214 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
215 usage: final_usage.clone()
216 }))
217 });
218
219 Ok(streaming::StreamingCompletionResponse::stream(inner))
220}