1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23 pub(crate) name: Option<String>,
24 pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29 pub(crate) index: usize,
30 pub(crate) id: Option<String>,
31 pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36 #[serde(default)]
37 content: Option<String>,
38 #[serde(default)]
39 reasoning_content: Option<String>, #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41 tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug, PartialEq)]
45#[serde(rename_all = "snake_case")]
46pub enum FinishReason {
47 ToolCalls,
48 Stop,
49 ContentFilter,
50 Length,
51 #[serde(untagged)]
52 Other(String), }
54
55#[derive(Deserialize, Debug)]
56struct StreamingChoice {
57 delta: StreamingDelta,
58 finish_reason: Option<FinishReason>,
59}
60
61#[derive(Deserialize, Debug)]
62struct StreamingCompletionChunk {
63 choices: Vec<StreamingChoice>,
64 usage: Option<Usage>,
65}
66
67#[derive(Clone, Serialize, Deserialize)]
68pub struct StreamingCompletionResponse {
69 pub usage: Usage,
70}
71
72impl GetTokenUsage for StreamingCompletionResponse {
73 fn token_usage(&self) -> Option<crate::completion::Usage> {
74 let mut usage = crate::completion::Usage::new();
75 usage.input_tokens = self.usage.prompt_tokens as u64;
76 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
77 usage.total_tokens = self.usage.total_tokens as u64;
78 usage.cached_input_tokens = self
79 .usage
80 .prompt_tokens_details
81 .as_ref()
82 .map_or(0, |d| d.cached_tokens as u64);
83 Some(usage)
84 }
85}
86
87impl<T> CompletionModel<T>
88where
89 T: HttpClientExt + Clone + 'static,
90{
91 pub(crate) async fn stream(
92 &self,
93 completion_request: CompletionRequest,
94 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
95 {
96 let request = super::CompletionRequest::try_from(OpenAIRequestParams {
97 model: self.model.clone(),
98 request: completion_request,
99 strict_tools: self.strict_tools,
100 tool_result_array_content: self.tool_result_array_content,
101 })?;
102 let request_messages = serde_json::to_string(&request.messages)
103 .expect("Converting to JSON from a Rust struct shouldn't fail");
104 let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
105
106 request_as_json = merge(
107 request_as_json,
108 json!({"stream": true, "stream_options": {"include_usage": true}}),
109 );
110
111 if enabled!(Level::TRACE) {
112 tracing::trace!(
113 target: "rig::completions",
114 "OpenAI Chat Completions streaming completion request: {}",
115 serde_json::to_string_pretty(&request_as_json)?
116 );
117 }
118
119 let req_body = serde_json::to_vec(&request_as_json)?;
120
121 let req = self
122 .client
123 .post("/chat/completions")?
124 .body(req_body)
125 .map_err(|e| CompletionError::HttpError(e.into()))?;
126
127 let span = if tracing::Span::current().is_disabled() {
128 info_span!(
129 target: "rig::completions",
130 "chat",
131 gen_ai.operation.name = "chat",
132 gen_ai.provider.name = "openai",
133 gen_ai.request.model = self.model,
134 gen_ai.response.id = tracing::field::Empty,
135 gen_ai.response.model = self.model,
136 gen_ai.usage.output_tokens = tracing::field::Empty,
137 gen_ai.usage.input_tokens = tracing::field::Empty,
138 gen_ai.usage.cached_tokens = tracing::field::Empty,
139 gen_ai.input.messages = request_messages,
140 gen_ai.output.messages = tracing::field::Empty,
141 )
142 } else {
143 tracing::Span::current()
144 };
145
146 let client = self.client.clone();
147
148 tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
149 }
150}
151
152pub async fn send_compatible_streaming_request<T>(
153 http_client: T,
154 req: Request<Vec<u8>>,
155) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
156where
157 T: HttpClientExt + Clone + 'static,
158{
159 let span = tracing::Span::current();
160 let mut event_source = GenericEventSource::new(http_client, req);
162
163 let stream = stream! {
164 let span = tracing::Span::current();
165
166 let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
168 let mut text_content = String::new();
169 let mut final_usage = None;
170
171 while let Some(event_result) = event_source.next().await {
172 match event_result {
173 Ok(Event::Open) => {
174 tracing::trace!("SSE connection opened");
175 continue;
176 }
177
178 Ok(Event::Message(message)) => {
179 if message.data.trim().is_empty() || message.data == "[DONE]" {
180 continue;
181 }
182
183 let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
184 Ok(data) => data,
185 Err(error) => {
186 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
187 continue;
188 }
189 };
190
191 if let Some(usage) = data.usage {
193 final_usage = Some(usage);
194 }
195
196 let Some(choice) = data.choices.first() else {
198 tracing::debug!("There is no choice");
199 continue;
200 };
201 let delta = &choice.delta;
202
203 if !delta.tool_calls.is_empty() {
204 for tool_call in &delta.tool_calls {
205 let index = tool_call.index;
206
207 if let Some(new_id) = &tool_call.id
215 && !new_id.is_empty()
216 && let Some(new_name) = &tool_call.function.name
217 && !new_name.is_empty()
218 && let Some(existing) = tool_calls.get(&index)
219 && !existing.id.is_empty()
220 && existing.id != *new_id
221 && !existing.name.is_empty()
222 && existing.name != *new_name
223 {
224 let evicted = tool_calls.remove(&index).expect("checked above");
225 yield Ok(streaming::RawStreamingChoice::ToolCall(
226 finalize_completed_streaming_tool_call(evicted),
227 ));
228 }
229
230 let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
231
232 if let Some(id) = &tool_call.id && !id.is_empty() {
233 existing_tool_call.id = id.clone();
234 }
235
236 if let Some(name) = &tool_call.function.name && !name.is_empty() {
237 existing_tool_call.name = name.clone();
238 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
239 id: existing_tool_call.id.clone(),
240 internal_call_id: existing_tool_call.internal_call_id.clone(),
241 content: streaming::ToolCallDeltaContent::Name(name.clone()),
242 });
243 }
244
245 if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
247 let current_args = match &existing_tool_call.arguments {
248 serde_json::Value::Null => String::new(),
249 serde_json::Value::String(s) => s.clone(),
250 v => v.to_string(),
251 };
252
253 let combined = format!("{current_args}{chunk}");
255
256 if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
258 match serde_json::from_str(&combined) {
259 Ok(parsed) => existing_tool_call.arguments = parsed,
260 Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
261 }
262 } else {
263 existing_tool_call.arguments = serde_json::Value::String(combined);
264 }
265
266 yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
268 id: existing_tool_call.id.clone(),
269 internal_call_id: existing_tool_call.internal_call_id.clone(),
270 content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
271 });
272 }
273 }
274 }
275
276 if let Some(reasoning) = &delta.reasoning_content && !reasoning.is_empty() {
278 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
279 id: None,
280 reasoning: reasoning.clone(),
281 });
282 }
283
284 if let Some(content) = &delta.content && !content.is_empty() {
286 text_content += content;
287 yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
288 }
289
290 if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
292 for (_idx, tool_call) in tool_calls.into_iter() {
293 yield Ok(streaming::RawStreamingChoice::ToolCall(
294 finalize_completed_streaming_tool_call(tool_call),
295 ));
296 }
297 tool_calls = HashMap::new();
298 }
299 }
300 Err(crate::http_client::Error::StreamEnded) => {
301 break;
302 }
303 Err(error) => {
304 tracing::error!(?error, "SSE error");
305 yield Err(CompletionError::ProviderError(error.to_string()));
306 break;
307 }
308 }
309 }
310
311
312 event_source.close();
314
315 for (_idx, tool_call) in tool_calls.into_iter() {
317 yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
318 }
319
320 let final_usage = final_usage.unwrap_or_default();
321 if !span.is_disabled() {
322 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
323 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
324 span.record(
325 "gen_ai.usage.cached_tokens",
326 final_usage
327 .prompt_tokens_details
328 .as_ref()
329 .map(|d| d.cached_tokens)
330 .unwrap_or(0),
331 );
332 }
333
334 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
335 usage: final_usage
336 }));
337 }.instrument(span);
338
339 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
340 stream,
341 )))
342}
343
344fn finalize_completed_streaming_tool_call(
345 mut tool_call: streaming::RawStreamingToolCall,
346) -> streaming::RawStreamingToolCall {
347 if tool_call.arguments.is_null() {
348 tool_call.arguments = serde_json::Value::Object(serde_json::Map::new());
349 }
350
351 tool_call
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_streaming_function_deserialization() {
360 let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
361 let function: StreamingFunction = serde_json::from_str(json).unwrap();
362 assert_eq!(function.name, Some("get_weather".to_string()));
363 assert_eq!(
364 function.arguments.as_ref().unwrap(),
365 r#"{"location":"Paris"}"#
366 );
367 }
368
369 #[test]
370 fn test_streaming_tool_call_deserialization() {
371 let json = r#"{
372 "index": 0,
373 "id": "call_abc123",
374 "function": {
375 "name": "get_weather",
376 "arguments": "{\"city\":\"London\"}"
377 }
378 }"#;
379 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
380 assert_eq!(tool_call.index, 0);
381 assert_eq!(tool_call.id, Some("call_abc123".to_string()));
382 assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
383 }
384
385 #[test]
386 fn test_streaming_tool_call_partial_deserialization() {
387 let json = r#"{
389 "index": 0,
390 "id": null,
391 "function": {
392 "name": null,
393 "arguments": "Paris"
394 }
395 }"#;
396 let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
397 assert_eq!(tool_call.index, 0);
398 assert!(tool_call.id.is_none());
399 assert!(tool_call.function.name.is_none());
400 assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
401 }
402
403 #[test]
404 fn test_streaming_delta_with_tool_calls() {
405 let json = r#"{
406 "content": null,
407 "tool_calls": [{
408 "index": 0,
409 "id": "call_xyz",
410 "function": {
411 "name": "search",
412 "arguments": ""
413 }
414 }]
415 }"#;
416 let delta: StreamingDelta = serde_json::from_str(json).unwrap();
417 assert!(delta.content.is_none());
418 assert_eq!(delta.tool_calls.len(), 1);
419 assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
420 }
421
422 #[test]
423 fn test_streaming_chunk_deserialization() {
424 let json = r#"{
425 "choices": [{
426 "delta": {
427 "content": "Hello",
428 "tool_calls": []
429 }
430 }],
431 "usage": {
432 "prompt_tokens": 10,
433 "completion_tokens": 5,
434 "total_tokens": 15
435 }
436 }"#;
437 let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
438 assert_eq!(chunk.choices.len(), 1);
439 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
440 assert!(chunk.usage.is_some());
441 }
442
443 #[test]
444 fn test_streaming_chunk_with_multiple_tool_call_deltas() {
445 let json_start = r#"{
447 "choices": [{
448 "delta": {
449 "content": null,
450 "tool_calls": [{
451 "index": 0,
452 "id": "call_123",
453 "function": {
454 "name": "get_weather",
455 "arguments": ""
456 }
457 }]
458 }
459 }],
460 "usage": null
461 }"#;
462
463 let json_chunk1 = r#"{
464 "choices": [{
465 "delta": {
466 "content": null,
467 "tool_calls": [{
468 "index": 0,
469 "id": null,
470 "function": {
471 "name": null,
472 "arguments": "{\"loc"
473 }
474 }]
475 }
476 }],
477 "usage": null
478 }"#;
479
480 let json_chunk2 = r#"{
481 "choices": [{
482 "delta": {
483 "content": null,
484 "tool_calls": [{
485 "index": 0,
486 "id": null,
487 "function": {
488 "name": null,
489 "arguments": "ation\":\"NYC\"}"
490 }
491 }]
492 }
493 }],
494 "usage": null
495 }"#;
496
497 let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
499 assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
500 assert_eq!(
501 start_chunk.choices[0].delta.tool_calls[0]
502 .function
503 .name
504 .as_ref()
505 .unwrap(),
506 "get_weather"
507 );
508
509 let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
510 assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
511 assert_eq!(
512 chunk1.choices[0].delta.tool_calls[0]
513 .function
514 .arguments
515 .as_ref()
516 .unwrap(),
517 "{\"loc"
518 );
519
520 let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
521 assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
522 assert_eq!(
523 chunk2.choices[0].delta.tool_calls[0]
524 .function
525 .arguments
526 .as_ref()
527 .unwrap(),
528 "ation\":\"NYC\"}"
529 );
530 }
531
532 #[tokio::test]
533 async fn test_streaming_usage_only_chunk_is_not_ignored() {
534 use crate::http_client::mock::MockStreamingClient;
535 use bytes::Bytes;
536 use futures::StreamExt;
537
538 let sse = concat!(
540 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
541 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
542 "data: [DONE]\n\n",
543 );
544
545 let client = MockStreamingClient {
546 sse_bytes: Bytes::from(sse),
547 };
548
549 let req = http::Request::builder()
550 .method("POST")
551 .uri("http://localhost/v1/chat/completions")
552 .body(Vec::new())
553 .unwrap();
554
555 let mut stream = send_compatible_streaming_request(client, req)
556 .await
557 .unwrap();
558
559 let mut final_usage = None;
560 while let Some(chunk) = stream.next().await {
561 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
562 final_usage = Some(res.usage);
563 break;
564 }
565 }
566
567 let usage = final_usage.expect("expected a final response with usage");
568 assert_eq!(usage.prompt_tokens, 10);
569 assert_eq!(usage.total_tokens, 15);
570 }
571
572 #[tokio::test]
573 async fn test_streaming_cached_input_tokens_populated() {
574 use crate::http_client::mock::MockStreamingClient;
575 use bytes::Bytes;
576 use futures::StreamExt;
577
578 let sse = concat!(
580 "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
581 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":100,\"completion_tokens\":10,\"total_tokens\":110,\"prompt_tokens_details\":{\"cached_tokens\":80}}}\n\n",
582 "data: [DONE]\n\n",
583 );
584
585 let client = MockStreamingClient {
586 sse_bytes: Bytes::from(sse),
587 };
588
589 let req = http::Request::builder()
590 .method("POST")
591 .uri("http://localhost/v1/chat/completions")
592 .body(Vec::new())
593 .unwrap();
594
595 let mut stream = send_compatible_streaming_request(client, req)
596 .await
597 .unwrap();
598
599 let mut final_response = None;
600 while let Some(chunk) = stream.next().await {
601 if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
602 final_response = Some(res);
603 break;
604 }
605 }
606
607 let res = final_response.expect("expected a final response");
608
609 assert_eq!(
611 res.usage
612 .prompt_tokens_details
613 .as_ref()
614 .unwrap()
615 .cached_tokens,
616 80
617 );
618
619 let core_usage = res.token_usage().expect("token_usage should return Some");
621 assert_eq!(core_usage.cached_input_tokens, 80);
622 assert_eq!(core_usage.input_tokens, 100);
623 assert_eq!(core_usage.total_tokens, 110);
624 }
625
626 #[tokio::test]
630 async fn test_duplicate_index_different_id_tool_calls() {
631 use crate::http_client::mock::MockStreamingClient;
632 use bytes::Bytes;
633 use futures::StreamExt;
634
635 let sse = concat!(
639 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_aaa\",\"function\":{\"name\":\"command\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
641 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"cmd\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
643 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"ls\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
644 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_bbb\",\"function\":{\"name\":\"git\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
646 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"action\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
648 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"log\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
649 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
651 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":20,\"completion_tokens\":10,\"total_tokens\":30}}\n\n",
653 "data: [DONE]\n\n",
654 );
655
656 let client = MockStreamingClient {
657 sse_bytes: Bytes::from(sse),
658 };
659
660 let req = http::Request::builder()
661 .method("POST")
662 .uri("http://localhost/v1/chat/completions")
663 .body(Vec::new())
664 .unwrap();
665
666 let mut stream = send_compatible_streaming_request(client, req)
667 .await
668 .unwrap();
669
670 let mut collected_tool_calls = Vec::new();
671 while let Some(chunk) = stream.next().await {
672 if let streaming::StreamedAssistantContent::ToolCall {
673 tool_call,
674 internal_call_id: _,
675 } = chunk.unwrap()
676 {
677 collected_tool_calls.push(tool_call);
678 }
679 }
680
681 assert_eq!(
682 collected_tool_calls.len(),
683 2,
684 "expected 2 separate tool calls, got {collected_tool_calls:?}"
685 );
686
687 assert_eq!(collected_tool_calls[0].id, "call_aaa");
688 assert_eq!(collected_tool_calls[0].function.name, "command");
689 assert_eq!(
690 collected_tool_calls[0].function.arguments,
691 serde_json::json!({"cmd": "ls"})
692 );
693
694 assert_eq!(collected_tool_calls[1].id, "call_bbb");
695 assert_eq!(collected_tool_calls[1].function.name, "git");
696 assert_eq!(
697 collected_tool_calls[1].function.arguments,
698 serde_json::json!({"action": "log"})
699 );
700 }
701
702 #[tokio::test]
707 async fn test_unique_id_per_chunk_single_tool_call() {
708 use crate::http_client::mock::MockStreamingClient;
709 use bytes::Bytes;
710 use futures::StreamExt;
711
712 let sse = concat!(
715 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-aaa\",\"function\":{\"name\":\"web_search\",\"arguments\":\"null\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
716 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-bbb\",\"function\":{\"name\":\"\",\"arguments\":\"{\\\"query\\\": \\\"META\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
717 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-ccc\",\"function\":{\"name\":\"\",\"arguments\":\" Platforms news\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
718 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
719 "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":8,\"total_tokens\":23}}\n\n",
720 "data: [DONE]\n\n",
721 );
722
723 let client = MockStreamingClient {
724 sse_bytes: Bytes::from(sse),
725 };
726
727 let req = http::Request::builder()
728 .method("POST")
729 .uri("http://localhost/v1/chat/completions")
730 .body(Vec::new())
731 .unwrap();
732
733 let mut stream = send_compatible_streaming_request(client, req)
734 .await
735 .unwrap();
736
737 let mut collected_tool_calls = Vec::new();
738 while let Some(chunk) = stream.next().await {
739 if let streaming::StreamedAssistantContent::ToolCall {
740 tool_call,
741 internal_call_id: _,
742 } = chunk.unwrap()
743 {
744 collected_tool_calls.push(tool_call);
745 }
746 }
747
748 assert_eq!(
749 collected_tool_calls.len(),
750 1,
751 "expected 1 tool call (all chunks are fragments of the same call), got {collected_tool_calls:?}"
752 );
753
754 assert_eq!(collected_tool_calls[0].function.name, "web_search");
755 let args_str = match &collected_tool_calls[0].function.arguments {
757 serde_json::Value::String(s) => s.clone(),
758 v => v.to_string(),
759 };
760 assert!(
761 args_str.contains("META Platforms news"),
762 "expected accumulated arguments containing the full query, got: {args_str}"
763 );
764 }
765
766 #[tokio::test]
767 async fn test_zero_arg_tool_call_normalized_on_finish_reason() {
768 use crate::http_client::mock::MockStreamingClient;
769 use bytes::Bytes;
770 use futures::StreamExt;
771
772 let sse = concat!(
773 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
774 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
775 "data: [DONE]\n\n",
776 );
777
778 let client = MockStreamingClient {
779 sse_bytes: Bytes::from(sse),
780 };
781
782 let req = http::Request::builder()
783 .method("POST")
784 .uri("http://localhost/v1/chat/completions")
785 .body(Vec::new())
786 .unwrap();
787
788 let mut stream = send_compatible_streaming_request(client, req)
789 .await
790 .unwrap();
791
792 let mut collected_tool_calls = Vec::new();
793 while let Some(chunk) = stream.next().await {
794 if let streaming::StreamedAssistantContent::ToolCall {
795 tool_call,
796 internal_call_id: _,
797 } = chunk.unwrap()
798 {
799 collected_tool_calls.push(tool_call);
800 }
801 }
802
803 assert_eq!(collected_tool_calls.len(), 1);
804 assert_eq!(collected_tool_calls[0].id, "call_123");
805 assert_eq!(collected_tool_calls[0].function.name, "ping");
806 assert_eq!(
807 collected_tool_calls[0].function.arguments,
808 serde_json::json!({})
809 );
810 }
811
812 #[tokio::test]
813 async fn test_incomplete_zero_arg_tool_call_preserves_null_on_cleanup_flush() {
814 use crate::http_client::mock::MockStreamingClient;
815 use bytes::Bytes;
816 use futures::StreamExt;
817
818 let sse = "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n";
819
820 let client = MockStreamingClient {
821 sse_bytes: Bytes::from(sse),
822 };
823
824 let req = http::Request::builder()
825 .method("POST")
826 .uri("http://localhost/v1/chat/completions")
827 .body(Vec::new())
828 .unwrap();
829
830 let mut stream = send_compatible_streaming_request(client, req)
831 .await
832 .unwrap();
833
834 let mut collected_tool_calls = Vec::new();
835 while let Some(chunk) = stream.next().await {
836 if let streaming::StreamedAssistantContent::ToolCall {
837 tool_call,
838 internal_call_id: _,
839 } = chunk.unwrap()
840 {
841 collected_tool_calls.push(tool_call);
842 }
843 }
844
845 assert_eq!(collected_tool_calls.len(), 1);
846 assert_eq!(collected_tool_calls[0].id, "call_123");
847 assert_eq!(collected_tool_calls[0].function.name, "ping");
848 assert_eq!(
849 collected_tool_calls[0].function.arguments,
850 serde_json::Value::Null
851 );
852 }
853}