rig/providers/openai/
streaming.rs1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge;
5use crate::providers::openai::Usage;
6use crate::streaming;
7use crate::streaming::RawStreamingChoice;
8use async_stream::stream;
9use futures::StreamExt;
10use reqwest::RequestBuilder;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::collections::HashMap;
14use tracing::debug;
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct StreamingFunction {
21 #[serde(default)]
22 name: Option<String>,
23 #[serde(default)]
24 arguments: String,
25}
26
27#[derive(Debug, Serialize, Deserialize, Clone)]
28pub struct StreamingToolCall {
29 pub index: usize,
30 pub id: Option<String>,
31 pub function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36 #[serde(default)]
37 content: Option<String>,
38 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
39 tool_calls: Vec<StreamingToolCall>,
40}
41
42#[derive(Deserialize, Debug)]
43struct StreamingChoice {
44 delta: StreamingDelta,
45}
46
47#[derive(Deserialize, Debug)]
48struct StreamingCompletionChunk {
49 choices: Vec<StreamingChoice>,
50 usage: Option<Usage>,
51}
52
53#[derive(Clone)]
54pub struct StreamingCompletionResponse {
55 pub usage: Usage,
56}
57
58impl CompletionModel {
59 pub(crate) async fn stream(
60 &self,
61 completion_request: CompletionRequest,
62 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
63 {
64 let mut request = self.create_completion_request(completion_request)?;
65 request = merge(
66 request,
67 json!({"stream": true, "stream_options": {"include_usage": true}}),
68 );
69
70 let builder = self.client.post("/chat/completions").json(&request);
71 send_compatible_streaming_request(builder).await
72 }
73}
74
75pub async fn send_compatible_streaming_request(
76 request_builder: RequestBuilder,
77) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
78 let response = request_builder.send().await?;
79
80 if !response.status().is_success() {
81 return Err(CompletionError::ProviderError(format!(
82 "{}: {}",
83 response.status(),
84 response.text().await?
85 )));
86 }
87
88 let inner = Box::pin(stream! {
90 let mut stream = response.bytes_stream();
91
92 let mut final_usage = Usage {
93 prompt_tokens: 0,
94 total_tokens: 0
95 };
96
97 let mut partial_data = None;
98 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
99
100 while let Some(chunk_result) = stream.next().await {
101 let chunk = match chunk_result {
102 Ok(c) => c,
103 Err(e) => {
104 yield Err(CompletionError::from(e));
105 break;
106 }
107 };
108
109 let text = match String::from_utf8(chunk.to_vec()) {
110 Ok(t) => t,
111 Err(e) => {
112 yield Err(CompletionError::ResponseError(e.to_string()));
113 break;
114 }
115 };
116
117
118 for line in text.lines() {
119 let mut line = line.to_string();
120
121 if partial_data.is_some() {
123 line = format!("{}{}", partial_data.unwrap(), line);
124 partial_data = None;
125 }
126 else {
128 let Some(data) = line.strip_prefix("data: ") else {
129 continue;
130 };
131
132 if !line.ends_with("}") {
134 partial_data = Some(data.to_string());
135 } else {
136 line = data.to_string();
137 }
138 }
139
140 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
141
142 let Ok(data) = data else {
143 let err = data.unwrap_err();
144 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
145 continue;
146 };
147
148
149 if let Some(choice) = data.choices.first() {
150
151 let delta = &choice.delta;
152
153 if !delta.tool_calls.is_empty() {
154 for tool_call in &delta.tool_calls {
155 let function = tool_call.function.clone();
156 if function.name.is_some() && function.arguments.is_empty() {
160 let id = tool_call.id.clone().unwrap_or("".to_string());
161
162 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
163 }
164 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
168 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
169 debug!("Partial tool call received but tool call was never started.");
170 continue;
171 };
172
173 let new_arguments = &tool_call.function.arguments;
174 let arguments = format!("{arguments}{new_arguments}");
175
176 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
177 }
178 else {
180 let id = tool_call.id.clone().unwrap_or("".to_string());
181 let name = function.name.expect("function name should be present for complete tool call");
182 let arguments = function.arguments;
183 let Ok(arguments) = serde_json::from_str(&arguments) else {
184 debug!("Couldn't serialize '{}' as a json value", arguments);
185 continue;
186 };
187
188 yield Ok(streaming::RawStreamingChoice::ToolCall {id, name, arguments})
189 }
190 }
191 }
192
193 if let Some(content) = &choice.delta.content {
194 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
195 }
196 }
197
198
199 if let Some(usage) = data.usage {
200 final_usage = usage.clone();
201 }
202 }
203 }
204
205 for (_, (id, name, arguments)) in calls {
206 let Ok(arguments) = serde_json::from_str(&arguments) else {
207 continue;
208 };
209
210 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments});
211 }
212
213 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
214 usage: final_usage.clone()
215 }))
216 });
217
218 Ok(streaming::StreamingCompletionResponse::stream(inner))
219}