1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::message::{ToolCall, ToolFunction};
16use crate::providers::openai::completion::{self, CompletionModel, OpenAIRequestParams, Usage};
17use crate::streaming::{self, RawStreamingChoice};
18
19#[derive(Deserialize, Debug)]
23pub(crate) struct StreamingFunction {
24 pub(crate) name: Option<String>,
25 pub(crate) arguments: Option<String>,
26}
27
28#[derive(Deserialize, Debug)]
29pub(crate) struct StreamingToolCall {
30 pub(crate) index: usize,
31 pub(crate) id: Option<String>,
32 pub(crate) function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37 #[serde(default)]
38 content: Option<String>,
39 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40 tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug, PartialEq)]
44#[serde(rename_all = "snake_case")]
45pub enum FinishReason {
46 ToolCalls,
47 Stop,
48 ContentFilter,
49 Length,
50 #[serde(untagged)]
51 Other(String), }
53
54#[derive(Deserialize, Debug)]
55struct StreamingChoice {
56 delta: StreamingDelta,
57 finish_reason: Option<FinishReason>,
58}
59
60#[derive(Deserialize, Debug)]
61struct StreamingCompletionChunk {
62 choices: Vec<StreamingChoice>,
63 usage: Option<Usage>,
64}
65
66#[derive(Clone, Serialize, Deserialize)]
67pub struct StreamingCompletionResponse {
68 pub usage: Usage,
69}
70
71impl GetTokenUsage for StreamingCompletionResponse {
72 fn token_usage(&self) -> Option<crate::completion::Usage> {
73 let mut usage = crate::completion::Usage::new();
74 usage.input_tokens = self.usage.prompt_tokens as u64;
75 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
76 usage.total_tokens = self.usage.total_tokens as u64;
77 Some(usage)
78 }
79}
80
81impl<T> CompletionModel<T>
82where
83 T: HttpClientExt + Clone + 'static,
84{
85 pub(crate) async fn stream(
86 &self,
87 completion_request: CompletionRequest,
88 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
89 {
90 let request = super::CompletionRequest::try_from(OpenAIRequestParams {
91 model: self.model.clone(),
92 request: completion_request,
93 strict_tools: self.strict_tools,
94 tool_result_array_content: self.tool_result_array_content,
95 })?;
96 let request_messages = serde_json::to_string(&request.messages)
97 .expect("Converting to JSON from a Rust struct shouldn't fail");
98 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
99
100 request_as_json = merge(
101 request_as_json,
102 json!({"stream": true, "stream_options": {"include_usage": true}}),
103 );
104
105 if enabled!(Level::TRACE) {
106 tracing::trace!(
107 target: "rig::completions",
108 "OpenAI Chat Completions streaming completion request: {}",
109 serde_json::to_string_pretty(&request_as_json)?
110 );
111 }
112
113 let req_body = serde_json::to_vec(&request_as_json)?;
114
115 let req = self
116 .client
117 .post("/chat/completions")?
118 .body(req_body)
119 .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121 let span = if tracing::Span::current().is_disabled() {
122 info_span!(
123 target: "rig::completions",
124 "chat",
125 gen_ai.operation.name = "chat",
126 gen_ai.provider.name = "openai",
127 gen_ai.request.model = self.model,
128 gen_ai.response.id = tracing::field::Empty,
129 gen_ai.response.model = self.model,
130 gen_ai.usage.output_tokens = tracing::field::Empty,
131 gen_ai.usage.input_tokens = tracing::field::Empty,
132 gen_ai.input.messages = request_messages,
133 gen_ai.output.messages = tracing::field::Empty,
134 )
135 } else {
136 tracing::Span::current()
137 };
138
139 let client = self.client.clone();
140
141 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
142 }
143}
144
145pub async fn send_compatible_streaming_request<T>(
146 http_client: T,
147 req: Request<Vec<u8>>,
148) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
149where
150 T: HttpClientExt + Clone + 'static,
151{
152 let span = tracing::Span::current();
153 let mut event_source = GenericEventSource::new(http_client, req);
155
156 let stream = stream! {
157 let span = tracing::Span::current();
158
159 let mut tool_calls: HashMap<usize, ToolCall> = HashMap::new();
161 let mut text_content = String::new();
162 let mut final_tool_calls: Vec<completion::ToolCall> = Vec::new();
163 let mut final_usage = None;
164
165 while let Some(event_result) = event_source.next().await {
166 match event_result {
167 Ok(Event::Open) => {
168 tracing::trace!("SSE connection opened");
169 continue;
170 }
171
172 Ok(Event::Message(message)) => {
173 if message.data.trim().is_empty() || message.data == "[DONE]" {
174 continue;
175 }
176
177 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
178 Ok(data) => data,
179 Err(error) => {
180 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
181 continue;
182 }
183 };
184
185 if let Some(usage) = data.usage {
187 final_usage = Some(usage);
188 }
189
190 let Some(choice) = data.choices.first() else {
192 tracing::debug!("There is no choice");
193 continue;
194 };
195 let delta = &choice.delta;
196
197 if !delta.tool_calls.is_empty() {
198 for tool_call in &delta.tool_calls {
199 let index = tool_call.index;
200
201 let existing_tool_call = tool_calls.entry(index).or_insert_with(|| ToolCall {
203 id: String::new(),
204 call_id: None,
205 function: ToolFunction {
206 name: String::new(),
207 arguments: serde_json::Value::Null,
208 },
209 signature: None,
210 additional_params: None,
211 });
212
213 if let Some(id) = &tool_call.id && !id.is_empty() {
215 existing_tool_call.id = id.clone();
216 }
217
218 if let Some(name) = &tool_call.function.name && !name.is_empty() {
219 existing_tool_call.function.name = name.clone();
220 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
221 id: existing_tool_call.id.clone(),
222 content: streaming::ToolCallDeltaContent::Name(name.clone()),
223 });
224 }
225
226 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
228 let current_args = match &existing_tool_call.function.arguments {
229 serde_json::Value::Null => String::new(),
230 serde_json::Value::String(s) => s.clone(),
231 v => v.to_string(),
232 };
233
234 let combined = format!("{current_args}{chunk}");
236
237 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
239 match serde_json::from_str(&combined) {
240 Ok(parsed) => existing_tool_call.function.arguments = parsed,
241 Err(_) => existing_tool_call.function.arguments = serde_json::Value::String(combined),
242 }
243 } else {
244 existing_tool_call.function.arguments = serde_json::Value::String(combined);
245 }
246
247 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
249 id: existing_tool_call.id.clone(),
250 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
251 });
252 }
253 }
254 }
255
256 if let Some(content) = &delta.content && !content.is_empty() {
258 text_content += content;
259 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
260 }
261
262 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
264 for (_idx, tool_call) in tool_calls.into_iter() {
265 final_tool_calls.push(completion::ToolCall {
266 id: tool_call.id.clone(),
267 r#type: completion::ToolType::Function,
268 function: completion::Function {
269 name: tool_call.function.name.clone(),
270 arguments: tool_call.function.arguments.clone(),
271 },
272 });
273 yield Ok(streaming::RawStreamingChoice::ToolCall(
274 streaming::RawStreamingToolCall::new(
275 tool_call.id,
276 tool_call.function.name,
277 tool_call.function.arguments,
278 )
279 ));
280 }
281 tool_calls = HashMap::new();
282 }
283 }
284 Err(crate::http_client::Error::StreamEnded) => {
285 break;
286 }
287 Err(error) => {
288 tracing::error!(?error, "SSE error");
289 yield Err(CompletionError::ProviderError(error.to_string()));
290 break;
291 }
292 }
293 }
294
295
296 event_source.close();
298
299 for (_idx, tool_call) in tool_calls.into_iter() {
301 yield Ok(streaming::RawStreamingChoice::ToolCall(
302 streaming::RawStreamingToolCall::new(
303 tool_call.id,
304 tool_call.function.name,
305 tool_call.function.arguments,
306 )
307 ));
308 }
309
310 let final_usage = final_usage.unwrap_or_default();
311 if !span.is_disabled() {
312 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
313 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
314 }
315
316 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
317 usage: final_usage
318 }));
319 }.instrument(span);
320
321 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
322 stream,
323 )))
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_streaming_function_deserialization() {
332 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
333 let function: StreamingFunction = serde_json::from_str(json).unwrap();
334 assert_eq!(function.name, Some("get_weather".to_string()));
335 assert_eq!(
336 function.arguments.as_ref().unwrap(),
337 r#"{"location":"Paris"}"#
338 );
339 }
340
341 #[test]
342 fn test_streaming_tool_call_deserialization() {
343 let json = r#"{
344 "index": 0,
345 "id": "call_abc123",
346 "function": {
347 "name": "get_weather",
348 "arguments": "{\"city\":\"London\"}"
349 }
350 }"#;
351 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
352 assert_eq!(tool_call.index, 0);
353 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
354 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
355 }
356
357 #[test]
358 fn test_streaming_tool_call_partial_deserialization() {
359 let json = r#"{
361 "index": 0,
362 "id": null,
363 "function": {
364 "name": null,
365 "arguments": "Paris"
366 }
367 }"#;
368 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
369 assert_eq!(tool_call.index, 0);
370 assert!(tool_call.id.is_none());
371 assert!(tool_call.function.name.is_none());
372 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
373 }
374
375 #[test]
376 fn test_streaming_delta_with_tool_calls() {
377 let json = r#"{
378 "content": null,
379 "tool_calls": [{
380 "index": 0,
381 "id": "call_xyz",
382 "function": {
383 "name": "search",
384 "arguments": ""
385 }
386 }]
387 }"#;
388 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
389 assert!(delta.content.is_none());
390 assert_eq!(delta.tool_calls.len(), 1);
391 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
392 }
393
394 #[test]
395 fn test_streaming_chunk_deserialization() {
396 let json = r#"{
397 "choices": [{
398 "delta": {
399 "content": "Hello",
400 "tool_calls": []
401 }
402 }],
403 "usage": {
404 "prompt_tokens": 10,
405 "completion_tokens": 5,
406 "total_tokens": 15
407 }
408 }"#;
409 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
410 assert_eq!(chunk.choices.len(), 1);
411 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
412 assert!(chunk.usage.is_some());
413 }
414
415 #[test]
416 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
417 let json_start = r#"{
419 "choices": [{
420 "delta": {
421 "content": null,
422 "tool_calls": [{
423 "index": 0,
424 "id": "call_123",
425 "function": {
426 "name": "get_weather",
427 "arguments": ""
428 }
429 }]
430 }
431 }],
432 "usage": null
433 }"#;
434
435 let json_chunk1 = r#"{
436 "choices": [{
437 "delta": {
438 "content": null,
439 "tool_calls": [{
440 "index": 0,
441 "id": null,
442 "function": {
443 "name": null,
444 "arguments": "{\"loc"
445 }
446 }]
447 }
448 }],
449 "usage": null
450 }"#;
451
452 let json_chunk2 = r#"{
453 "choices": [{
454 "delta": {
455 "content": null,
456 "tool_calls": [{
457 "index": 0,
458 "id": null,
459 "function": {
460 "name": null,
461 "arguments": "ation\":\"NYC\"}"
462 }
463 }]
464 }
465 }],
466 "usage": null
467 }"#;
468
469 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
471 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
472 assert_eq!(
473 start_chunk.choices[0].delta.tool_calls[0]
474 .function
475 .name
476 .as_ref()
477 .unwrap(),
478 "get_weather"
479 );
480
481 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
482 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
483 assert_eq!(
484 chunk1.choices[0].delta.tool_calls[0]
485 .function
486 .arguments
487 .as_ref()
488 .unwrap(),
489 "{\"loc"
490 );
491
492 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
493 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
494 assert_eq!(
495 chunk2.choices[0].delta.tool_calls[0]
496 .function
497 .arguments
498 .as_ref()
499 .unwrap(),
500 "ation\":\"NYC\"}"
501 );
502 }
503
504 #[tokio::test]
505 async fn test_streaming_usage_only_chunk_is_not_ignored() {
506 use bytes::Bytes;
507 use futures::StreamExt;
508
509 #[derive(Clone)]
510 struct MockHttpClient {
511 sse_bytes: Bytes,
512 }
513
514 impl crate::http_client::HttpClientExt for MockHttpClient {
515 fn send<T, U>(
516 &self,
517 _req: http::Request<T>,
518 ) -> impl std::future::Future<
519 Output = crate::http_client::Result<
520 http::Response<crate::http_client::LazyBody<U>>,
521 >,
522 > + crate::wasm_compat::WasmCompatSend
523 + 'static
524 where
525 T: Into<Bytes>,
526 T: crate::wasm_compat::WasmCompatSend,
527 U: From<Bytes>,
528 U: crate::wasm_compat::WasmCompatSend + 'static,
529 {
530 std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
531 http::StatusCode::NOT_IMPLEMENTED,
532 )))
533 }
534
535 fn send_multipart<U>(
536 &self,
537 _req: http::Request<crate::http_client::MultipartForm>,
538 ) -> impl std::future::Future<
539 Output = crate::http_client::Result<
540 http::Response<crate::http_client::LazyBody<U>>,
541 >,
542 > + crate::wasm_compat::WasmCompatSend
543 + 'static
544 where
545 U: From<Bytes>,
546 U: crate::wasm_compat::WasmCompatSend + 'static,
547 {
548 std::future::ready(Err(crate::http_client::Error::InvalidStatusCode(
549 http::StatusCode::NOT_IMPLEMENTED,
550 )))
551 }
552
553 fn send_streaming<T>(
554 &self,
555 _req: http::Request<T>,
556 ) -> impl std::future::Future<
557 Output = crate::http_client::Result<crate::http_client::StreamingResponse>,
558 > + crate::wasm_compat::WasmCompatSend
559 where
560 T: Into<Bytes>,
561 {
562 let sse_bytes = self.sse_bytes.clone();
563 async move {
564 let byte_stream = futures::stream::iter(vec![Ok::<
565 Bytes,
566 crate::http_client::Error,
567 >(sse_bytes)]);
568 let boxed_stream: crate::http_client::sse::BoxedStream = Box::pin(byte_stream);
569
570 http::Response::builder()
571 .status(http::StatusCode::OK)
572 .header(reqwest::header::CONTENT_TYPE, "text/event-stream")
573 .body(boxed_stream)
574 .map_err(crate::http_client::Error::Protocol)
575 }
576 }
577 }
578
579 let sse = concat!(
581 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
582 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
583 "data: [DONE]\n\n",
584 );
585
586 let client = MockHttpClient {
587 sse_bytes: Bytes::from(sse),
588 };
589
590 let req = http::Request::builder()
591 .method("POST")
592 .uri("http://localhost/v1/chat/completions")
593 .body(Vec::new())
594 .unwrap();
595
596 let mut stream = send_compatible_streaming_request(client, req)
597 .await
598 .unwrap();
599
600 let mut final_usage = None;
601 while let Some(chunk) = stream.next().await {
602 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
603 final_usage = Some(res.usage);
604 break;
605 }
606 }
607
608 let usage = final_usage.expect("expected a final response with usage");
609 assert_eq!(usage.prompt_tokens, 10);
610 assert_eq!(usage.total_tokens, 15);
611 }
612}