Skip to main content

rig/providers/openai/completion/
streaming.rs

1use std::collections::HashMap;
2
3use async_stream::stream;
4use futures::StreamExt;
5use http::Request;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use tracing::{Level, enabled, info_span};
9use tracing_futures::Instrument;
10
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::HttpClientExt;
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::json_utils::{self, merge};
15use crate::providers::openai::completion::{CompletionModel, OpenAIRequestParams, Usage};
16use crate::streaming::{self, RawStreamingChoice};
17
18// ================================================================
19// OpenAI Completion Streaming API
20// ================================================================
21#[derive(Deserialize, Debug)]
22pub(crate) struct StreamingFunction {
23    pub(crate) name: Option<String>,
24    pub(crate) arguments: Option<String>,
25}
26
27#[derive(Deserialize, Debug)]
28pub(crate) struct StreamingToolCall {
29    pub(crate) index: usize,
30    pub(crate) id: Option<String>,
31    pub(crate) function: StreamingFunction,
32}
33
34#[derive(Deserialize, Debug)]
35struct StreamingDelta {
36    #[serde(default)]
37    content: Option<String>,
38    #[serde(default)]
39    reasoning_content: Option<String>, // This is not part of the official OpenAI API
40    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41    tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug, PartialEq)]
45#[serde(rename_all = "snake_case")]
46pub enum FinishReason {
47    ToolCalls,
48    Stop,
49    ContentFilter,
50    Length,
51    #[serde(untagged)]
52    Other(String), // This will handle the deprecated function_call
53}
54
55#[derive(Deserialize, Debug)]
56struct StreamingChoice {
57    delta: StreamingDelta,
58    finish_reason: Option<FinishReason>,
59}
60
61#[derive(Deserialize, Debug)]
62struct StreamingCompletionChunk {
63    choices: Vec<StreamingChoice>,
64    usage: Option<Usage>,
65}
66
67#[derive(Clone, Serialize, Deserialize)]
68pub struct StreamingCompletionResponse {
69    pub usage: Usage,
70}
71
72impl GetTokenUsage for StreamingCompletionResponse {
73    fn token_usage(&self) -> Option<crate::completion::Usage> {
74        let mut usage = crate::completion::Usage::new();
75        usage.input_tokens = self.usage.prompt_tokens as u64;
76        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
77        usage.total_tokens = self.usage.total_tokens as u64;
78        usage.cached_input_tokens = self
79            .usage
80            .prompt_tokens_details
81            .as_ref()
82            .map_or(0, |d| d.cached_tokens as u64);
83        Some(usage)
84    }
85}
86
87impl<T> CompletionModel<T>
88where
89    T: HttpClientExt + Clone + 'static,
90{
91    pub(crate) async fn stream(
92        &self,
93        completion_request: CompletionRequest,
94    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
95    {
96        let request = super::CompletionRequest::try_from(OpenAIRequestParams {
97            model: self.model.clone(),
98            request: completion_request,
99            strict_tools: self.strict_tools,
100            tool_result_array_content: self.tool_result_array_content,
101        })?;
102        let request_messages = serde_json::to_string(&request.messages)
103            .expect("Converting to JSON from a Rust struct shouldn't fail");
104        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
105
106        request_as_json = merge(
107            request_as_json,
108            json!({"stream": true, "stream_options": {"include_usage": true}}),
109        );
110
111        if enabled!(Level::TRACE) {
112            tracing::trace!(
113                target: "rig::completions",
114                "OpenAI Chat Completions streaming completion request: {}",
115                serde_json::to_string_pretty(&request_as_json)?
116            );
117        }
118
119        let req_body = serde_json::to_vec(&request_as_json)?;
120
121        let req = self
122            .client
123            .post("/chat/completions")?
124            .body(req_body)
125            .map_err(|e| CompletionError::HttpError(e.into()))?;
126
127        let span = if tracing::Span::current().is_disabled() {
128            info_span!(
129                target: "rig::completions",
130                "chat",
131                gen_ai.operation.name = "chat",
132                gen_ai.provider.name = "openai",
133                gen_ai.request.model = self.model,
134                gen_ai.response.id = tracing::field::Empty,
135                gen_ai.response.model = self.model,
136                gen_ai.usage.output_tokens = tracing::field::Empty,
137                gen_ai.usage.input_tokens = tracing::field::Empty,
138                gen_ai.usage.cached_tokens = tracing::field::Empty,
139                gen_ai.input.messages = request_messages,
140                gen_ai.output.messages = tracing::field::Empty,
141            )
142        } else {
143            tracing::Span::current()
144        };
145
146        let client = self.client.clone();
147
148        tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await
149    }
150}
151
152pub async fn send_compatible_streaming_request<T>(
153    http_client: T,
154    req: Request<Vec<u8>>,
155) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
156where
157    T: HttpClientExt + Clone + 'static,
158{
159    let span = tracing::Span::current();
160    // Build the request with proper headers for SSE
161    let mut event_source = GenericEventSource::new(http_client, req);
162
163    let stream = stream! {
164        let span = tracing::Span::current();
165
166        // Accumulate tool calls by index while streaming
167        let mut tool_calls: HashMap<usize, streaming::RawStreamingToolCall> = HashMap::new();
168        let mut text_content = String::new();
169        let mut final_usage = None;
170
171        while let Some(event_result) = event_source.next().await {
172            match event_result {
173                Ok(Event::Open) => {
174                    tracing::trace!("SSE connection opened");
175                    continue;
176                }
177
178                Ok(Event::Message(message)) => {
179                    if message.data.trim().is_empty() || message.data == "[DONE]" {
180                        continue;
181                    }
182
183                    let data = match serde_json::from_str::<StreamingCompletionChunk>(&message.data) {
184                        Ok(data) => data,
185                        Err(error) => {
186                            tracing::error!(?error, message = message.data, "Failed to parse SSE message");
187                            continue;
188                        }
189                    };
190
191                    // Usage updates (some providers send a final "usage-only" chunk with empty choices)
192                    if let Some(usage) = data.usage {
193                        final_usage = Some(usage);
194                    }
195
196                    // Expect at least one choice
197                     let Some(choice) = data.choices.first() else {
198                        tracing::debug!("There is no choice");
199                        continue;
200                    };
201                    let delta = &choice.delta;
202
203                    if !delta.tool_calls.is_empty() {
204                        for tool_call in &delta.tool_calls {
205                            let index = tool_call.index;
206
207                            // Some API gateways (e.g. LiteLLM, OneAPI) emit multiple
208                            // distinct tool calls all sharing index 0.  Detect this by
209                            // comparing both the `id` and `name`: only evict when a new
210                            // chunk carries a different non-empty id AND a different
211                            // non-empty name.  Checking the name prevents false evictions
212                            // from providers (e.g. GLM-4) that send a unique id on every
213                            // SSE chunk for the same logical tool call.
214                            if let Some(new_id) = &tool_call.id
215                                && !new_id.is_empty()
216                                && let Some(new_name) = &tool_call.function.name
217                                && !new_name.is_empty()
218                                && let Some(existing) = tool_calls.get(&index)
219                                && !existing.id.is_empty()
220                                && existing.id != *new_id
221                                && !existing.name.is_empty()
222                                && existing.name != *new_name
223                            {
224                                let evicted = tool_calls.remove(&index).expect("checked above");
225                                yield Ok(streaming::RawStreamingChoice::ToolCall(
226                                    finalize_completed_streaming_tool_call(evicted),
227                                ));
228                            }
229
230                            let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty);
231
232                            if let Some(id) = &tool_call.id && !id.is_empty() {
233                                existing_tool_call.id = id.clone();
234                            }
235
236                            if let Some(name) = &tool_call.function.name && !name.is_empty() {
237                                existing_tool_call.name = name.clone();
238                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
239                                    id: existing_tool_call.id.clone(),
240                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
241                                    content: streaming::ToolCallDeltaContent::Name(name.clone()),
242                                });
243                            }
244
245                            // Convert current arguments to string if needed
246                            if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() {
247                                let current_args = match &existing_tool_call.arguments {
248                                    serde_json::Value::Null => String::new(),
249                                    serde_json::Value::String(s) => s.clone(),
250                                    v => v.to_string(),
251                                };
252
253                                // Concatenate the new chunk
254                                let combined = format!("{current_args}{chunk}");
255
256                                // Try to parse as JSON if it looks complete
257                                if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') {
258                                    match serde_json::from_str(&combined) {
259                                        Ok(parsed) => existing_tool_call.arguments = parsed,
260                                        Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined),
261                                    }
262                                } else {
263                                    existing_tool_call.arguments = serde_json::Value::String(combined);
264                                }
265
266                                // Emit the delta so UI can show progress
267                                yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
268                                    id: existing_tool_call.id.clone(),
269                                    internal_call_id: existing_tool_call.internal_call_id.clone(),
270                                    content: streaming::ToolCallDeltaContent::Delta(chunk.clone()),
271                                });
272                            }
273                        }
274                    }
275
276                    // Streamed reasoning/thinking content (e.g. GLM-4, DeepSeek via compatible endpoint)
277                    if let Some(reasoning) = &delta.reasoning_content && !reasoning.is_empty() {
278                        yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
279                            id: None,
280                            reasoning: reasoning.clone(),
281                        });
282                    }
283
284                    // Streamed text content
285                    if let Some(content) = &delta.content && !content.is_empty() {
286                        text_content += content;
287                        yield Ok(streaming::RawStreamingChoice::Message(content.clone()));
288                    }
289
290                    // Finish reason
291                    if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls {
292                        for (_idx, tool_call) in tool_calls.into_iter() {
293                            yield Ok(streaming::RawStreamingChoice::ToolCall(
294                                finalize_completed_streaming_tool_call(tool_call),
295                            ));
296                        }
297                        tool_calls = HashMap::new();
298                    }
299                }
300                Err(crate::http_client::Error::StreamEnded) => {
301                    break;
302                }
303                Err(error) => {
304                    tracing::error!(?error, "SSE error");
305                    yield Err(CompletionError::ProviderError(error.to_string()));
306                    break;
307                }
308            }
309        }
310
311
312        // Ensure event source is closed when stream ends
313        event_source.close();
314
315        // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier)
316        for (_idx, tool_call) in tool_calls.into_iter() {
317            yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call));
318        }
319
320        let final_usage = final_usage.unwrap_or_default();
321        if !span.is_disabled() {
322            span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
323            span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
324            span.record(
325                "gen_ai.usage.cached_tokens",
326                final_usage
327                    .prompt_tokens_details
328                    .as_ref()
329                    .map(|d| d.cached_tokens)
330                    .unwrap_or(0),
331            );
332        }
333
334        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
335            usage: final_usage
336        }));
337    }.instrument(span);
338
339    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
340        stream,
341    )))
342}
343
344fn finalize_completed_streaming_tool_call(
345    mut tool_call: streaming::RawStreamingToolCall,
346) -> streaming::RawStreamingToolCall {
347    if tool_call.arguments.is_null() {
348        tool_call.arguments = serde_json::Value::Object(serde_json::Map::new());
349    }
350
351    tool_call
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_streaming_function_deserialization() {
360        let json = r#"{"name": "get_weather", "arguments": "{\"location\":\"Paris\"}"}"#;
361        let function: StreamingFunction = serde_json::from_str(json).unwrap();
362        assert_eq!(function.name, Some("get_weather".to_string()));
363        assert_eq!(
364            function.arguments.as_ref().unwrap(),
365            r#"{"location":"Paris"}"#
366        );
367    }
368
369    #[test]
370    fn test_streaming_tool_call_deserialization() {
371        let json = r#"{
372            "index": 0,
373            "id": "call_abc123",
374            "function": {
375                "name": "get_weather",
376                "arguments": "{\"city\":\"London\"}"
377            }
378        }"#;
379        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
380        assert_eq!(tool_call.index, 0);
381        assert_eq!(tool_call.id, Some("call_abc123".to_string()));
382        assert_eq!(tool_call.function.name, Some("get_weather".to_string()));
383    }
384
385    #[test]
386    fn test_streaming_tool_call_partial_deserialization() {
387        // Partial tool calls have no name and partial arguments
388        let json = r#"{
389            "index": 0,
390            "id": null,
391            "function": {
392                "name": null,
393                "arguments": "Paris"
394            }
395        }"#;
396        let tool_call: StreamingToolCall = serde_json::from_str(json).unwrap();
397        assert_eq!(tool_call.index, 0);
398        assert!(tool_call.id.is_none());
399        assert!(tool_call.function.name.is_none());
400        assert_eq!(tool_call.function.arguments.as_ref().unwrap(), "Paris");
401    }
402
403    #[test]
404    fn test_streaming_delta_with_tool_calls() {
405        let json = r#"{
406            "content": null,
407            "tool_calls": [{
408                "index": 0,
409                "id": "call_xyz",
410                "function": {
411                    "name": "search",
412                    "arguments": ""
413                }
414            }]
415        }"#;
416        let delta: StreamingDelta = serde_json::from_str(json).unwrap();
417        assert!(delta.content.is_none());
418        assert_eq!(delta.tool_calls.len(), 1);
419        assert_eq!(delta.tool_calls[0].id, Some("call_xyz".to_string()));
420    }
421
422    #[test]
423    fn test_streaming_chunk_deserialization() {
424        let json = r#"{
425            "choices": [{
426                "delta": {
427                    "content": "Hello",
428                    "tool_calls": []
429                }
430            }],
431            "usage": {
432                "prompt_tokens": 10,
433                "completion_tokens": 5,
434                "total_tokens": 15
435            }
436        }"#;
437        let chunk: StreamingCompletionChunk = serde_json::from_str(json).unwrap();
438        assert_eq!(chunk.choices.len(), 1);
439        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
440        assert!(chunk.usage.is_some());
441    }
442
443    #[test]
444    fn test_streaming_chunk_with_multiple_tool_call_deltas() {
445        // Simulates multiple partial tool call chunks arriving
446        let json_start = r#"{
447            "choices": [{
448                "delta": {
449                    "content": null,
450                    "tool_calls": [{
451                        "index": 0,
452                        "id": "call_123",
453                        "function": {
454                            "name": "get_weather",
455                            "arguments": ""
456                        }
457                    }]
458                }
459            }],
460            "usage": null
461        }"#;
462
463        let json_chunk1 = r#"{
464            "choices": [{
465                "delta": {
466                    "content": null,
467                    "tool_calls": [{
468                        "index": 0,
469                        "id": null,
470                        "function": {
471                            "name": null,
472                            "arguments": "{\"loc"
473                        }
474                    }]
475                }
476            }],
477            "usage": null
478        }"#;
479
480        let json_chunk2 = r#"{
481            "choices": [{
482                "delta": {
483                    "content": null,
484                    "tool_calls": [{
485                        "index": 0,
486                        "id": null,
487                        "function": {
488                            "name": null,
489                            "arguments": "ation\":\"NYC\"}"
490                        }
491                    }]
492                }
493            }],
494            "usage": null
495        }"#;
496
497        // Verify each chunk deserializes correctly
498        let start_chunk: StreamingCompletionChunk = serde_json::from_str(json_start).unwrap();
499        assert_eq!(start_chunk.choices[0].delta.tool_calls.len(), 1);
500        assert_eq!(
501            start_chunk.choices[0].delta.tool_calls[0]
502                .function
503                .name
504                .as_ref()
505                .unwrap(),
506            "get_weather"
507        );
508
509        let chunk1: StreamingCompletionChunk = serde_json::from_str(json_chunk1).unwrap();
510        assert_eq!(chunk1.choices[0].delta.tool_calls.len(), 1);
511        assert_eq!(
512            chunk1.choices[0].delta.tool_calls[0]
513                .function
514                .arguments
515                .as_ref()
516                .unwrap(),
517            "{\"loc"
518        );
519
520        let chunk2: StreamingCompletionChunk = serde_json::from_str(json_chunk2).unwrap();
521        assert_eq!(chunk2.choices[0].delta.tool_calls.len(), 1);
522        assert_eq!(
523            chunk2.choices[0].delta.tool_calls[0]
524                .function
525                .arguments
526                .as_ref()
527                .unwrap(),
528            "ation\":\"NYC\"}"
529        );
530    }
531
532    #[tokio::test]
533    async fn test_streaming_usage_only_chunk_is_not_ignored() {
534        use crate::http_client::mock::MockStreamingClient;
535        use bytes::Bytes;
536        use futures::StreamExt;
537
538        // Some providers emit a final "usage-only" chunk where `choices` is empty.
539        let sse = concat!(
540            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
541            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n",
542            "data: [DONE]\n\n",
543        );
544
545        let client = MockStreamingClient {
546            sse_bytes: Bytes::from(sse),
547        };
548
549        let req = http::Request::builder()
550            .method("POST")
551            .uri("http://localhost/v1/chat/completions")
552            .body(Vec::new())
553            .unwrap();
554
555        let mut stream = send_compatible_streaming_request(client, req)
556            .await
557            .unwrap();
558
559        let mut final_usage = None;
560        while let Some(chunk) = stream.next().await {
561            if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
562                final_usage = Some(res.usage);
563                break;
564            }
565        }
566
567        let usage = final_usage.expect("expected a final response with usage");
568        assert_eq!(usage.prompt_tokens, 10);
569        assert_eq!(usage.total_tokens, 15);
570    }
571
572    #[tokio::test]
573    async fn test_streaming_cached_input_tokens_populated() {
574        use crate::http_client::mock::MockStreamingClient;
575        use bytes::Bytes;
576        use futures::StreamExt;
577
578        // Usage chunk includes prompt_tokens_details with cached_tokens.
579        let sse = concat!(
580            "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\",\"tool_calls\":[]}}],\"usage\":null}\n\n",
581            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":100,\"completion_tokens\":10,\"total_tokens\":110,\"prompt_tokens_details\":{\"cached_tokens\":80}}}\n\n",
582            "data: [DONE]\n\n",
583        );
584
585        let client = MockStreamingClient {
586            sse_bytes: Bytes::from(sse),
587        };
588
589        let req = http::Request::builder()
590            .method("POST")
591            .uri("http://localhost/v1/chat/completions")
592            .body(Vec::new())
593            .unwrap();
594
595        let mut stream = send_compatible_streaming_request(client, req)
596            .await
597            .unwrap();
598
599        let mut final_response = None;
600        while let Some(chunk) = stream.next().await {
601            if let streaming::StreamedAssistantContent::Final(res) = chunk.unwrap() {
602                final_response = Some(res);
603                break;
604            }
605        }
606
607        let res = final_response.expect("expected a final response");
608
609        // Verify provider-level usage has the cached_tokens
610        assert_eq!(
611            res.usage
612                .prompt_tokens_details
613                .as_ref()
614                .unwrap()
615                .cached_tokens,
616            80
617        );
618
619        // Verify core Usage also has cached_input_tokens via GetTokenUsage
620        let core_usage = res.token_usage().expect("token_usage should return Some");
621        assert_eq!(core_usage.cached_input_tokens, 80);
622        assert_eq!(core_usage.input_tokens, 100);
623        assert_eq!(core_usage.total_tokens, 110);
624    }
625
626    /// Reproduces the bug where a proxy/gateway sends multiple parallel tool
627    /// calls all sharing `index: 0` but with distinct `id` values.  Without
628    /// the fix, rig merges both calls into one corrupted entry.
629    #[tokio::test]
630    async fn test_duplicate_index_different_id_tool_calls() {
631        use crate::http_client::mock::MockStreamingClient;
632        use bytes::Bytes;
633        use futures::StreamExt;
634
635        // Simulate a gateway that sends two tool calls both at index 0.
636        // First tool call: id="call_aaa", name="command", args={"cmd":"ls"}
637        // Second tool call: id="call_bbb", name="git", args={"action":"log"}
638        let sse = concat!(
639            // First tool call starts
640            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_aaa\",\"function\":{\"name\":\"command\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
641            // First tool call argument chunks
642            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"cmd\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
643            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"ls\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
644            // Second tool call starts AT THE SAME index 0 but with a NEW id
645            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_bbb\",\"function\":{\"name\":\"git\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
646            // Second tool call argument chunks
647            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\"{\\\"action\\\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
648            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":null,\"function\":{\"name\":null,\"arguments\":\":\\\"log\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
649            // Finish with tool_calls reason
650            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
651            // Usage chunk
652            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":20,\"completion_tokens\":10,\"total_tokens\":30}}\n\n",
653            "data: [DONE]\n\n",
654        );
655
656        let client = MockStreamingClient {
657            sse_bytes: Bytes::from(sse),
658        };
659
660        let req = http::Request::builder()
661            .method("POST")
662            .uri("http://localhost/v1/chat/completions")
663            .body(Vec::new())
664            .unwrap();
665
666        let mut stream = send_compatible_streaming_request(client, req)
667            .await
668            .unwrap();
669
670        let mut collected_tool_calls = Vec::new();
671        while let Some(chunk) = stream.next().await {
672            if let streaming::StreamedAssistantContent::ToolCall {
673                tool_call,
674                internal_call_id: _,
675            } = chunk.unwrap()
676            {
677                collected_tool_calls.push(tool_call);
678            }
679        }
680
681        assert_eq!(
682            collected_tool_calls.len(),
683            2,
684            "expected 2 separate tool calls, got {collected_tool_calls:?}"
685        );
686
687        assert_eq!(collected_tool_calls[0].id, "call_aaa");
688        assert_eq!(collected_tool_calls[0].function.name, "command");
689        assert_eq!(
690            collected_tool_calls[0].function.arguments,
691            serde_json::json!({"cmd": "ls"})
692        );
693
694        assert_eq!(collected_tool_calls[1].id, "call_bbb");
695        assert_eq!(collected_tool_calls[1].function.name, "git");
696        assert_eq!(
697            collected_tool_calls[1].function.arguments,
698            serde_json::json!({"action": "log"})
699        );
700    }
701
702    /// Reproduces the bug where a provider (e.g. GLM-4 via OpenAI-compatible
703    /// endpoint) sends a unique `id` on every SSE delta chunk for the same
704    /// logical tool call.  Without the fix, each chunk triggers an eviction,
705    /// yielding incomplete fragments as "completed" tool calls.
706    #[tokio::test]
707    async fn test_unique_id_per_chunk_single_tool_call() {
708        use crate::http_client::mock::MockStreamingClient;
709        use bytes::Bytes;
710        use futures::StreamExt;
711
712        // Each chunk carries a different id but they all represent delta
713        // fragments of the SAME tool call at index 0.
714        let sse = concat!(
715            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-aaa\",\"function\":{\"name\":\"web_search\",\"arguments\":\"null\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
716            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-bbb\",\"function\":{\"name\":\"\",\"arguments\":\"{\\\"query\\\": \\\"META\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
717            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"chatcmpl-tool-ccc\",\"function\":{\"name\":\"\",\"arguments\":\" Platforms news\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
718            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
719            "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":8,\"total_tokens\":23}}\n\n",
720            "data: [DONE]\n\n",
721        );
722
723        let client = MockStreamingClient {
724            sse_bytes: Bytes::from(sse),
725        };
726
727        let req = http::Request::builder()
728            .method("POST")
729            .uri("http://localhost/v1/chat/completions")
730            .body(Vec::new())
731            .unwrap();
732
733        let mut stream = send_compatible_streaming_request(client, req)
734            .await
735            .unwrap();
736
737        let mut collected_tool_calls = Vec::new();
738        while let Some(chunk) = stream.next().await {
739            if let streaming::StreamedAssistantContent::ToolCall {
740                tool_call,
741                internal_call_id: _,
742            } = chunk.unwrap()
743            {
744                collected_tool_calls.push(tool_call);
745            }
746        }
747
748        assert_eq!(
749            collected_tool_calls.len(),
750            1,
751            "expected 1 tool call (all chunks are fragments of the same call), got {collected_tool_calls:?}"
752        );
753
754        assert_eq!(collected_tool_calls[0].function.name, "web_search");
755        // The arguments should be the fully accumulated string, not fragments
756        let args_str = match &collected_tool_calls[0].function.arguments {
757            serde_json::Value::String(s) => s.clone(),
758            v => v.to_string(),
759        };
760        assert!(
761            args_str.contains("META Platforms news"),
762            "expected accumulated arguments containing the full query, got: {args_str}"
763        );
764    }
765
766    #[tokio::test]
767    async fn test_zero_arg_tool_call_normalized_on_finish_reason() {
768        use crate::http_client::mock::MockStreamingClient;
769        use bytes::Bytes;
770        use futures::StreamExt;
771
772        let sse = concat!(
773            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n",
774            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n",
775            "data: [DONE]\n\n",
776        );
777
778        let client = MockStreamingClient {
779            sse_bytes: Bytes::from(sse),
780        };
781
782        let req = http::Request::builder()
783            .method("POST")
784            .uri("http://localhost/v1/chat/completions")
785            .body(Vec::new())
786            .unwrap();
787
788        let mut stream = send_compatible_streaming_request(client, req)
789            .await
790            .unwrap();
791
792        let mut collected_tool_calls = Vec::new();
793        while let Some(chunk) = stream.next().await {
794            if let streaming::StreamedAssistantContent::ToolCall {
795                tool_call,
796                internal_call_id: _,
797            } = chunk.unwrap()
798            {
799                collected_tool_calls.push(tool_call);
800            }
801        }
802
803        assert_eq!(collected_tool_calls.len(), 1);
804        assert_eq!(collected_tool_calls[0].id, "call_123");
805        assert_eq!(collected_tool_calls[0].function.name, "ping");
806        assert_eq!(
807            collected_tool_calls[0].function.arguments,
808            serde_json::json!({})
809        );
810    }
811
812    #[tokio::test]
813    async fn test_incomplete_zero_arg_tool_call_preserves_null_on_cleanup_flush() {
814        use crate::http_client::mock::MockStreamingClient;
815        use bytes::Bytes;
816        use futures::StreamExt;
817
818        let sse = "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n";
819
820        let client = MockStreamingClient {
821            sse_bytes: Bytes::from(sse),
822        };
823
824        let req = http::Request::builder()
825            .method("POST")
826            .uri("http://localhost/v1/chat/completions")
827            .body(Vec::new())
828            .unwrap();
829
830        let mut stream = send_compatible_streaming_request(client, req)
831            .await
832            .unwrap();
833
834        let mut collected_tool_calls = Vec::new();
835        while let Some(chunk) = stream.next().await {
836            if let streaming::StreamedAssistantContent::ToolCall {
837                tool_call,
838                internal_call_id: _,
839            } = chunk.unwrap()
840            {
841                collected_tool_calls.push(tool_call);
842            }
843        }
844
845        assert_eq!(collected_tool_calls.len(), 1);
846        assert_eq!(collected_tool_calls[0].id, "call_123");
847        assert_eq!(collected_tool_calls[0].function.name, "ping");
848        assert_eq!(
849            collected_tool_calls[0].function.arguments,
850            serde_json::Value::Null
851        );
852    }
853}