1use http::Request;
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4use tracing::{Level, enabled, info_span};
5
6use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
7use crate::http_client::HttpClientExt;
8use crate::json_utils::{self, merge};
9use crate::providers::internal::openai_chat_completions_compatible::{
10 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
11 CompatibleToolCallChunk,
12};
13use crate::providers::openai::completion::{GenericCompletionModel, OpenAIRequestParams, Usage};
14use crate::streaming;
15
16#[derive(Deserialize, Debug)]
20pub(crate) struct StreamingFunction {
21 pub(crate) name: Option<String>,
22 pub(crate) arguments: Option<String>,
23}
24
25#[derive(Deserialize, Debug)]
26pub(crate) struct StreamingToolCall {
27 pub(crate) index: usize,
28 pub(crate) id: Option<String>,
29 pub(crate) function: StreamingFunction,
30}
31
32impl From<&StreamingToolCall> for CompatibleToolCallChunk {
33 fn from(value: &StreamingToolCall) -> Self {
34 Self {
35 index: value.index,
36 id: value.id.clone(),
37 name: value.function.name.clone(),
38 arguments: value.function.arguments.clone(),
39 }
40 }
41}
42
43#[derive(Deserialize, Debug)]
44struct StreamingDelta {
45 #[serde(default)]
46 content: Option<String>,
47 #[serde(default)]
48 reasoning_content: Option<String>, #[serde(default, deserialize_with = "json_utils::null_or_vec")]
50 tool_calls: Vec<StreamingToolCall>,
51}
52
53#[derive(Deserialize, Debug, PartialEq)]
54#[serde(rename_all = "snake_case")]
55pub enum FinishReason {
56 ToolCalls,
57 Stop,
58 ContentFilter,
59 Length,
60 #[serde(untagged)]
61 Other(String), }
63
64#[derive(Deserialize, Debug)]
65struct StreamingChoice {
66 delta: StreamingDelta,
67 finish_reason: Option<FinishReason>,
68}
69
70#[derive(Deserialize, Debug)]
71struct StreamingCompletionChunk {
72 id: Option<String>,
73 model: Option<String>,
74 choices: Vec<StreamingChoice>,
75 usage: Option<Usage>,
76}
77
78#[derive(Clone, Serialize, Deserialize)]
79pub struct StreamingCompletionResponse {
80 pub usage: Usage,
81}
82
83impl GetTokenUsage for StreamingCompletionResponse {
84 fn token_usage(&self) -> Option<crate::completion::Usage> {
85 self.usage.token_usage()
86 }
87}
88
89impl<Ext, H> GenericCompletionModel<Ext, H>
90where
91 crate::client::Client<Ext, H>: HttpClientExt + Clone + 'static,
92 Ext: crate::client::Provider + Clone + 'static,
93{
94 pub(crate) async fn stream(
95 &self,
96 completion_request: CompletionRequest,
97 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
98 {
99 let request = super::CompletionRequest::try_from(OpenAIRequestParams {
100 model: self.model.clone(),
101 request: completion_request,
102 strict_tools: self.strict_tools,
103 tool_result_array_content: self.tool_result_array_content,
104 })?;
105 let request_messages = serde_json::to_string(&request.messages)?;
106 let mut request_as_json = serde_json::to_value(request)?;
107
108 request_as_json = merge(
109 request_as_json,
110 json!({"stream": true, "stream_options": {"include_usage": true}}),
111 );
112
113 if enabled!(Level::TRACE) {
114 tracing::trace!(
115 target: "rig::completions",
116 "OpenAI Chat Completions streaming completion request: {}",
117 serde_json::to_string_pretty(&request_as_json)?
118 );
119 }
120
121 let req_body = serde_json::to_vec(&request_as_json)?;
122
123 let req = self
124 .client
125 .post("/chat/completions")?
126 .body(req_body)
127 .map_err(|e| CompletionError::HttpError(e.into()))?;
128
129 let span = if tracing::Span::current().is_disabled() {
130 info_span!(
131 target: "rig::completions",
132 "chat",
133 gen_ai.operation.name = "chat",
134 gen_ai.provider.name = "openai",
135 gen_ai.request.model = self.model,
136 gen_ai.response.id = tracing::field::Empty,
137 gen_ai.response.model = tracing::field::Empty,
138 gen_ai.usage.output_tokens = tracing::field::Empty,
139 gen_ai.usage.input_tokens = tracing::field::Empty,
140 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
141 gen_ai.input.messages = request_messages,
142 gen_ai.output.messages = tracing::field::Empty,
143 )
144 } else {
145 tracing::Span::current()
146 };
147
148 let client = self.client.clone();
149
150 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
151 }
152}
153
154#[derive(Clone, Copy)]
155struct OpenAICompatibleProfile;
156
157impl CompatibleStreamProfile for OpenAICompatibleProfile {
158 type Usage = Usage;
159 type Detail = ();
160 type FinalResponse = StreamingCompletionResponse;
161
162 fn normalize_chunk(
163 &self,
164 data: &str,
165 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
166 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
167 Ok(data) => data,
168 Err(error) => {
169 tracing::error!(?error, message = data, "Failed to parse SSE message");
170 return Ok(None);
171 }
172 };
173
174 Ok(Some(
175 openai_chat_completions_compatible::normalize_first_choice_chunk(
176 data.id,
177 data.model,
178 data.usage,
179 &data.choices,
180 |choice| CompatibleChoiceData {
181 finish_reason: if choice.finish_reason == Some(FinishReason::ToolCalls) {
182 CompatibleFinishReason::ToolCalls
183 } else {
184 CompatibleFinishReason::Other
185 },
186 text: choice.delta.content.clone(),
187 reasoning: choice.delta.reasoning_content.clone(),
188 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
189 &choice.delta.tool_calls,
190 ),
191 details: Vec::new(),
192 },
193 ),
194 ))
195 }
196
197 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
198 StreamingCompletionResponse { usage }
199 }
200
201 fn uses_distinct_tool_call_eviction(&self) -> bool {
202 true
203 }
204}
205
206pub async fn send_compatible_streaming_request<T>(
207 http_client: T,
208 req: Request<Vec<u8>>,
209) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
210where
211 T: HttpClientExt + Clone + 'static,
212{
213 openai_chat_completions_compatible::send_compatible_streaming_request(
214 http_client,
215 req,
216 OpenAICompatibleProfile,
217 )
218 .await
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::providers::internal::openai_chat_completions_compatible::test_support::{
225 assert_zero_arg_tool_call_is_emitted, sse_bytes_from_data_lines,
226 };
227
228 #[test]
229 fn test_streaming_function_deserialization() {
230 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
231 let function: StreamingFunction = serde_json::from_str(json).unwrap();
232 assert_eq!(function.name, Some("get_weather".to_string()));
233 assert_eq!(
234 function.arguments.as_ref().unwrap(),
235 r#"{"location":"Paris"}"#
236 );
237 }
238
239 #[test]
240 fn test_streaming_tool_call_deserialization() {
241 let json = r#"{
242 "index": 0,
243 "id": "call_abc123",
244 "function": {
245 "name": "get_weather",
246 "arguments": "{\"city\":\"London\"}"
247 }
248 }"#;
249 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
250 assert_eq!(tool_call.index, 0);
251 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
252 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
253 }
254
255 #[test]
256 fn test_streaming_tool_call_partial_deserialization() {
257 let json = r#"{
259 "index": 0,
260 "id": null,
261 "function": {
262 "name": null,
263 "arguments": "Paris"
264 }
265 }"#;
266 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
267 assert_eq!(tool_call.index, 0);
268 assert!(tool_call.id.is_none());
269 assert!(tool_call.function.name.is_none());
270 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
271 }
272
273 #[test]
274 fn test_streaming_delta_with_tool_calls() {
275 let json = r#"{
276 "content": null,
277 "tool_calls": [{
278 "index": 0,
279 "id": "call_xyz",
280 "function": {
281 "name": "search",
282 "arguments": ""
283 }
284 }]
285 }"#;
286 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
287 assert!(delta.content.is_none());
288 assert_eq!(delta.tool_calls.len(), 1);
289 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
290 }
291
292 #[test]
293 fn test_streaming_chunk_deserialization() {
294 let json = r#"{
295 "choices": [{
296 "delta": {
297 "content": "Hello",
298 "tool_calls": []
299 }
300 }],
301 "usage": {
302 "prompt_tokens": 10,
303 "completion_tokens": 5,
304 "total_tokens": 15
305 }
306 }"#;
307 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
308 assert_eq!(chunk.choices.len(), 1);
309 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
310 assert!(chunk.usage.is_some());
311 }
312
313 #[test]
314 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
315 let json_start = r#"{
317 "choices": [{
318 "delta": {
319 "content": null,
320 "tool_calls": [{
321 "index": 0,
322 "id": "call_123",
323 "function": {
324 "name": "get_weather",
325 "arguments": ""
326 }
327 }]
328 }
329 }],
330 "usage": null
331 }"#;
332
333 let json_chunk1 = r#"{
334 "choices": [{
335 "delta": {
336 "content": null,
337 "tool_calls": [{
338 "index": 0,
339 "id": null,
340 "function": {
341 "name": null,
342 "arguments": "{\"loc"
343 }
344 }]
345 }
346 }],
347 "usage": null
348 }"#;
349
350 let json_chunk2 = r#"{
351 "choices": [{
352 "delta": {
353 "content": null,
354 "tool_calls": [{
355 "index": 0,
356 "id": null,
357 "function": {
358 "name": null,
359 "arguments": "ation\":\"NYC\"}"
360 }
361 }]
362 }
363 }],
364 "usage": null
365 }"#;
366
367 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
369 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
370 assert_eq!(
371 start_chunk.choices[0].delta.tool_calls[0]
372 .function
373 .name
374 .as_ref()
375 .unwrap(),
376 "get_weather"
377 );
378
379 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
380 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
381 assert_eq!(
382 chunk1.choices[0].delta.tool_calls[0]
383 .function
384 .arguments
385 .as_ref()
386 .unwrap(),
387 "{\"loc"
388 );
389
390 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
391 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
392 assert_eq!(
393 chunk2.choices[0].delta.tool_calls[0]
394 .function
395 .arguments
396 .as_ref()
397 .unwrap(),
398 "ation\":\"NYC\"}"
399 );
400 }
401
402 #[tokio::test]
403 async fn test_streaming_usage_only_chunk_is_not_ignored() {
404 use crate::http_client::mock::MockStreamingClient;
405 use futures::StreamExt;
406
407 let client = MockStreamingClient {
409 sse_bytes: sse_bytes_from_data_lines([
410 "{\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}",
411 "{\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}",
412 "[DONE]",
413 ]),
414 };
415
416 let req = http::Request::builder()
417 .method("POST")
418 .uri("http://localhost/v1/chat/completions")
419 .body(Vec::new())
420 .unwrap();
421
422 let mut stream = send_compatible_streaming_request(client, req)
423 .await
424 .unwrap();
425
426 let mut final_usage = None;
427 while let Some(chunk) = stream.next().await {
428 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
429 final_usage = Some(res.usage);
430 break;
431 }
432 }
433
434 let usage = final_usage.expect("expected a final response with usage");
435 assert_eq!(usage.prompt_tokens, 10);
436 assert_eq!(usage.total_tokens, 15);
437 }
438
439 #[tokio::test]
440 async fn test_streaming_cached_input_tokens_populated() {
441 use crate::http_client::mock::MockStreamingClient;
442 use futures::StreamExt;
443
444 let client = MockStreamingClient {
446 sse_bytes: sse_bytes_from_data_lines([
447 "{\"choices\":[{\"delta\":{\"content\":\"Hi\",\"tool_calls\":[]}}],\"usage\":null}",
448 "{\"choices\":[],\"usage\":{\"prompt_tokens\":100,\"completion_tokens\":10,\"total_tokens\":110,\"prompt_tokens_details\":{\"cached_tokens\":80}}}",
449 "[DONE]",
450 ]),
451 };
452
453 let req = http::Request::builder()
454 .method("POST")
455 .uri("http://localhost/v1/chat/completions")
456 .body(Vec::new())
457 .unwrap();
458
459 let mut stream = send_compatible_streaming_request(client, req)
460 .await
461 .unwrap();
462
463 let mut final_response = None;
464 while let Some(chunk) = stream.next().await {
465 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
466 final_response = Some(res);
467 break;
468 }
469 }
470
471 let res = final_response.expect("expected a final response");
472
473 assert_eq!(
475 res.usage
476 .prompt_tokens_details
477 .as_ref()
478 .unwrap()
479 .cached_tokens,
480 80
481 );
482
483 let core_usage = res.token_usage().expect("token_usage should return Some");
485 assert_eq!(core_usage.cached_input_tokens, 80);
486 assert_eq!(core_usage.input_tokens, 100);
487 assert_eq!(core_usage.total_tokens, 110);
488 }
489
490 #[tokio::test]
494 async fn test_duplicate_index_different_id_tool_calls() {
495 use crate::http_client::mock::MockStreamingClient;
496 use futures::StreamExt;
497
498 let client = MockStreamingClient {
502 sse_bytes: sse_bytes_from_data_lines([
503 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_aaa\",\"function\":{\"name\":\"command\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
504 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"cmd\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
505 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"ls\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}",
506 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_bbb\",\"function\":{\"name\":\"git\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
507 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"action\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
508 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"log\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}",
509 "{\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
510 "{\"choices\":[],\"usage\":{\"prompt_tokens\":20,\"completion_tokens\":10,\"total_tokens\":30}}",
511 "[DONE]",
512 ]),
513 };
514
515 let req = http::Request::builder()
516 .method("POST")
517 .uri("http://localhost/v1/chat/completions")
518 .body(Vec::new())
519 .unwrap();
520
521 let mut stream = send_compatible_streaming_request(client, req)
522 .await
523 .unwrap();
524
525 let mut collected_tool_calls = Vec::new();
526 while let Some(chunk) = stream.next().await {
527 if let streaming::StreamedAssistantContent::ToolCall {
528 tool_call,
529 internal_call_id: _,
530 } = chunk.unwrap()
531 {
532 collected_tool_calls.push(tool_call);
533 }
534 }
535
536 assert_eq!(
537 collected_tool_calls.len(),
538 2,
539 "expected 2 separate tool calls, got {collected_tool_calls:?}"
540 );
541
542 assert_eq!(collected_tool_calls[0].id, "call_aaa");
543 assert_eq!(collected_tool_calls[0].function.name, "command");
544 assert_eq!(
545 collected_tool_calls[0].function.arguments,
546 serde_json::json!({"cmd": "ls"})
547 );
548
549 assert_eq!(collected_tool_calls[1].id, "call_bbb");
550 assert_eq!(collected_tool_calls[1].function.name, "git");
551 assert_eq!(
552 collected_tool_calls[1].function.arguments,
553 serde_json::json!({"action": "log"})
554 );
555 }
556
557 #[tokio::test]
562 async fn test_unique_id_per_chunk_single_tool_call() {
563 use crate::http_client::mock::MockStreamingClient;
564 use futures::StreamExt;
565
566 let client = MockStreamingClient {
569 sse_bytes: sse_bytes_from_data_lines([
570 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-aaa\",\"function\":{\"name\":\"web_search\",\"arguments\":\"null\"}}]},\"finish_reason\":null}],\"usage\":null}",
571 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-bbb\",\"function\":{\"name\":\"\",\"arguments\":\"{\\\"query\\\": \\\"META\"}}]},\"finish_reason\":null}],\"usage\":null}",
572 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-ccc\",\"function\":{\"name\":\"\",\"arguments\":\" Platforms news\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}",
573 "{\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
574 "{\"choices\":[],\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":8,\"total_tokens\":23}}",
575 "[DONE]",
576 ]),
577 };
578
579 let req = http::Request::builder()
580 .method("POST")
581 .uri("http://localhost/v1/chat/completions")
582 .body(Vec::new())
583 .unwrap();
584
585 let mut stream = send_compatible_streaming_request(client, req)
586 .await
587 .unwrap();
588
589 let mut collected_tool_calls = Vec::new();
590 while let Some(chunk) = stream.next().await {
591 if let streaming::StreamedAssistantContent::ToolCall {
592 tool_call,
593 internal_call_id: _,
594 } = chunk.unwrap()
595 {
596 collected_tool_calls.push(tool_call);
597 }
598 }
599
600 assert_eq!(
601 collected_tool_calls.len(),
602 1,
603 "expected 1 tool call (all chunks are fragments of the same call), got {collected_tool_calls:?}"
604 );
605
606 assert_eq!(collected_tool_calls[0].function.name, "web_search");
607 let args_str = match &collected_tool_calls[0].function.arguments {
609 serde_json::Value::String(s) => s.clone(),
610 v => v.to_string(),
611 };
612 assert!(
613 args_str.contains("META Platforms news"),
614 "expected accumulated arguments containing the full query, got: {args_str}"
615 );
616 }
617
618 #[tokio::test]
619 async fn test_zero_arg_tool_call_normalized_on_finish_reason() {
620 use crate::http_client::mock::MockStreamingClient;
621
622 let client = MockStreamingClient {
623 sse_bytes: sse_bytes_from_data_lines([
624 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
625 "{\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}",
626 "[DONE]",
627 ]),
628 };
629
630 let req = http::Request::builder()
631 .method("POST")
632 .uri("http://localhost/v1/chat/completions")
633 .body(Vec::new())
634 .unwrap();
635
636 let stream = send_compatible_streaming_request(client, req)
637 .await
638 .unwrap();
639
640 assert_zero_arg_tool_call_is_emitted(stream, "call_123", "ping", true).await;
641 }
642
643 #[tokio::test]
644 async fn test_zero_arg_tool_call_is_preserved_at_eof() {
645 use crate::http_client::mock::MockStreamingClient;
646
647 let client = MockStreamingClient {
648 sse_bytes: sse_bytes_from_data_lines([
649 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
650 ]),
651 };
652
653 let req = http::Request::builder()
654 .method("POST")
655 .uri("http://localhost/v1/chat/completions")
656 .body(Vec::new())
657 .unwrap();
658
659 let stream = send_compatible_streaming_request(client, req)
660 .await
661 .unwrap();
662
663 assert_zero_arg_tool_call_is_emitted(stream, "call_123", "ping", true).await;
664 }
665}