rig/providers/openai/completion/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::json_utils;
5use crate::json_utils::merge;
6use crate::providers::openai::completion::{CompletionModel, Usage};
7use crate::streaming;
8use crate::streaming::RawStreamingChoice;
9use async_stream::stream;
10use futures::StreamExt;
11use http::Request;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::{debug, info_span};
16use tracing_futures::Instrument;
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct StreamingFunction {
23 #[serde(default)]
24 pub name: Option<String>,
25 #[serde(default)]
26 pub arguments: String,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct StreamingToolCall {
31 pub index: usize,
32 pub id: Option<String>,
33 pub function: StreamingFunction,
34}
35
36#[derive(Deserialize, Debug)]
37struct StreamingDelta {
38 #[serde(default)]
39 content: Option<String>,
40 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41 tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug)]
45struct StreamingChoice {
46 delta: StreamingDelta,
47}
48
49#[derive(Deserialize, Debug)]
50struct StreamingCompletionChunk {
51 choices: Vec<StreamingChoice>,
52 usage: Option<Usage>,
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57 pub usage: Usage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61 fn token_usage(&self) -> Option<crate::completion::Usage> {
62 let mut usage = crate::completion::Usage::new();
63 usage.input_tokens = self.usage.prompt_tokens as u64;
64 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
65 usage.total_tokens = self.usage.total_tokens as u64;
66 Some(usage)
67 }
68}
69
70impl CompletionModel<reqwest::Client> {
71 pub(crate) async fn stream(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
75 {
76 let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
77 let request_messages = serde_json::to_string(&request.messages)
78 .expect("Converting to JSON from a Rust struct shouldn't fail");
79 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
80
81 request_as_json = merge(
82 request_as_json,
83 json!({"stream": true, "stream_options": {"include_usage": true}}),
84 );
85
86 let req_body = serde_json::to_vec(&request_as_json)?;
87
88 let req = self
89 .client
90 .post("/chat/completions")?
91 .header("Content-Type", "application/json")
92 .body(req_body)
93 .map_err(|e| CompletionError::HttpError(e.into()))?;
94
95 let span = if tracing::Span::current().is_disabled() {
96 info_span!(
97 target: "rig::completions",
98 "chat",
99 gen_ai.operation.name = "chat",
100 gen_ai.provider.name = "openai",
101 gen_ai.request.model = self.model,
102 gen_ai.response.id = tracing::field::Empty,
103 gen_ai.response.model = self.model,
104 gen_ai.usage.output_tokens = tracing::field::Empty,
105 gen_ai.usage.input_tokens = tracing::field::Empty,
106 gen_ai.input.messages = request_messages,
107 gen_ai.output.messages = tracing::field::Empty,
108 )
109 } else {
110 tracing::Span::current()
111 };
112
113 tracing::Instrument::instrument(
114 send_compatible_streaming_request(self.client.http_client.clone(), req),
115 span,
116 )
117 .await
118 }
119}
120
121pub async fn send_compatible_streaming_request<T>(
122 http_client: T,
123 req: Request<Vec<u8>>,
124) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
125where
126 T: HttpClientExt + Clone + 'static,
127{
128 let span = tracing::Span::current();
129 let mut event_source = GenericEventSource::new(http_client, req);
131
132 let stream = stream! {
133 let span = tracing::Span::current();
134 let mut final_usage = Usage::new();
135
136 let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
138
139 let mut text_content = String::new();
140
141 while let Some(event_result) = event_source.next().await {
142 match event_result {
143 Ok(Event::Open) => {
144 tracing::trace!("SSE connection opened");
145 continue;
146 }
147 Ok(Event::Message(message)) => {
148 if message.data.trim().is_empty() || message.data == "[DONE]" {
149 continue;
150 }
151
152 let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
153 let Ok(data) = data else {
154 let err = data.unwrap_err();
155 debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
156 continue;
157 };
158
159 if let Some(choice) = data.choices.first() {
160 let delta = &choice.delta;
161
162 if !delta.tool_calls.is_empty() {
164 for tool_call in &delta.tool_calls {
165 let function = tool_call.function.clone();
166
167 if function.name.is_some() && function.arguments.is_empty() {
169 let id = tool_call.id.clone().unwrap_or_default();
170 tool_calls.insert(
171 tool_call.index,
172 (id, function.name.clone().unwrap(), "".to_string()),
173 );
174 }
175 else if function.name.clone().is_none_or(|s| s.is_empty())
179 && !function.arguments.is_empty()
180 {
181 if let Some((id, name, arguments)) =
182 tool_calls.get(&tool_call.index).cloned()
183 {
184 let new_arguments = &tool_call.function.arguments;
185 let combined_arguments = format!("{arguments}{new_arguments}");
186 tool_calls.insert(
187 tool_call.index,
188 (id.clone(), name.clone(), combined_arguments),
189 );
190
191 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
193 id: id.clone(),
194 delta: new_arguments.clone(),
195 });
196 } else {
197 debug!("Partial tool call received but tool call was never started.");
198 }
199 }
200 else {
202 let id = tool_call.id.clone().unwrap_or_default();
203 let name = function.name.expect("tool call should have a name");
204 let arguments = function.arguments;
205 let Ok(arguments) = serde_json::from_str(&arguments) else {
206 debug!("Couldn't serialize '{arguments}' as JSON");
207 continue;
208 };
209
210 yield Ok(streaming::RawStreamingChoice::ToolCall {
211 id,
212 name,
213 arguments,
214 call_id: None,
215 });
216 }
217 }
218 }
219
220 if let Some(content) = &choice.delta.content {
222 text_content += content;
223 yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
224 }
225 }
226
227 if let Some(usage) = data.usage {
229 final_usage = usage.clone();
230 }
231 }
232 Err(crate::http_client::Error::StreamEnded) => {
233 break;
234 }
235 Err(error) => {
236 tracing::error!(?error, "SSE error");
237 yield Err(CompletionError::ResponseError(error.to_string()));
238 break;
239 }
240 }
241 }
242
243 event_source.close();
245
246 let mut vec_toolcalls = vec![];
247
248 for (_, (id, name, arguments)) in tool_calls {
250 let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&arguments) else {
251 continue;
252 };
253
254 vec_toolcalls.push(super::ToolCall {
255 r#type: super::ToolType::Function,
256 id: id.clone(),
257 function: super::Function {
258 name: name.clone(), arguments: arguments.clone()
259 },
260 });
261
262 yield Ok(RawStreamingChoice::ToolCall {
263 id,
264 name,
265 arguments,
266 call_id: None,
267 });
268 }
269
270 let message_output = super::Message::Assistant {
271 content: vec![super::AssistantContent::Text { text: text_content }],
272 refusal: None,
273 audio: None,
274 name: None,
275 tool_calls: vec_toolcalls
276 };
277
278 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
279 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
280 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
281
282 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
283 usage: final_usage.clone()
284 }));
285 }.instrument(span);
286
287 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
288 stream,
289 )))
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_streaming_function_deserialization() {
298 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
299 let function: StreamingFunction = serde_json::from_str(json).unwrap();
300 assert_eq!(function.name, Some("get_weather".to_string()));
301 assert_eq!(function.arguments, r#"{"location":"Paris"}"#.to_string());
302 }
303
304 #[test]
305 fn test_streaming_tool_call_deserialization() {
306 let json = r#"{
307 "index": 0,
308 "id": "call_abc123",
309 "function": {
310 "name": "get_weather",
311 "arguments": "{\"city\":\"London\"}"
312 }
313 }"#;
314 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
315 assert_eq!(tool_call.index, 0);
316 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
317 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
318 }
319
320 #[test]
321 fn test_streaming_tool_call_partial_deserialization() {
322 let json = r#"{
324 "index": 0,
325 "id": null,
326 "function": {
327 "name": null,
328 "arguments": "Paris"
329 }
330 }"#;
331 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
332 assert_eq!(tool_call.index, 0);
333 assert!(tool_call.id.is_none());
334 assert!(tool_call.function.name.is_none());
335 assert_eq!(tool_call.function.arguments, "Paris");
336 }
337
338 #[test]
339 fn test_streaming_delta_with_tool_calls() {
340 let json = r#"{
341 "content": null,
342 "tool_calls": [{
343 "index": 0,
344 "id": "call_xyz",
345 "function": {
346 "name": "search",
347 "arguments": ""
348 }
349 }]
350 }"#;
351 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
352 assert!(delta.content.is_none());
353 assert_eq!(delta.tool_calls.len(), 1);
354 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
355 }
356
357 #[test]
358 fn test_streaming_chunk_deserialization() {
359 let json = r#"{
360 "choices": [{
361 "delta": {
362 "content": "Hello",
363 "tool_calls": []
364 }
365 }],
366 "usage": {
367 "prompt_tokens": 10,
368 "completion_tokens": 5,
369 "total_tokens": 15
370 }
371 }"#;
372 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
373 assert_eq!(chunk.choices.len(), 1);
374 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
375 assert!(chunk.usage.is_some());
376 }
377
378 #[test]
379 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
380 let json_start = r#"{
382 "choices": [{
383 "delta": {
384 "content": null,
385 "tool_calls": [{
386 "index": 0,
387 "id": "call_123",
388 "function": {
389 "name": "get_weather",
390 "arguments": ""
391 }
392 }]
393 }
394 }],
395 "usage": null
396 }"#;
397
398 let json_chunk1 = r#"{
399 "choices": [{
400 "delta": {
401 "content": null,
402 "tool_calls": [{
403 "index": 0,
404 "id": null,
405 "function": {
406 "name": null,
407 "arguments": "{\"loc"
408 }
409 }]
410 }
411 }],
412 "usage": null
413 }"#;
414
415 let json_chunk2 = r#"{
416 "choices": [{
417 "delta": {
418 "content": null,
419 "tool_calls": [{
420 "index": 0,
421 "id": null,
422 "function": {
423 "name": null,
424 "arguments": "ation\":\"NYC\"}"
425 }
426 }]
427 }
428 }],
429 "usage": null
430 }"#;
431
432 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
434 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
435 assert_eq!(
436 start_chunk.choices[0].delta.tool_calls[0]
437 .function
438 .name
439 .as_ref()
440 .unwrap(),
441 "get_weather"
442 );
443
444 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
445 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
446 assert_eq!(
447 chunk1.choices[0].delta.tool_calls[0].function.arguments,
448 "{\"loc"
449 );
450
451 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
452 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
453 assert_eq!(
454 chunk2.choices[0].delta.tool_calls[0].function.arguments,
455 "ation\":\"NYC\"}"
456 );
457 }
458}