1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23 pub(crate) name: Option<String>,
24 pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29 pub(crate) index: usize,
30 pub(crate) id: Option<String>,
31 pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36 #[serde(default)]
37 content: Option<String>,
38 #[serde(default)]
39 reasoning_content: Option<String>, #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41 tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug, PartialEq)]
45#[serde(rename_all = "snake_case")]
46pub enum FinishReason {
47 ToolCalls,
48 Stop,
49 ContentFilter,
50 Length,
51 #[serde(untagged)]
52 Other(String), }
54
55#[derive(Deserialize, Debug)]
56struct StreamingChoice {
57 delta: StreamingDelta,
58 finish_reason: Option<FinishReason>,
59}
60
61#[derive(Deserialize, Debug)]
62struct StreamingCompletionChunk {
63 choices: Vec<StreamingChoice>,
64 usage: Option<Usage>,
65}
66
67#[derive(Clone, Serialize, Deserialize)]
68pub struct StreamingCompletionResponse {
69 pub usage: Usage,
70}
71
72impl GetTokenUsage for StreamingCompletionResponse {
73 fn token_usage(&self) -> Option<crate::completion::Usage> {
74 let mut usage = crate::completion::Usage::new();
75 usage.input_tokens = self.usage.prompt_tokens as u64;
76 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
77 usage.total_tokens = self.usage.total_tokens as u64;
78 Some(usage)
79 }
80}
81
82impl<T> CompletionModel<T>
83where
84 T: HttpClientExt + Clone + 'static,
85{
86 pub(crate) async fn stream(
87 &self,
88 completion_request: CompletionRequest,
89 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
90 {
91 let request = super::CompletionRequest::try_from(OpenAIRequestParams {
92 model: self.model.clone(),
93 request: completion_request,
94 strict_tools: self.strict_tools,
95 tool_result_array_content: self.tool_result_array_content,
96 })?;
97 let request_messages = serde_json::to_string(&request.messages)
98 .expect("Converting to JSON from a Rust struct shouldn't fail");
99 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
100
101 request_as_json = merge(
102 request_as_json,
103 json!({"stream": true, "stream_options": {"include_usage": true}}),
104 );
105
106 if enabled!(Level::TRACE) {
107 tracing::trace!(
108 target: "rig::completions",
109 "OpenAI Chat Completions streaming completion request: {}",
110 serde_json::to_string_pretty(&request_as_json)?
111 );
112 }
113
114 let req_body = serde_json::to_vec(&request_as_json)?;
115
116 let req = self
117 .client
118 .post("/chat/completions")?
119 .body(req_body)
120 .map_err(|e| CompletionError::HttpError(e.into()))?;
121
122 let span = if tracing::Span::current().is_disabled() {
123 info_span!(
124 target: "rig::completions",
125 "chat",
126 gen_ai.operation.name = "chat",
127 gen_ai.provider.name = "openai",
128 gen_ai.request.model = self.model,
129 gen_ai.response.id = tracing::field::Empty,
130 gen_ai.response.model = self.model,
131 gen_ai.usage.output_tokens = tracing::field::Empty,
132 gen_ai.usage.input_tokens = tracing::field::Empty,
133 gen_ai.input.messages = request_messages,
134 gen_ai.output.messages = tracing::field::Empty,
135 )
136 } else {
137 tracing::Span::current()
138 };
139
140 let client = self.client.clone();
141
142 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
143 }
144}
145
146pub async fn send_compatible_streaming_request<T>(
147 http_client: T,
148 req: Request<Vec<u8>>,
149) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
150where
151 T: HttpClientExt + Clone + 'static,
152{
153 let span = tracing::Span::current();
154 let mut event_source = GenericEventSource::new(http_client, req);
156
157 let stream = stream! {
158 let span = tracing::Span::current();
159
160 let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
162 let mut text_content = String::new();
163 let mut final_usage = None;
164
165 while let Some(event_result) = event_source.next().await {
166 match event_result {
167 Ok(Event::Open) => {
168 tracing::trace!("SSE connection opened");
169 continue;
170 }
171
172 Ok(Event::Message(message)) => {
173 if message.data.trim().is_empty() || message.data == "[DONE]" {
174 continue;
175 }
176
177 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
178 Ok(data) => data,
179 Err(error) => {
180 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
181 continue;
182 }
183 };
184
185 if let Some(usage) = data.usage {
187 final_usage = Some(usage);
188 }
189
190 let Some(choice) = data.choices.first() else {
192 tracing::debug!("There is no choice");
193 continue;
194 };
195 let delta = &choice.delta;
196
197 if !delta.tool_calls.is_empty() {
198 for tool_call in &delta.tool_calls {
199 let index = tool_call.index;
200
201 if let Some(new_id) = &tool_call.id
207 && !new_id.is_empty()
208 && let Some(existing) = tool_calls.get(&index)
209 && !existing.id.is_empty()
210 && existing.id != *new_id
211 {
212 let evicted = tool_calls.remove(&index).expect("checked above");
213 yield Ok(streaming::RawStreamingChoice::ToolCall(evicted));
214 }
215
216 let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
217
218 if let Some(id) = &tool_call.id && !id.is_empty() {
219 existing_tool_call.id = id.clone();
220 }
221
222 if let Some(name) = &tool_call.function.name && !name.is_empty() {
223 existing_tool_call.name = name.clone();
224 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
225 id: existing_tool_call.id.clone(),
226 internal_call_id: existing_tool_call.internal_call_id.clone(),
227 content: streaming::ToolCallDeltaContent::Name(name.clone()),
228 });
229 }
230
231 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
233 let current_args = match &existing_tool_call.arguments {
234 serde_json::Value::Null => String::new(),
235 serde_json::Value::String(s) => s.clone(),
236 v => v.to_string(),
237 };
238
239 let combined = format!("{current_args}{chunk}");
241
242 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
244 match serde_json::from_str(&combined) {
245 Ok(parsed) => existing_tool_call.arguments = parsed,
246 Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
247 }
248 } else {
249 existing_tool_call.arguments = serde_json::Value::String(combined);
250 }
251
252 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
254 id: existing_tool_call.id.clone(),
255 internal_call_id: existing_tool_call.internal_call_id.clone(),
256 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
257 });
258 }
259 }
260 }
261
262 if let Some(reasoning) = &delta.reasoning_content && !reasoning.is_empty() {
264 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
265 id: None,
266 reasoning: reasoning.clone(),
267 });
268 }
269
270 if let Some(content) = &delta.content && !content.is_empty() {
272 text_content += content;
273 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
274 }
275
276 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
278 for (_idx, tool_call) in tool_calls.into_iter() {
279 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
280 }
281 tool_calls = HashMap::new();
282 }
283 }
284 Err(crate::http_client::Error::StreamEnded) => {
285 break;
286 }
287 Err(error) => {
288 tracing::error!(?error, "SSE error");
289 yield Err(CompletionError::ProviderError(error.to_string()));
290 break;
291 }
292 }
293 }
294
295
296 event_source.close();
298
299 for (_idx, tool_call) in tool_calls.into_iter() {
301 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
302 }
303
304 let final_usage = final_usage.unwrap_or_default();
305 if !span.is_disabled() {
306 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
307 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
308 }
309
310 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
311 usage: final_usage
312 }));
313 }.instrument(span);
314
315 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
316 stream,
317 )))
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_streaming_function_deserialization() {
326 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
327 let function: StreamingFunction = serde_json::from_str(json).unwrap();
328 assert_eq!(function.name, Some("get_weather".to_string()));
329 assert_eq!(
330 function.arguments.as_ref().unwrap(),
331 r#"{"location":"Paris"}"#
332 );
333 }
334
335 #[test]
336 fn test_streaming_tool_call_deserialization() {
337 let json = r#"{
338 "index": 0,
339 "id": "call_abc123",
340 "function": {
341 "name": "get_weather",
342 "arguments": "{\"city\":\"London\"}"
343 }
344 }"#;
345 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
346 assert_eq!(tool_call.index, 0);
347 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
348 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
349 }
350
351 #[test]
352 fn test_streaming_tool_call_partial_deserialization() {
353 let json = r#"{
355 "index": 0,
356 "id": null,
357 "function": {
358 "name": null,
359 "arguments": "Paris"
360 }
361 }"#;
362 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
363 assert_eq!(tool_call.index, 0);
364 assert!(tool_call.id.is_none());
365 assert!(tool_call.function.name.is_none());
366 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
367 }
368
369 #[test]
370 fn test_streaming_delta_with_tool_calls() {
371 let json = r#"{
372 "content": null,
373 "tool_calls": [{
374 "index": 0,
375 "id": "call_xyz",
376 "function": {
377 "name": "search",
378 "arguments": ""
379 }
380 }]
381 }"#;
382 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
383 assert!(delta.content.is_none());
384 assert_eq!(delta.tool_calls.len(), 1);
385 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
386 }
387
388 #[test]
389 fn test_streaming_chunk_deserialization() {
390 let json = r#"{
391 "choices": [{
392 "delta": {
393 "content": "Hello",
394 "tool_calls": []
395 }
396 }],
397 "usage": {
398 "prompt_tokens": 10,
399 "completion_tokens": 5,
400 "total_tokens": 15
401 }
402 }"#;
403 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
404 assert_eq!(chunk.choices.len(), 1);
405 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
406 assert!(chunk.usage.is_some());
407 }
408
409 #[test]
410 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
411 let json_start = r#"{
413 "choices": [{
414 "delta": {
415 "content": null,
416 "tool_calls": [{
417 "index": 0,
418 "id": "call_123",
419 "function": {
420 "name": "get_weather",
421 "arguments": ""
422 }
423 }]
424 }
425 }],
426 "usage": null
427 }"#;
428
429 let json_chunk1 = r#"{
430 "choices": [{
431 "delta": {
432 "content": null,
433 "tool_calls": [{
434 "index": 0,
435 "id": null,
436 "function": {
437 "name": null,
438 "arguments": "{\"loc"
439 }
440 }]
441 }
442 }],
443 "usage": null
444 }"#;
445
446 let json_chunk2 = r#"{
447 "choices": [{
448 "delta": {
449 "content": null,
450 "tool_calls": [{
451 "index": 0,
452 "id": null,
453 "function": {
454 "name": null,
455 "arguments": "ation\":\"NYC\"}"
456 }
457 }]
458 }
459 }],
460 "usage": null
461 }"#;
462
463 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
465 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
466 assert_eq!(
467 start_chunk.choices[0].delta.tool_calls[0]
468 .function
469 .name
470 .as_ref()
471 .unwrap(),
472 "get_weather"
473 );
474
475 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
476 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
477 assert_eq!(
478 chunk1.choices[0].delta.tool_calls[0]
479 .function
480 .arguments
481 .as_ref()
482 .unwrap(),
483 "{\"loc"
484 );
485
486 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
487 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
488 assert_eq!(
489 chunk2.choices[0].delta.tool_calls[0]
490 .function
491 .arguments
492 .as_ref()
493 .unwrap(),
494 "ation\":\"NYC\"}"
495 );
496 }
497
498 #[tokio::test]
499 async fn test_streaming_usage_only_chunk_is_not_ignored() {
500 use crate::http_client::mock::MockStreamingClient;
501 use bytes::Bytes;
502 use futures::StreamExt;
503
504 let sse = concat!(
506 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
507 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
508 "data: [DONE]\n\n",
509 );
510
511 let client = MockStreamingClient {
512 sse_bytes: Bytes::from(sse),
513 };
514
515 let req = http::Request::builder()
516 .method("POST")
517 .uri("http://localhost/v1/chat/completions")
518 .body(Vec::new())
519 .unwrap();
520
521 let mut stream = send_compatible_streaming_request(client, req)
522 .await
523 .unwrap();
524
525 let mut final_usage = None;
526 while let Some(chunk) = stream.next().await {
527 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
528 final_usage = Some(res.usage);
529 break;
530 }
531 }
532
533 let usage = final_usage.expect("expected a final response with usage");
534 assert_eq!(usage.prompt_tokens, 10);
535 assert_eq!(usage.total_tokens, 15);
536 }
537
538 #[tokio::test]
542 async fn test_duplicate_index_different_id_tool_calls() {
543 use crate::http_client::mock::MockStreamingClient;
544 use bytes::Bytes;
545 use futures::StreamExt;
546
547 let sse = concat!(
551 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_aaa\",\"function\":{\"name\":\"command\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
553 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"cmd\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
555 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"ls\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
556 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_bbb\",\"function\":{\"name\":\"git\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
558 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"action\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
560 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"log\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
561 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
563 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":20,\"completion_tokens\":10,\"total_tokens\":30}}\n\n",
565 "data: [DONE]\n\n",
566 );
567
568 let client = MockStreamingClient {
569 sse_bytes: Bytes::from(sse),
570 };
571
572 let req = http::Request::builder()
573 .method("POST")
574 .uri("http://localhost/v1/chat/completions")
575 .body(Vec::new())
576 .unwrap();
577
578 let mut stream = send_compatible_streaming_request(client, req)
579 .await
580 .unwrap();
581
582 let mut collected_tool_calls = Vec::new();
583 while let Some(chunk) = stream.next().await {
584 if let streaming::StreamedAssistantContent::ToolCall {
585 tool_call,
586 internal_call_id: _,
587 } = chunk.unwrap()
588 {
589 collected_tool_calls.push(tool_call);
590 }
591 }
592
593 assert_eq!(
594 collected_tool_calls.len(),
595 2,
596 "expected 2 separate tool calls, got {collected_tool_calls:?}"
597 );
598
599 assert_eq!(collected_tool_calls[0].id, "call_aaa");
600 assert_eq!(collected_tool_calls[0].function.name, "command");
601 assert_eq!(
602 collected_tool_calls[0].function.arguments,
603 serde_json::json!({"cmd": "ls"})
604 );
605
606 assert_eq!(collected_tool_calls[1].id, "call_bbb");
607 assert_eq!(collected_tool_calls[1].function.name, "git");
608 assert_eq!(
609 collected_tool_calls[1].function.arguments,
610 serde_json::json!({"action": "log"})
611 );
612 }
613}