Skip to main content

rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10    apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25    MessageStart {
26        message: MessageStart,
27    },
28    ContentBlockStart {
29        index: usize,
30        content_block: Content,
31    },
32    ContentBlockDelta {
33        index: usize,
34        delta: ContentDelta,
35    },
36    ContentBlockStop {
37        index: usize,
38    },
39    MessageDelta {
40        delta: MessageDelta,
41        usage: PartialUsage,
42    },
43    MessageStop,
44    Ping,
45    #[serde(other)]
46    Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51    pub id: String,
52    pub role: String,
53    pub content: Vec<Content>,
54    pub model: String,
55    pub stop_reason: Option<String>,
56    pub stop_sequence: Option<String>,
57    pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63    TextDelta { text: String },
64    InputJsonDelta { partial_json: String },
65    ThinkingDelta { thinking: String },
66    SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71    pub stop_reason: Option<String>,
72    pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77    pub output_tokens: usize,
78    #[serde(default)]
79    pub input_tokens: Option<usize>,
80}
81
82impl GetTokenUsage for PartialUsage {
83    fn token_usage(&self) -> Option<crate::completion::Usage> {
84        let mut usage = crate::completion::Usage::new();
85
86        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
87        usage.output_tokens = self.output_tokens as u64;
88        usage.total_tokens = usage.input_tokens + usage.output_tokens;
89        Some(usage)
90    }
91}
92
93#[derive(Default)]
94struct ToolCallState {
95    name: String,
96    id: String,
97    internal_call_id: String,
98    input_json: String,
99}
100
101#[derive(Default)]
102struct ThinkingState {
103    thinking: String,
104    signature: String,
105}
106
107#[derive(Clone, Debug, Deserialize, Serialize)]
108pub struct StreamingCompletionResponse {
109    pub usage: PartialUsage,
110}
111
112impl GetTokenUsage for StreamingCompletionResponse {
113    fn token_usage(&self) -> Option<crate::completion::Usage> {
114        let mut usage = crate::completion::Usage::new();
115        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
116        usage.output_tokens = self.usage.output_tokens as u64;
117        usage.total_tokens =
118            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
119
120        Some(usage)
121    }
122}
123
124impl<T> CompletionModel<T>
125where
126    T: HttpClientExt + Clone + Default + 'static,
127{
128    pub(crate) async fn stream(
129        &self,
130        completion_request: CompletionRequest,
131    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132    {
133        let request_model = completion_request
134            .model
135            .clone()
136            .unwrap_or_else(|| self.model.clone());
137        let span = if tracing::Span::current().is_disabled() {
138            info_span!(
139                target: "rig::completions",
140                "chat_streaming",
141                gen_ai.operation.name = "chat_streaming",
142                gen_ai.provider.name = "anthropic",
143                gen_ai.request.model = &request_model,
144                gen_ai.system_instructions = &completion_request.preamble,
145                gen_ai.response.id = tracing::field::Empty,
146                gen_ai.response.model = &request_model,
147                gen_ai.usage.output_tokens = tracing::field::Empty,
148                gen_ai.usage.input_tokens = tracing::field::Empty,
149                gen_ai.input.messages = tracing::field::Empty,
150                gen_ai.output.messages = tracing::field::Empty,
151            )
152        } else {
153            tracing::Span::current()
154        };
155        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
156            tokens
157        } else if let Some(tokens) = self.default_max_tokens {
158            tokens
159        } else {
160            return Err(CompletionError::RequestError(
161                "`max_tokens` must be set for Anthropic".into(),
162            ));
163        };
164
165        let mut full_history = vec![];
166        if let Some(docs) = completion_request.normalized_documents() {
167            full_history.push(docs);
168        }
169        full_history.extend(completion_request.chat_history);
170
171        let mut messages = full_history
172            .into_iter()
173            .map(Message::try_from)
174            .collect::<Result<Vec<Message>, _>>()?;
175
176        // Convert system prompt to array format for cache_control support
177        let mut system: Vec<SystemContent> =
178            if let Some(preamble) = completion_request.preamble.as_ref() {
179                if preamble.is_empty() {
180                    vec![]
181                } else {
182                    vec![SystemContent::Text {
183                        text: preamble.clone(),
184                        cache_control: None,
185                    }]
186                }
187            } else {
188                vec![]
189            };
190
191        // Apply cache control breakpoints only if prompt_caching is enabled
192        if self.prompt_caching {
193            apply_cache_control(&mut system, &mut messages);
194        }
195
196        let mut body = json!({
197            "model": request_model,
198            "messages": messages,
199            "max_tokens": max_tokens,
200            "stream": true,
201        });
202
203        // Add system prompt if non-empty
204        if !system.is_empty() {
205            merge_inplace(&mut body, json!({ "system": system }));
206        }
207
208        if let Some(temperature) = completion_request.temperature {
209            merge_inplace(&mut body, json!({ "temperature": temperature }));
210        }
211
212        if !completion_request.tools.is_empty() {
213            merge_inplace(
214                &mut body,
215                json!({
216                    "tools": completion_request
217                        .tools
218                        .into_iter()
219                        .map(|tool| ToolDefinition {
220                            name: tool.name,
221                            description: Some(tool.description),
222                            input_schema: tool.parameters,
223                        })
224                        .collect::<Vec<_>>(),
225                    "tool_choice": ToolChoice::Auto,
226                }),
227            );
228        }
229
230        if let Some(ref params) = completion_request.additional_params {
231            merge_inplace(&mut body, params.clone())
232        }
233
234        if enabled!(Level::TRACE) {
235            tracing::trace!(
236                target: "rig::completions",
237                "Anthropic completion request: {}",
238                serde_json::to_string_pretty(&body)?
239            );
240        }
241
242        let body: Vec<u8> = serde_json::to_vec(&body)?;
243
244        let req = self
245            .client
246            .post("/v1/messages")?
247            .body(body)
248            .map_err(http_client::Error::Protocol)?;
249
250        let stream = GenericEventSource::new(self.client.clone(), req);
251
252        // Use our SSE decoder to directly handle Server-Sent Events format
253        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
254            let mut current_tool_call: Option<ToolCallState> = None;
255            let mut current_thinking: Option<ThinkingState> = None;
256            let mut sse_stream = Box::pin(stream);
257            let mut input_tokens = 0;
258            let mut final_usage = None;
259
260            let mut text_content = String::new();
261
262            while let Some(sse_result) = sse_stream.next().await {
263                match sse_result {
264                    Ok(Event::Open) => {}
265                    Ok(Event::Message(sse)) => {
266                        // Parse the SSE data as a StreamingEvent
267                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
268                            Ok(event) => {
269                                match &event {
270                                    StreamingEvent::MessageStart { message } => {
271                                        input_tokens = message.usage.input_tokens;
272
273                                        let span = tracing::Span::current();
274                                        span.record("gen_ai.response.id", &message.id);
275                                        span.record("gen_ai.response.model_name", &message.model);
276                                    },
277                                    StreamingEvent::MessageDelta { delta, usage } => {
278                                        if delta.stop_reason.is_some() {
279                                            let usage = PartialUsage {
280                                                 output_tokens: usage.output_tokens,
281                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
282                                            };
283
284                                            let span = tracing::Span::current();
285                                            span.record_token_usage(&usage);
286                                            final_usage = Some(usage);
287                                            break;
288                                        }
289                                    }
290                                    _ => {}
291                                }
292
293                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
294                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
295                                        text_content += text;
296                                    }
297                                    yield result;
298                                }
299                            },
300                            Err(e) => {
301                                if !sse.data.trim().is_empty() {
302                                    yield Err(CompletionError::ResponseError(
303                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
304                                    ));
305                                }
306                            }
307                        }
308                    },
309                    Err(e) => {
310                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
311                        break;
312                    }
313                }
314            }
315
316            // Ensure event source is closed when stream ends
317            sse_stream.close();
318
319            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
320                usage: final_usage.unwrap_or_default()
321            }))
322        }.instrument(span));
323
324        Ok(streaming::StreamingCompletionResponse::stream(stream))
325    }
326}
327
328fn handle_event(
329    event: &StreamingEvent,
330    current_tool_call: &mut Option<ToolCallState>,
331    current_thinking: &mut Option<ThinkingState>,
332) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
333    match event {
334        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
335            ContentDelta::TextDelta { text } => {
336                if current_tool_call.is_none() {
337                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
338                }
339                None
340            }
341            ContentDelta::InputJsonDelta { partial_json } => {
342                if let Some(tool_call) = current_tool_call {
343                    tool_call.input_json.push_str(partial_json);
344                    // Emit the delta so UI can show progress
345                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
346                        id: tool_call.id.clone(),
347                        internal_call_id: tool_call.internal_call_id.clone(),
348                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
349                    }));
350                }
351                None
352            }
353            ContentDelta::ThinkingDelta { thinking } => {
354                current_thinking
355                    .get_or_insert_with(ThinkingState::default)
356                    .thinking
357                    .push_str(thinking);
358
359                Some(Ok(RawStreamingChoice::ReasoningDelta {
360                    id: None,
361                    reasoning: thinking.clone(),
362                }))
363            }
364            ContentDelta::SignatureDelta { signature } => {
365                current_thinking
366                    .get_or_insert_with(ThinkingState::default)
367                    .signature
368                    .push_str(signature);
369
370                // Don't yield signature chunks, they will be included in the final Reasoning
371                None
372            }
373        },
374        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
375            Content::ToolUse { id, name, .. } => {
376                let internal_call_id = nanoid::nanoid!();
377                *current_tool_call = Some(ToolCallState {
378                    name: name.clone(),
379                    id: id.clone(),
380                    internal_call_id: internal_call_id.clone(),
381                    input_json: String::new(),
382                });
383                Some(Ok(RawStreamingChoice::ToolCallDelta {
384                    id: id.clone(),
385                    internal_call_id,
386                    content: ToolCallDeltaContent::Name(name.clone()),
387                }))
388            }
389            Content::Thinking { .. } => {
390                *current_thinking = Some(ThinkingState::default());
391                None
392            }
393            Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
394                id: None,
395                content: ReasoningContent::Redacted { data: data.clone() },
396            })),
397            // Handle other content types - they don't need special handling
398            _ => None,
399        },
400        StreamingEvent::ContentBlockStop { .. } => {
401            if let Some(thinking_state) = Option::take(current_thinking)
402                && !thinking_state.thinking.is_empty()
403            {
404                let signature = if thinking_state.signature.is_empty() {
405                    None
406                } else {
407                    Some(thinking_state.signature)
408                };
409
410                return Some(Ok(RawStreamingChoice::Reasoning {
411                    id: None,
412                    content: ReasoningContent::Text {
413                        text: thinking_state.thinking,
414                        signature,
415                    },
416                }));
417            }
418
419            if let Some(tool_call) = Option::take(current_tool_call) {
420                let json_str = if tool_call.input_json.is_empty() {
421                    "{}"
422                } else {
423                    &tool_call.input_json
424                };
425                match serde_json::from_str(json_str) {
426                    Ok(json_value) => {
427                        let raw_tool_call =
428                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
429                                .with_internal_call_id(tool_call.internal_call_id);
430                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
431                    }
432                    Err(e) => Some(Err(CompletionError::from(e))),
433                }
434            } else {
435                None
436            }
437        }
438        // Ignore other event types or handle as needed
439        StreamingEvent::MessageStart { .. }
440        | StreamingEvent::MessageDelta { .. }
441        | StreamingEvent::MessageStop
442        | StreamingEvent::Ping
443        | StreamingEvent::Unknown => None,
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_thinking_delta_deserialization() {
453        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
454        let delta: ContentDelta = serde_json::from_str(json).unwrap();
455
456        match delta {
457            ContentDelta::ThinkingDelta { thinking } => {
458                assert_eq!(thinking, "Let me think about this...");
459            }
460            _ => panic!("Expected ThinkingDelta variant"),
461        }
462    }
463
464    #[test]
465    fn test_signature_delta_deserialization() {
466        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
467        let delta: ContentDelta = serde_json::from_str(json).unwrap();
468
469        match delta {
470            ContentDelta::SignatureDelta { signature } => {
471                assert_eq!(signature, "abc123def456");
472            }
473            _ => panic!("Expected SignatureDelta variant"),
474        }
475    }
476
477    #[test]
478    fn test_thinking_delta_streaming_event_deserialization() {
479        let json = r#"{
480            "type": "content_block_delta",
481            "index": 0,
482            "delta": {
483                "type": "thinking_delta",
484                "thinking": "First, I need to understand the problem."
485            }
486        }"#;
487
488        let event: StreamingEvent = serde_json::from_str(json).unwrap();
489
490        match event {
491            StreamingEvent::ContentBlockDelta { index, delta } => {
492                assert_eq!(index, 0);
493                match delta {
494                    ContentDelta::ThinkingDelta { thinking } => {
495                        assert_eq!(thinking, "First, I need to understand the problem.");
496                    }
497                    _ => panic!("Expected ThinkingDelta"),
498                }
499            }
500            _ => panic!("Expected ContentBlockDelta event"),
501        }
502    }
503
504    #[test]
505    fn test_signature_delta_streaming_event_deserialization() {
506        let json = r#"{
507            "type": "content_block_delta",
508            "index": 0,
509            "delta": {
510                "type": "signature_delta",
511                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
512            }
513        }"#;
514
515        let event: StreamingEvent = serde_json::from_str(json).unwrap();
516
517        match event {
518            StreamingEvent::ContentBlockDelta { index, delta } => {
519                assert_eq!(index, 0);
520                match delta {
521                    ContentDelta::SignatureDelta { signature } => {
522                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
523                    }
524                    _ => panic!("Expected SignatureDelta"),
525                }
526            }
527            _ => panic!("Expected ContentBlockDelta event"),
528        }
529    }
530
531    #[test]
532    fn test_handle_thinking_delta_event() {
533        let event = StreamingEvent::ContentBlockDelta {
534            index: 0,
535            delta: ContentDelta::ThinkingDelta {
536                thinking: "Analyzing the request...".to_string(),
537            },
538        };
539
540        let mut tool_call_state = None;
541        let mut thinking_state = None;
542        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
543
544        assert!(result.is_some());
545        let choice = result.unwrap().unwrap();
546
547        match choice {
548            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
549                assert_eq!(id, None);
550                assert_eq!(reasoning, "Analyzing the request...");
551            }
552            _ => panic!("Expected ReasoningDelta choice"),
553        }
554
555        // Verify thinking state was updated
556        assert!(thinking_state.is_some());
557        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
558    }
559
560    #[test]
561    fn test_handle_signature_delta_event() {
562        let event = StreamingEvent::ContentBlockDelta {
563            index: 0,
564            delta: ContentDelta::SignatureDelta {
565                signature: "test_signature".to_string(),
566            },
567        };
568
569        let mut tool_call_state = None;
570        let mut thinking_state = None;
571        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
572
573        // SignatureDelta should not yield anything (returns None)
574        assert!(result.is_none());
575
576        // But signature should be captured in thinking state
577        assert!(thinking_state.is_some());
578        assert_eq!(thinking_state.unwrap().signature, "test_signature");
579    }
580
581    #[test]
582    fn test_handle_redacted_thinking_content_block_start_event() {
583        let event = StreamingEvent::ContentBlockStart {
584            index: 0,
585            content_block: Content::RedactedThinking {
586                data: "redacted_blob".to_string(),
587            },
588        };
589        let mut tool_call_state = None;
590        let mut thinking_state = None;
591        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
592
593        assert!(result.is_some());
594        match result.unwrap().unwrap() {
595            RawStreamingChoice::Reasoning {
596                content: ReasoningContent::Redacted { data },
597                ..
598            } => {
599                assert_eq!(data, "redacted_blob");
600            }
601            _ => panic!("Expected Redacted reasoning chunk"),
602        }
603    }
604
605    #[test]
606    fn test_handle_text_delta_event() {
607        let event = StreamingEvent::ContentBlockDelta {
608            index: 0,
609            delta: ContentDelta::TextDelta {
610                text: "Hello, world!".to_string(),
611            },
612        };
613
614        let mut tool_call_state = None;
615        let mut thinking_state = None;
616        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
617
618        assert!(result.is_some());
619        let choice = result.unwrap().unwrap();
620
621        match choice {
622            RawStreamingChoice::Message(text) => {
623                assert_eq!(text, "Hello, world!");
624            }
625            _ => panic!("Expected Message choice"),
626        }
627    }
628
629    #[test]
630    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
631        // Thinking deltas should still be processed even if a tool call is in progress
632        let event = StreamingEvent::ContentBlockDelta {
633            index: 0,
634            delta: ContentDelta::ThinkingDelta {
635                thinking: "Thinking while tool is active...".to_string(),
636            },
637        };
638
639        let mut tool_call_state = Some(ToolCallState {
640            name: "test_tool".to_string(),
641            id: "tool_123".to_string(),
642            internal_call_id: nanoid::nanoid!(),
643            input_json: String::new(),
644        });
645        let mut thinking_state = None;
646
647        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
648
649        assert!(result.is_some());
650        let choice = result.unwrap().unwrap();
651
652        match choice {
653            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
654                assert_eq!(reasoning, "Thinking while tool is active...");
655            }
656            _ => panic!("Expected ReasoningDelta choice"),
657        }
658
659        // Tool call state should remain unchanged
660        assert!(tool_call_state.is_some());
661    }
662
663    #[test]
664    fn test_handle_input_json_delta_event() {
665        let event = StreamingEvent::ContentBlockDelta {
666            index: 0,
667            delta: ContentDelta::InputJsonDelta {
668                partial_json: "{\"arg\":\"value".to_string(),
669            },
670        };
671
672        let mut tool_call_state = Some(ToolCallState {
673            name: "test_tool".to_string(),
674            id: "tool_123".to_string(),
675            internal_call_id: nanoid::nanoid!(),
676            input_json: String::new(),
677        });
678        let mut thinking_state = None;
679
680        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
681
682        // Should emit a ToolCallDelta
683        assert!(result.is_some());
684        let choice = result.unwrap().unwrap();
685
686        match choice {
687            RawStreamingChoice::ToolCallDelta {
688                id,
689                internal_call_id: _,
690                content,
691            } => {
692                assert_eq!(id, "tool_123");
693                match content {
694                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
695                    _ => panic!("Expected Delta content"),
696                }
697            }
698            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
699        }
700
701        // Verify the input_json was accumulated
702        assert!(tool_call_state.is_some());
703        let state = tool_call_state.unwrap();
704        assert_eq!(state.input_json, "{\"arg\":\"value");
705    }
706
707    #[test]
708    fn test_tool_call_accumulation_with_multiple_deltas() {
709        let mut tool_call_state = Some(ToolCallState {
710            name: "test_tool".to_string(),
711            id: "tool_123".to_string(),
712            internal_call_id: nanoid::nanoid!(),
713            input_json: String::new(),
714        });
715        let mut thinking_state = None;
716
717        // First delta
718        let event1 = StreamingEvent::ContentBlockDelta {
719            index: 0,
720            delta: ContentDelta::InputJsonDelta {
721                partial_json: "{\"location\":".to_string(),
722            },
723        };
724        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
725        assert!(result1.is_some());
726
727        // Second delta
728        let event2 = StreamingEvent::ContentBlockDelta {
729            index: 0,
730            delta: ContentDelta::InputJsonDelta {
731                partial_json: "\"Paris\",".to_string(),
732            },
733        };
734        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
735        assert!(result2.is_some());
736
737        // Third delta
738        let event3 = StreamingEvent::ContentBlockDelta {
739            index: 0,
740            delta: ContentDelta::InputJsonDelta {
741                partial_json: "\"temp\":\"20C\"}".to_string(),
742            },
743        };
744        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
745        assert!(result3.is_some());
746
747        // Verify accumulated JSON
748        assert!(tool_call_state.is_some());
749        let state = tool_call_state.as_ref().unwrap();
750        assert_eq!(
751            state.input_json,
752            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
753        );
754
755        // Final ContentBlockStop should emit complete tool call
756        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
757        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
758        assert!(final_result.is_some());
759
760        match final_result.unwrap().unwrap() {
761            RawStreamingChoice::ToolCall(RawStreamingToolCall {
762                id,
763                name,
764                arguments,
765                ..
766            }) => {
767                assert_eq!(id, "tool_123");
768                assert_eq!(name, "test_tool");
769                assert_eq!(
770                    arguments.get("location").unwrap().as_str().unwrap(),
771                    "Paris"
772                );
773                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
774            }
775            other => panic!("Expected ToolCall, got {:?}", other),
776        }
777
778        // Tool call state should be taken
779        assert!(tool_call_state.is_none());
780    }
781}