1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{self, CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23 pub(crate) name: Option<String>,
24 pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29 pub(crate) index: usize,
30 pub(crate) id: Option<String>,
31 pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36 #[serde(default)]
37 content: Option<String>,
38 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
39 tool_calls: Vec<StreamingToolCall>,
40}
41
42#[derive(Deserialize, Debug, PartialEq)]
43#[serde(rename_all = "snake_case")]
44pub enum FinishReason {
45 ToolCalls,
46 Stop,
47 ContentFilter,
48 Length,
49 #[serde(untagged)]
50 Other(String), }
52
53#[derive(Deserialize, Debug)]
54struct StreamingChoice {
55 delta: StreamingDelta,
56 finish_reason: Option<FinishReason>,
57}
58
59#[derive(Deserialize, Debug)]
60struct StreamingCompletionChunk {
61 choices: Vec<StreamingChoice>,
62 usage: Option<Usage>,
63}
64
65#[derive(Clone, Serialize, Deserialize)]
66pub struct StreamingCompletionResponse {
67 pub usage: Usage,
68}
69
70impl GetTokenUsage for StreamingCompletionResponse {
71 fn token_usage(&self) -> Option<crate::completion::Usage> {
72 let mut usage = crate::completion::Usage::new();
73 usage.input_tokens = self.usage.prompt_tokens as u64;
74 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
75 usage.total_tokens = self.usage.total_tokens as u64;
76 Some(usage)
77 }
78}
79
80impl<T> CompletionModel<T>
81where
82 T: HttpClientExt + Clone + 'static,
83{
84 pub(crate) async fn stream(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
88 {
89 let request = super::CompletionRequest::try_from(OpenAIRequestParams {
90 model: self.model.clone(),
91 request: completion_request,
92 strict_tools: self.strict_tools,
93 tool_result_array_content: self.tool_result_array_content,
94 })?;
95 let request_messages = serde_json::to_string(&request.messages)
96 .expect("Converting to JSON from a Rust struct shouldn't fail");
97 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
98
99 request_as_json = merge(
100 request_as_json,
101 json!({"stream": true, "stream_options": {"include_usage": true}}),
102 );
103
104 if enabled!(Level::TRACE) {
105 tracing::trace!(
106 target: "rig::completions",
107 "OpenAI Chat Completions streaming completion request: {}",
108 serde_json::to_string_pretty(&request_as_json)?
109 );
110 }
111
112 let req_body = serde_json::to_vec(&request_as_json)?;
113
114 let req = self
115 .client
116 .post("/chat/completions")?
117 .body(req_body)
118 .map_err(|e| CompletionError::HttpError(e.into()))?;
119
120 let span = if tracing::Span::current().is_disabled() {
121 info_span!(
122 target: "rig::completions",
123 "chat",
124 gen_ai.operation.name = "chat",
125 gen_ai.provider.name = "openai",
126 gen_ai.request.model = self.model,
127 gen_ai.response.id = tracing::field::Empty,
128 gen_ai.response.model = self.model,
129 gen_ai.usage.output_tokens = tracing::field::Empty,
130 gen_ai.usage.input_tokens = tracing::field::Empty,
131 gen_ai.input.messages = request_messages,
132 gen_ai.output.messages = tracing::field::Empty,
133 )
134 } else {
135 tracing::Span::current()
136 };
137
138 let client = self.client.clone();
139
140 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
141 }
142}
143
144pub async fn send_compatible_streaming_request<T>(
145 http_client: T,
146 req: Request<Vec<u8>>,
147) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
148where
149 T: HttpClientExt + Clone + 'static,
150{
151 let span = tracing::Span::current();
152 let mut event_source = GenericEventSource::new(http_client, req);
154
155 let stream = stream! {
156 let span = tracing::Span::current();
157
158 let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
160 let mut text_content = String::new();
161 let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
162 let mut final_usage = None;
163
164 while let Some(event_result) = event_source.next().await {
165 match event_result {
166 Ok(Event::Open) => {
167 tracing::trace!("SSE connection opened");
168 continue;
169 }
170
171 Ok(Event::Message(message)) => {
172 if message.data.trim().is_empty() || message.data == "[DONE]" {
173 continue;
174 }
175
176 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
177 Ok(data) => data,
178 Err(error) => {
179 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
180 continue;
181 }
182 };
183
184 if let Some(usage) = data.usage {
186 final_usage = Some(usage);
187 }
188
189 let Some(choice) = data.choices.first() else {
191 tracing::debug!("There is no choice");
192 continue;
193 };
194 let delta = &choice.delta;
195
196 if !delta.tool_calls.is_empty() {
197 for tool_call in &delta.tool_calls {
198 let index = tool_call.index;
199
200 let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
202
203 if let Some(id) = &tool_call.id && !id.is_empty() {
205 existing_tool_call.id = id.clone();
206 }
207
208 if let Some(name) = &tool_call.function.name && !name.is_empty() {
209 existing_tool_call.name = name.clone();
210 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
211 id: existing_tool_call.id.clone(),
212 internal_call_id: existing_tool_call.internal_call_id.clone(),
213 content: streaming::ToolCallDeltaContent::Name(name.clone()),
214 });
215 }
216
217 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
219 let current_args = match &existing_tool_call.arguments {
220 serde_json::Value::Null => String::new(),
221 serde_json::Value::String(s) => s.clone(),
222 v => v.to_string(),
223 };
224
225 let combined = format!("{current_args}{chunk}");
227
228 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
230 match serde_json::from_str(&combined) {
231 Ok(parsed) => existing_tool_call.arguments = parsed,
232 Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
233 }
234 } else {
235 existing_tool_call.arguments = serde_json::Value::String(combined);
236 }
237
238 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
240 id: existing_tool_call.id.clone(),
241 internal_call_id: existing_tool_call.internal_call_id.clone(),
242 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
243 });
244 }
245 }
246 }
247
248 if let Some(content) = &delta.content && !content.is_empty() {
250 text_content += content;
251 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
252 }
253
254 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
256 for (_idx, tool_call) in tool_calls.into_iter() {
257 final_tool_calls.push(completion::ToolCall {
258 id: tool_call.id.clone(),
259 r#type: completion::ToolType::Function,
260 function: completion::Function {
261 name: tool_call.name.clone(),
262 arguments: tool_call.arguments.clone(),
263 },
264 });
265 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
266 }
267 tool_calls = HashMap::new();
268 }
269 }
270 Err(crate::http_client::Error::StreamEnded) => {
271 break;
272 }
273 Err(error) => {
274 tracing::error!(?error, "SSE error");
275 yield Err(CompletionError::ProviderError(error.to_string()));
276 break;
277 }
278 }
279 }
280
281
282 event_source.close();
284
285 for (_idx, tool_call) in tool_calls.into_iter() {
287 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
288 }
289
290 let final_usage = final_usage.unwrap_or_default();
291 if !span.is_disabled() {
292 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
293 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
294 }
295
296 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
297 usage: final_usage
298 }));
299 }.instrument(span);
300
301 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
302 stream,
303 )))
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_streaming_function_deserialization() {
312 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
313 let function: StreamingFunction = serde_json::from_str(json).unwrap();
314 assert_eq!(function.name, Some("get_weather".to_string()));
315 assert_eq!(
316 function.arguments.as_ref().unwrap(),
317 r#"{"location":"Paris"}"#
318 );
319 }
320
321 #[test]
322 fn test_streaming_tool_call_deserialization() {
323 let json = r#"{
324 "index": 0,
325 "id": "call_abc123",
326 "function": {
327 "name": "get_weather",
328 "arguments": "{\"city\":\"London\"}"
329 }
330 }"#;
331 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
332 assert_eq!(tool_call.index, 0);
333 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
334 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
335 }
336
337 #[test]
338 fn test_streaming_tool_call_partial_deserialization() {
339 let json = r#"{
341 "index": 0,
342 "id": null,
343 "function": {
344 "name": null,
345 "arguments": "Paris"
346 }
347 }"#;
348 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
349 assert_eq!(tool_call.index, 0);
350 assert!(tool_call.id.is_none());
351 assert!(tool_call.function.name.is_none());
352 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
353 }
354
355 #[test]
356 fn test_streaming_delta_with_tool_calls() {
357 let json = r#"{
358 "content": null,
359 "tool_calls": [{
360 "index": 0,
361 "id": "call_xyz",
362 "function": {
363 "name": "search",
364 "arguments": ""
365 }
366 }]
367 }"#;
368 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
369 assert!(delta.content.is_none());
370 assert_eq!(delta.tool_calls.len(), 1);
371 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
372 }
373
374 #[test]
375 fn test_streaming_chunk_deserialization() {
376 let json = r#"{
377 "choices": [{
378 "delta": {
379 "content": "Hello",
380 "tool_calls": []
381 }
382 }],
383 "usage": {
384 "prompt_tokens": 10,
385 "completion_tokens": 5,
386 "total_tokens": 15
387 }
388 }"#;
389 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
390 assert_eq!(chunk.choices.len(), 1);
391 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
392 assert!(chunk.usage.is_some());
393 }
394
395 #[test]
396 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
397 let json_start = r#"{
399 "choices": [{
400 "delta": {
401 "content": null,
402 "tool_calls": [{
403 "index": 0,
404 "id": "call_123",
405 "function": {
406 "name": "get_weather",
407 "arguments": ""
408 }
409 }]
410 }
411 }],
412 "usage": null
413 }"#;
414
415 let json_chunk1 = r#"{
416 "choices": [{
417 "delta": {
418 "content": null,
419 "tool_calls": [{
420 "index": 0,
421 "id": null,
422 "function": {
423 "name": null,
424 "arguments": "{\"loc"
425 }
426 }]
427 }
428 }],
429 "usage": null
430 }"#;
431
432 let json_chunk2 = r#"{
433 "choices": [{
434 "delta": {
435 "content": null,
436 "tool_calls": [{
437 "index": 0,
438 "id": null,
439 "function": {
440 "name": null,
441 "arguments": "ation\":\"NYC\"}"
442 }
443 }]
444 }
445 }],
446 "usage": null
447 }"#;
448
449 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
451 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
452 assert_eq!(
453 start_chunk.choices[0].delta.tool_calls[0]
454 .function
455 .name
456 .as_ref()
457 .unwrap(),
458 "get_weather"
459 );
460
461 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
462 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
463 assert_eq!(
464 chunk1.choices[0].delta.tool_calls[0]
465 .function
466 .arguments
467 .as_ref()
468 .unwrap(),
469 "{\"loc"
470 );
471
472 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
473 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
474 assert_eq!(
475 chunk2.choices[0].delta.tool_calls[0]
476 .function
477 .arguments
478 .as_ref()
479 .unwrap(),
480 "ation\":\"NYC\"}"
481 );
482 }
483
484 #[tokio::test]
485 async fn test_streaming_usage_only_chunk_is_not_ignored() {
486 use bytes::Bytes;
487 use futures::StreamExt;
488
489 #[derive(Clone)]
490 struct MockHttpClient {
491 sse_bytes: Bytes,
492 }
493
494 impl crate::http_client::HttpClientExt for MockHttpClient {
495 fn send<T, U>(
496 &self,
497 _req: http::Request<T>,
498 ) -> impl std::future::Future<
499 Output = crate::http_client::Result<
500 http::Response<crate::http_client::LazyBody<U>>,
501 >,
502 > + crate::wasm_compat::WasmCompatSend
503 + 'static
504 where
505 T: Into<Bytes>,
506 T: crate::wasm_compat::WasmCompatSend,
507 U: From<Bytes>,
508 U: crate::wasm_compat::WasmCompatSend + 'static,
509 {
510 std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
511 http::StatusCode::NOT_IMPLEMENTED,
512 )))
513 }
514
515 fn send_multipart<U>(
516 &self,
517 _req: http::Request<crate::http_client::MultipartForm>,
518 ) -> impl std::future::Future<
519 Output = crate::http_client::Result<
520 http::Response<crate::http_client::LazyBody<U>>,
521 >,
522 > + crate::wasm_compat::WasmCompatSend
523 + 'static
524 where
525 U: From<Bytes>,
526 U: crate::wasm_compat::WasmCompatSend + 'static,
527 {
528 std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
529 http::StatusCode::NOT_IMPLEMENTED,
530 )))
531 }
532
533 fn send_streaming<T>(
534 &self,
535 _req: http::Request<T>,
536 ) -> impl std::future::Future<
537 Output = crate::http_client::Result<crate::http_client::StreamingResponse>,
538 > + crate::wasm_compat::WasmCompatSend
539 where
540 T: Into<Bytes>,
541 {
542 let sse_bytes = self.sse_bytes.clone();
543 async move {
544 let byte_stream = futures::stream::iter(vec![Ok::<
545 Bytes,
546 crate::http_client::Error,
547 >(sse_bytes)]);
548 let boxed_stream: crate::http_client::sse::BoxedStream = Box::pin(byte_stream);
549
550 http::Response::builder()
551 .status(http::StatusCode::OK)
552 .header(reqwest::header::CONTENT_TYPE, "text/event-stream")
553 .body(boxed_stream)
554 .map_err(crate::http_client::Error::Protocol)
555 }
556 }
557 }
558
559 let sse = concat!(
561 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
562 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
563 "data: [DONE]\n\n",
564 );
565
566 let client = MockHttpClient {
567 sse_bytes: Bytes::from(sse),
568 };
569
570 let req = http::Request::builder()
571 .method("POST")
572 .uri("http://localhost/v1/chat/completions")
573 .body(Vec::new())
574 .unwrap();
575
576 let mut stream = send_compatible_streaming_request(client, req)
577 .await
578 .unwrap();
579
580 let mut final_usage = None;
581 while let Some(chunk) = stream.next().await {
582 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
583 final_usage = Some(res.usage);
584 break;
585 }
586 }
587
588 let usage = final_usage.expect("expected a final response with usage");
589 assert_eq!(usage.prompt_tokens, 10);
590 assert_eq!(usage.total_tokens, 15);
591 }
592}