rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::json_utils;
5use crate::json_utils::merge;
6use crate::providers::openai::completion::{CompletionModel, Usage};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use http::Request;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::{debug, info_span};
16use tracing_futures::Instrument;
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct StreamingFunction {
23 #[serde(default)]
24 pub name: Option<String>,
25 #[serde(default)]
26 pub arguments: String,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct StreamingToolCall {
31 pub index: usize,
32 pub id: Option<String>,
33 pub function: StreamingFunction,
34}
35
36#[derive(Deserialize, Debug)]
37struct StreamingDelta {
38 #[serde(default)]
39 content: Option<String>,
40 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41 tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug)]
45struct StreamingChoice {
46 delta: StreamingDelta,
47}
48
49#[derive(Deserialize, Debug)]
50struct StreamingCompletionChunk {
51 choices: Vec<StreamingChoice>,
52 usage: Option<Usage>,
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57 pub usage: Usage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61 fn token_usage(&self) -> Option<crate::completion::Usage> {
62 let mut usage = crate::completion::Usage::new();
63 usage.input_tokens = self.usage.prompt_tokens as u64;
64 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
65 usage.total_tokens = self.usage.total_tokens as u64;
66 Some(usage)
67 }
68}
69
70impl CompletionModel<reqwest::Client> {
71 pub(crate) async fn stream(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
75 {
76 let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
77 let request_messages = serde_json::to_string(&request.messages)
78 .expect("Converting to JSON from a Rust struct shouldn't fail");
79 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
80
81 request_as_json = merge(
82 request_as_json,
83 json!({"stream": true, "stream_options": {"include_usage": true}}),
84 );
85
86 let req_body = serde_json::to_vec(&request_as_json)?;
87
88 let req = self
89 .client
90 .post("/chat/completions")?
91 .body(req_body)
92 .map_err(|e| CompletionError::HttpError(e.into()))?;
93
94 let span = if tracing::Span::current().is_disabled() {
95 info_span!(
96 target: "rig::completions",
97 "chat",
98 gen_ai.operation.name = "chat",
99 gen_ai.provider.name = "openai",
100 gen_ai.request.model = self.model,
101 gen_ai.response.id = tracing::field::Empty,
102 gen_ai.response.model = self.model,
103 gen_ai.usage.output_tokens = tracing::field::Empty,
104 gen_ai.usage.input_tokens = tracing::field::Empty,
105 gen_ai.input.messages = request_messages,
106 gen_ai.output.messages = tracing::field::Empty,
107 )
108 } else {
109 tracing::Span::current()
110 };
111
112 tracing::Instrument::instrument(
113 send_compatible_streaming_request(self.client.http_client.clone(), req),
114 span,
115 )
116 .await
117 }
118}
119
120pub async fn send_compatible_streaming_request<T>(
121 http_client: T,
122 req: Request<Vec<u8>>,
123) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
124where
125 T: HttpClientExt + Clone + 'static,
126{
127 let span = tracing::Span::current();
128 let mut event_source = GenericEventSource::new(http_client, req);
130
131 let stream = stream! {
132 let span = tracing::Span::current();
133 let mut final_usage = Usage::new();
134
135 let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
137
138 let mut text_content = String::new();
139
140 while let Some(event_result) = event_source.next().await {
141 match event_result {
142 Ok(Event::Open) => {
143 tracing::trace!("SSE connection opened");
144 continue;
145 }
146 Ok(Event::Message(message)) => {
147 if message.data.trim().is_empty() || message.data == "[DONE]" {
148 continue;
149 }
150
151 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
152 let Ok(data) = data else {
153 let err = data.unwrap_err();
154 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
155 continue;
156 };
157
158 if let Some(choice) = data.choices.first() {
159 let delta = &choice.delta;
160
161 if !delta.tool_calls.is_empty() {
163 for tool_call in &delta.tool_calls {
164 let function = tool_call.function.clone();
165
166 if function.name.is_some() && function.arguments.is_empty() {
168 let id = tool_call.id.clone().unwrap_or_default();
169 tool_calls.insert(
170 tool_call.index,
171 (id, function.name.clone().unwrap(), "".to_string()),
172 );
173 }
174 else if function.name.clone().is_none_or(|s| s.is_empty())
178 && !function.arguments.is_empty()
179 {
180 if let Some((id, name, arguments)) =
181 tool_calls.get(&tool_call.index)
182 {
183 let new_arguments = &tool_call.function.arguments;
184 let arguments = format!("{arguments}{new_arguments}");
185 tool_calls.insert(
186 tool_call.index,
187 (id.clone(), name.clone(), arguments),
188 );
189 } else {
190 debug!("Partial tool call received but tool call was never started.");
191 }
192 }
193 else {
195 let id = tool_call.id.clone().unwrap_or_default();
196 let name = function.name.expect("tool call should have a name");
197 let arguments = function.arguments;
198 let Ok(arguments) = serde_json::from_str(&arguments) else {
199 debug!("Couldn't serialize '{arguments}' as JSON");
200 continue;
201 };
202
203 yield Ok(streaming::RawStreamingChoice::ToolCall {
204 id,
205 name,
206 arguments,
207 call_id: None,
208 });
209 }
210 }
211 }
212
213 if let Some(content) = &choice.delta.content {
215 text_content += content;
216 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
217 }
218 }
219
220 if let Some(usage) = data.usage {
222 final_usage = usage.clone();
223 }
224 }
225 Err(crate::http_client::Error::StreamEnded) => {
226 break;
227 }
228 Err(error) => {
229 tracing::error!(?error, "SSE error");
230 yield Err(CompletionError::ResponseError(error.to_string()));
231 break;
232 }
233 }
234 }
235
236 event_source.close();
238
239 let mut vec_toolcalls = vec![];
240
241 for (_, (id, name, arguments)) in tool_calls {
243 let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&arguments) else {
244 continue;
245 };
246
247 vec_toolcalls.push(super::ToolCall {
248 r#type: super::ToolType::Function,
249 id: id.clone(),
250 function: super::Function {
251 name: name.clone(), arguments: arguments.clone()
252 },
253 });
254
255 yield Ok(RawStreamingChoice::ToolCall {
256 id,
257 name,
258 arguments,
259 call_id: None,
260 });
261 }
262
263 let message_output = super::Message::Assistant {
264 content: vec![super::AssistantContent::Text { text: text_content }],
265 refusal: None,
266 audio: None,
267 name: None,
268 tool_calls: vec_toolcalls
269 };
270
271 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
272 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
273 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
274
275 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
276 usage: final_usage.clone()
277 }));
278 }.instrument(span);
279
280 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
281 stream,
282 )))
283}