rig/providers/openai/completion/
streaming.rs1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::info_span;
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::message::{ToolCall, ToolFunction};
16use crate::providers::openai::completion::{self, CompletionModel, Usage};
17use crate::streaming::{self, RawStreamingChoice};
18
19#[derive(Deserialize, Debug)]
23pub(crate) struct StreamingFunction {
24 pub(crate) name: Option<String>,
25 pub(crate) arguments: Option<String>,
26}
27
28#[derive(Deserialize, Debug)]
29pub(crate) struct StreamingToolCall {
30 pub(crate) index: usize,
31 pub(crate) id: Option<String>,
32 pub(crate) function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37 #[serde(default)]
38 content: Option<String>,
39 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40 tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug, PartialEq)]
44#[serde(rename_all = "snake_case")]
45pub enum FinishReason {
46 ToolCalls,
47 Stop,
48 ContentFilter,
49 Length,
50 #[serde(untagged)]
51 Other(String), }
53
54#[derive(Deserialize, Debug)]
55struct StreamingChoice {
56 delta: StreamingDelta,
57 finish_reason: Option<FinishReason>,
58}
59
60#[derive(Deserialize, Debug)]
61struct StreamingCompletionChunk {
62 choices: Vec<StreamingChoice>,
63 usage: Option<Usage>,
64}
65
66#[derive(Clone, Serialize, Deserialize)]
67pub struct StreamingCompletionResponse {
68 pub usage: Usage,
69}
70
71impl GetTokenUsage for StreamingCompletionResponse {
72 fn token_usage(&self) -> Option<crate::completion::Usage> {
73 let mut usage = crate::completion::Usage::new();
74 usage.input_tokens = self.usage.prompt_tokens as u64;
75 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
76 usage.total_tokens = self.usage.total_tokens as u64;
77 Some(usage)
78 }
79}
80
81impl<T> CompletionModel<T>
82where
83 T: HttpClientExt + Clone + 'static,
84{
85 pub(crate) async fn stream(
86 &self,
87 completion_request: CompletionRequest,
88 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
89 {
90 let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
91 let request_messages = serde_json::to_string(&request.messages)
92 .expect("Converting to JSON from a Rust struct shouldn't fail");
93 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
94
95 request_as_json = merge(
96 request_as_json,
97 json!({"stream": true, "stream_options": {"include_usage": true}}),
98 );
99
100 let req_body = serde_json::to_vec(&request_as_json)?;
101
102 let req = self
103 .client
104 .post("/chat/completions")?
105 .body(req_body)
106 .map_err(|e| CompletionError::HttpError(e.into()))?;
107
108 let span = if tracing::Span::current().is_disabled() {
109 info_span!(
110 target: "rig::completions",
111 "chat",
112 gen_ai.operation.name = "chat",
113 gen_ai.provider.name = "openai",
114 gen_ai.request.model = self.model,
115 gen_ai.response.id = tracing::field::Empty,
116 gen_ai.response.model = self.model,
117 gen_ai.usage.output_tokens = tracing::field::Empty,
118 gen_ai.usage.input_tokens = tracing::field::Empty,
119 gen_ai.input.messages = request_messages,
120 gen_ai.output.messages = tracing::field::Empty,
121 )
122 } else {
123 tracing::Span::current()
124 };
125
126 let client = self.client.http_client().clone();
127
128 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
129 }
130}
131
132pub async fn send_compatible_streaming_request<T>(
133 http_client: T,
134 req: Request<Vec<u8>>,
135) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
136where
137 T: HttpClientExt + Clone + 'static,
138{
139 let span = tracing::Span::current();
140 let mut event_source = GenericEventSource::new(http_client, req);
142
143 let stream = stream! {
144 let span = tracing::Span::current();
145
146 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
148 let mut text_content = String::new();
149 let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
150 let mut final_usage = None;
151
152 while let Some(event_result) = event_source.next().await {
153 match event_result {
154 Ok(Event::Open) => {
155 tracing::trace!("SSE connection opened");
156 continue;
157 }
158
159 Ok(Event::Message(message)) => {
160 if message.data.trim().is_empty() || message.data == "[DONE]" {
161 continue;
162 }
163
164 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
165 Ok(data) => data,
166 Err(error) => {
167 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
168 continue;
169 }
170 };
171
172 let Some(choice) = data.choices.first() else {
174 tracing::debug!("There is no choice");
175 continue;
176 };
177 let delta = &choice.delta;
178
179 if !delta.tool_calls.is_empty() {
180 for tool_call in &delta.tool_calls {
181 let index = tool_call.index;
182
183 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
185 id: String::new(),
186 call_id: None,
187 function: ToolFunction {
188 name: String::new(),
189 arguments: serde_json::Value::Null,
190 },
191 });
192
193 if let Some(id) = &tool_call.id && !id.is_empty() {
195 existing_tool_call.id = id.clone();
196 }
197
198 if let Some(name) = &tool_call.function.name && !name.is_empty() {
199 existing_tool_call.function.name = name.clone();
200 }
201
202 if let Some(chunk) = &tool_call.function.arguments {
203 let current_args = match &existing_tool_call.function.arguments {
205 serde_json::Value::Null => String::new(),
206 serde_json::Value::String(s) => s.clone(),
207 v => v.to_string(),
208 };
209
210 let combined = format!("{current_args}{chunk}");
212
213 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
215 match serde_json::from_str(&combined) {
216 Ok(parsed) => existing_tool_call.function.arguments = parsed,
217 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
218 }
219 } else {
220 existing_tool_call.function.arguments = serde_json::Value::String(combined);
221 }
222
223 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
225 id: existing_tool_call.id.clone(),
226 delta: chunk.clone(),
227 });
228 }
229 }
230 }
231
232 if let Some(content) = &delta.content && !content.is_empty() {
234 text_content += content;
235 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
236 }
237
238 if let Some(usage) = data.usage {
240 final_usage = Some(usage);
241 }
242
243 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
245 for (_idx, tool_call) in tool_calls.into_iter() {
246 final_tool_calls.push(completion::ToolCall {
247 id: tool_call.id.clone(),
248 r#type: completion::ToolType::Function,
249 function: completion::Function {
250 name: tool_call.function.name.clone(),
251 arguments: tool_call.function.arguments.clone(),
252 },
253 });
254 yield Ok(streaming::RawStreamingChoice::ToolCall {
255 name: tool_call.function.name,
256 id: tool_call.id,
257 arguments: tool_call.function.arguments,
258 call_id: None,
259 });
260 }
261 tool_calls = HashMap::new();
262 }
263 }
264 Err(crate::http_client::Error::StreamEnded) => {
265 break;
266 }
267 Err(error) => {
268 tracing::error!(?error, "SSE error");
269 yield Err(CompletionError::ProviderError(error.to_string()));
270 break;
271 }
272 }
273 }
274
275
276 event_source.close();
278
279 for (_idx, tool_call) in tool_calls.into_iter() {
281 yield Ok(streaming::RawStreamingChoice::ToolCall {
282 name: tool_call.function.name,
283 id: tool_call.id,
284 arguments: tool_call.function.arguments,
285 call_id: None,
286 });
287 }
288
289 let final_usage = final_usage.unwrap_or_default();
290 if !span.is_disabled() {
291 let message_output = super::Message::Assistant {
292 content: vec![super::AssistantContent::Text { text: text_content }],
293 refusal: None,
294 audio: None,
295 name: None,
296 tool_calls: final_tool_calls
297 };
298 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
299 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
300 span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
301 }
302
303 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
304 usage: final_usage
305 }));
306 }.instrument(span);
307
308 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
309 stream,
310 )))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_streaming_function_deserialization() {
319 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
320 let function: StreamingFunction = serde_json::from_str(json).unwrap();
321 assert_eq!(function.name, Some("get_weather".to_string()));
322 assert_eq!(
323 function.arguments.as_ref().unwrap(),
324 r#"{"location":"Paris"}"#
325 );
326 }
327
328 #[test]
329 fn test_streaming_tool_call_deserialization() {
330 let json = r#"{
331 "index": 0,
332 "id": "call_abc123",
333 "function": {
334 "name": "get_weather",
335 "arguments": "{\"city\":\"London\"}"
336 }
337 }"#;
338 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
339 assert_eq!(tool_call.index, 0);
340 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
341 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
342 }
343
344 #[test]
345 fn test_streaming_tool_call_partial_deserialization() {
346 let json = r#"{
348 "index": 0,
349 "id": null,
350 "function": {
351 "name": null,
352 "arguments": "Paris"
353 }
354 }"#;
355 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
356 assert_eq!(tool_call.index, 0);
357 assert!(tool_call.id.is_none());
358 assert!(tool_call.function.name.is_none());
359 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
360 }
361
362 #[test]
363 fn test_streaming_delta_with_tool_calls() {
364 let json = r#"{
365 "content": null,
366 "tool_calls": [{
367 "index": 0,
368 "id": "call_xyz",
369 "function": {
370 "name": "search",
371 "arguments": ""
372 }
373 }]
374 }"#;
375 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
376 assert!(delta.content.is_none());
377 assert_eq!(delta.tool_calls.len(), 1);
378 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
379 }
380
381 #[test]
382 fn test_streaming_chunk_deserialization() {
383 let json = r#"{
384 "choices": [{
385 "delta": {
386 "content": "Hello",
387 "tool_calls": []
388 }
389 }],
390 "usage": {
391 "prompt_tokens": 10,
392 "completion_tokens": 5,
393 "total_tokens": 15
394 }
395 }"#;
396 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
397 assert_eq!(chunk.choices.len(), 1);
398 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
399 assert!(chunk.usage.is_some());
400 }
401
402 #[test]
403 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
404 let json_start = r#"{
406 "choices": [{
407 "delta": {
408 "content": null,
409 "tool_calls": [{
410 "index": 0,
411 "id": "call_123",
412 "function": {
413 "name": "get_weather",
414 "arguments": ""
415 }
416 }]
417 }
418 }],
419 "usage": null
420 }"#;
421
422 let json_chunk1 = r#"{
423 "choices": [{
424 "delta": {
425 "content": null,
426 "tool_calls": [{
427 "index": 0,
428 "id": null,
429 "function": {
430 "name": null,
431 "arguments": "{\"loc"
432 }
433 }]
434 }
435 }],
436 "usage": null
437 }"#;
438
439 let json_chunk2 = r#"{
440 "choices": [{
441 "delta": {
442 "content": null,
443 "tool_calls": [{
444 "index": 0,
445 "id": null,
446 "function": {
447 "name": null,
448 "arguments": "ation\":\"NYC\"}"
449 }
450 }]
451 }
452 }],
453 "usage": null
454 }"#;
455
456 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
458 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
459 assert_eq!(
460 start_chunk.choices[0].delta.tool_calls[0]
461 .function
462 .name
463 .as_ref()
464 .unwrap(),
465 "get_weather"
466 );
467
468 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
469 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
470 assert_eq!(
471 chunk1.choices[0].delta.tool_calls[0]
472 .function
473 .arguments
474 .as_ref()
475 .unwrap(),
476 "{\"loc"
477 );
478
479 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
480 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
481 assert_eq!(
482 chunk2.choices[0].delta.tool_calls[0]
483 .function
484 .arguments
485 .as_ref()
486 .unwrap(),
487 "ation\":\"NYC\"}"
488 );
489 }
490}