Skip to main content

rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10    apply_cache_control, split_system_messages_from_history,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25    MessageStart {
26        message: MessageStart,
27    },
28    ContentBlockStart {
29        index: usize,
30        content_block: Content,
31    },
32    ContentBlockDelta {
33        index: usize,
34        delta: ContentDelta,
35    },
36    ContentBlockStop {
37        index: usize,
38    },
39    MessageDelta {
40        delta: MessageDelta,
41        usage: PartialUsage,
42    },
43    MessageStop,
44    Ping,
45    #[serde(other)]
46    Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51    pub id: String,
52    pub role: String,
53    pub content: Vec<Content>,
54    pub model: String,
55    pub stop_reason: Option<String>,
56    pub stop_sequence: Option<String>,
57    pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63    TextDelta { text: String },
64    InputJsonDelta { partial_json: String },
65    ThinkingDelta { thinking: String },
66    SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71    pub stop_reason: Option<String>,
72    pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77    pub output_tokens: usize,
78    #[serde(default)]
79    pub input_tokens: Option<usize>,
80}
81
82impl GetTokenUsage for PartialUsage {
83    fn token_usage(&self) -> Option<crate::completion::Usage> {
84        let mut usage = crate::completion::Usage::new();
85
86        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
87        usage.output_tokens = self.output_tokens as u64;
88        usage.total_tokens = usage.input_tokens + usage.output_tokens;
89        Some(usage)
90    }
91}
92
93#[derive(Default)]
94struct ToolCallState {
95    name: String,
96    id: String,
97    internal_call_id: String,
98    input_json: String,
99}
100
101#[derive(Default)]
102struct ThinkingState {
103    thinking: String,
104    signature: String,
105}
106
107#[derive(Clone, Debug, Deserialize, Serialize)]
108pub struct StreamingCompletionResponse {
109    pub usage: PartialUsage,
110}
111
112impl GetTokenUsage for StreamingCompletionResponse {
113    fn token_usage(&self) -> Option<crate::completion::Usage> {
114        let mut usage = crate::completion::Usage::new();
115        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
116        usage.output_tokens = self.usage.output_tokens as u64;
117        usage.total_tokens =
118            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
119
120        Some(usage)
121    }
122}
123
124impl<T> CompletionModel<T>
125where
126    T: HttpClientExt + Clone + Default + 'static,
127{
128    pub(crate) async fn stream(
129        &self,
130        mut completion_request: CompletionRequest,
131    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132    {
133        let request_model = completion_request
134            .model
135            .clone()
136            .unwrap_or_else(|| self.model.clone());
137        let span = if tracing::Span::current().is_disabled() {
138            info_span!(
139                target: "rig::completions",
140                "chat_streaming",
141                gen_ai.operation.name = "chat_streaming",
142                gen_ai.provider.name = "anthropic",
143                gen_ai.request.model = &request_model,
144                gen_ai.system_instructions = &completion_request.preamble,
145                gen_ai.response.id = tracing::field::Empty,
146                gen_ai.response.model = &request_model,
147                gen_ai.usage.output_tokens = tracing::field::Empty,
148                gen_ai.usage.input_tokens = tracing::field::Empty,
149                gen_ai.usage.cached_tokens = tracing::field::Empty,
150                gen_ai.input.messages = tracing::field::Empty,
151                gen_ai.output.messages = tracing::field::Empty,
152            )
153        } else {
154            tracing::Span::current()
155        };
156        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
157            tokens
158        } else if let Some(tokens) = self.default_max_tokens {
159            tokens
160        } else {
161            return Err(CompletionError::RequestError(
162                "`max_tokens` must be set for Anthropic".into(),
163            ));
164        };
165
166        let mut full_history = vec![];
167        if let Some(docs) = completion_request.normalized_documents() {
168            full_history.push(docs);
169        }
170        full_history.extend(completion_request.chat_history);
171        let (history_system, full_history) = split_system_messages_from_history(full_history);
172
173        let mut messages = full_history
174            .into_iter()
175            .map(Message::try_from)
176            .collect::<Result<Vec<Message>, _>>()?;
177
178        // Convert system prompt to array format for cache_control support
179        let mut system: Vec<SystemContent> =
180            if let Some(preamble) = completion_request.preamble.as_ref() {
181                if preamble.is_empty() {
182                    vec![]
183                } else {
184                    vec![SystemContent::Text {
185                        text: preamble.clone(),
186                        cache_control: None,
187                    }]
188                }
189            } else {
190                vec![]
191            };
192        system.extend(history_system);
193
194        // Apply cache control breakpoints only if prompt_caching is enabled
195        if self.prompt_caching {
196            apply_cache_control(&mut system, &mut messages);
197        }
198
199        let mut body = json!({
200            "model": request_model,
201            "messages": messages,
202            "max_tokens": max_tokens,
203            "stream": true,
204        });
205
206        // Add system prompt if non-empty
207        if !system.is_empty() {
208            merge_inplace(&mut body, json!({ "system": system }));
209        }
210
211        if let Some(temperature) = completion_request.temperature {
212            merge_inplace(&mut body, json!({ "temperature": temperature }));
213        }
214
215        let mut additional_params_payload = completion_request
216            .additional_params
217            .take()
218            .unwrap_or(Value::Null);
219        let mut additional_tools =
220            extract_tools_from_additional_params(&mut additional_params_payload)?;
221
222        let mut tools = completion_request
223            .tools
224            .into_iter()
225            .map(|tool| ToolDefinition {
226                name: tool.name,
227                description: Some(tool.description),
228                input_schema: tool.parameters,
229            })
230            .map(serde_json::to_value)
231            .collect::<Result<Vec<_>, _>>()?;
232        tools.append(&mut additional_tools);
233
234        if !tools.is_empty() {
235            merge_inplace(
236                &mut body,
237                json!({
238                    "tools": tools,
239                    "tool_choice": ToolChoice::Auto,
240                }),
241            );
242        }
243
244        if !additional_params_payload.is_null() {
245            merge_inplace(&mut body, additional_params_payload)
246        }
247
248        if enabled!(Level::TRACE) {
249            tracing::trace!(
250                target: "rig::completions",
251                "Anthropic completion request: {}",
252                serde_json::to_string_pretty(&body)?
253            );
254        }
255
256        let body: Vec<u8> = serde_json::to_vec(&body)?;
257
258        let req = self
259            .client
260            .post("/v1/messages")?
261            .body(body)
262            .map_err(http_client::Error::Protocol)?;
263
264        let stream = GenericEventSource::new(self.client.clone(), req);
265
266        // Use our SSE decoder to directly handle Server-Sent Events format
267        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
268            let mut current_tool_call: Option<ToolCallState> = None;
269            let mut current_thinking: Option<ThinkingState> = None;
270            let mut sse_stream = Box::pin(stream);
271            let mut input_tokens = 0;
272            let mut final_usage = None;
273
274            let mut text_content = String::new();
275
276            while let Some(sse_result) = sse_stream.next().await {
277                match sse_result {
278                    Ok(Event::Open) => {}
279                    Ok(Event::Message(sse)) => {
280                        // Parse the SSE data as a StreamingEvent
281                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
282                            Ok(event) => {
283                                match &event {
284                                    StreamingEvent::MessageStart { message } => {
285                                        input_tokens = message.usage.input_tokens;
286
287                                        let span = tracing::Span::current();
288                                        span.record("gen_ai.response.id", &message.id);
289                                        span.record("gen_ai.response.model_name", &message.model);
290                                    },
291                                    StreamingEvent::MessageDelta { delta, usage } => {
292                                        if delta.stop_reason.is_some() {
293                                            let usage = PartialUsage {
294                                                 output_tokens: usage.output_tokens,
295                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
296                                            };
297
298                                            let span = tracing::Span::current();
299                                            span.record_token_usage(&usage);
300                                            final_usage = Some(usage);
301                                            break;
302                                        }
303                                    }
304                                    _ => {}
305                                }
306
307                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
308                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
309                                        text_content += text;
310                                    }
311                                    yield result;
312                                }
313                            },
314                            Err(e) => {
315                                if !sse.data.trim().is_empty() {
316                                    yield Err(CompletionError::ResponseError(
317                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
318                                    ));
319                                }
320                            }
321                        }
322                    },
323                    Err(e) => {
324                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
325                        break;
326                    }
327                }
328            }
329
330            // Ensure event source is closed when stream ends
331            sse_stream.close();
332
333            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
334                usage: final_usage.unwrap_or_default()
335            }))
336        }.instrument(span));
337
338        Ok(streaming::StreamingCompletionResponse::stream(stream))
339    }
340}
341
342fn extract_tools_from_additional_params(
343    additional_params: &mut Value,
344) -> Result<Vec<Value>, CompletionError> {
345    if let Some(map) = additional_params.as_object_mut()
346        && let Some(raw_tools) = map.remove("tools")
347    {
348        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
349            CompletionError::RequestError(
350                format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
351            )
352        });
353    }
354
355    Ok(Vec::new())
356}
357
358fn handle_event(
359    event: &StreamingEvent,
360    current_tool_call: &mut Option<ToolCallState>,
361    current_thinking: &mut Option<ThinkingState>,
362) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
363    match event {
364        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
365            ContentDelta::TextDelta { text } => {
366                if current_tool_call.is_none() {
367                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
368                }
369                None
370            }
371            ContentDelta::InputJsonDelta { partial_json } => {
372                if let Some(tool_call) = current_tool_call {
373                    tool_call.input_json.push_str(partial_json);
374                    // Emit the delta so UI can show progress
375                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
376                        id: tool_call.id.clone(),
377                        internal_call_id: tool_call.internal_call_id.clone(),
378                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
379                    }));
380                }
381                None
382            }
383            ContentDelta::ThinkingDelta { thinking } => {
384                current_thinking
385                    .get_or_insert_with(ThinkingState::default)
386                    .thinking
387                    .push_str(thinking);
388
389                Some(Ok(RawStreamingChoice::ReasoningDelta {
390                    id: None,
391                    reasoning: thinking.clone(),
392                }))
393            }
394            ContentDelta::SignatureDelta { signature } => {
395                current_thinking
396                    .get_or_insert_with(ThinkingState::default)
397                    .signature
398                    .push_str(signature);
399
400                // Don't yield signature chunks, they will be included in the final Reasoning
401                None
402            }
403        },
404        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
405            Content::ToolUse { id, name, .. } => {
406                let internal_call_id = nanoid::nanoid!();
407                *current_tool_call = Some(ToolCallState {
408                    name: name.clone(),
409                    id: id.clone(),
410                    internal_call_id: internal_call_id.clone(),
411                    input_json: String::new(),
412                });
413                Some(Ok(RawStreamingChoice::ToolCallDelta {
414                    id: id.clone(),
415                    internal_call_id,
416                    content: ToolCallDeltaContent::Name(name.clone()),
417                }))
418            }
419            Content::Thinking { .. } => {
420                *current_thinking = Some(ThinkingState::default());
421                None
422            }
423            Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
424                id: None,
425                content: ReasoningContent::Redacted { data: data.clone() },
426            })),
427            // Handle other content types - they don't need special handling
428            _ => None,
429        },
430        StreamingEvent::ContentBlockStop { .. } => {
431            if let Some(thinking_state) = Option::take(current_thinking)
432                && !thinking_state.thinking.is_empty()
433            {
434                let signature = if thinking_state.signature.is_empty() {
435                    None
436                } else {
437                    Some(thinking_state.signature)
438                };
439
440                return Some(Ok(RawStreamingChoice::Reasoning {
441                    id: None,
442                    content: ReasoningContent::Text {
443                        text: thinking_state.thinking,
444                        signature,
445                    },
446                }));
447            }
448
449            if let Some(tool_call) = Option::take(current_tool_call) {
450                let json_str = if tool_call.input_json.is_empty() {
451                    "{}"
452                } else {
453                    &tool_call.input_json
454                };
455                match serde_json::from_str(json_str) {
456                    Ok(json_value) => {
457                        let raw_tool_call =
458                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
459                                .with_internal_call_id(tool_call.internal_call_id);
460                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
461                    }
462                    Err(e) => Some(Err(CompletionError::from(e))),
463                }
464            } else {
465                None
466            }
467        }
468        // Ignore other event types or handle as needed
469        StreamingEvent::MessageStart { .. }
470        | StreamingEvent::MessageDelta { .. }
471        | StreamingEvent::MessageStop
472        | StreamingEvent::Ping
473        | StreamingEvent::Unknown => None,
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_thinking_delta_deserialization() {
483        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
484        let delta: ContentDelta = serde_json::from_str(json).unwrap();
485
486        match delta {
487            ContentDelta::ThinkingDelta { thinking } => {
488                assert_eq!(thinking, "Let me think about this...");
489            }
490            _ => panic!("Expected ThinkingDelta variant"),
491        }
492    }
493
494    #[test]
495    fn test_signature_delta_deserialization() {
496        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
497        let delta: ContentDelta = serde_json::from_str(json).unwrap();
498
499        match delta {
500            ContentDelta::SignatureDelta { signature } => {
501                assert_eq!(signature, "abc123def456");
502            }
503            _ => panic!("Expected SignatureDelta variant"),
504        }
505    }
506
507    #[test]
508    fn test_thinking_delta_streaming_event_deserialization() {
509        let json = r#"{
510            "type": "content_block_delta",
511            "index": 0,
512            "delta": {
513                "type": "thinking_delta",
514                "thinking": "First, I need to understand the problem."
515            }
516        }"#;
517
518        let event: StreamingEvent = serde_json::from_str(json).unwrap();
519
520        match event {
521            StreamingEvent::ContentBlockDelta { index, delta } => {
522                assert_eq!(index, 0);
523                match delta {
524                    ContentDelta::ThinkingDelta { thinking } => {
525                        assert_eq!(thinking, "First, I need to understand the problem.");
526                    }
527                    _ => panic!("Expected ThinkingDelta"),
528                }
529            }
530            _ => panic!("Expected ContentBlockDelta event"),
531        }
532    }
533
534    #[test]
535    fn test_signature_delta_streaming_event_deserialization() {
536        let json = r#"{
537            "type": "content_block_delta",
538            "index": 0,
539            "delta": {
540                "type": "signature_delta",
541                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
542            }
543        }"#;
544
545        let event: StreamingEvent = serde_json::from_str(json).unwrap();
546
547        match event {
548            StreamingEvent::ContentBlockDelta { index, delta } => {
549                assert_eq!(index, 0);
550                match delta {
551                    ContentDelta::SignatureDelta { signature } => {
552                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
553                    }
554                    _ => panic!("Expected SignatureDelta"),
555                }
556            }
557            _ => panic!("Expected ContentBlockDelta event"),
558        }
559    }
560
561    #[test]
562    fn test_handle_thinking_delta_event() {
563        let event = StreamingEvent::ContentBlockDelta {
564            index: 0,
565            delta: ContentDelta::ThinkingDelta {
566                thinking: "Analyzing the request...".to_string(),
567            },
568        };
569
570        let mut tool_call_state = None;
571        let mut thinking_state = None;
572        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
573
574        assert!(result.is_some());
575        let choice = result.unwrap().unwrap();
576
577        match choice {
578            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
579                assert_eq!(id, None);
580                assert_eq!(reasoning, "Analyzing the request...");
581            }
582            _ => panic!("Expected ReasoningDelta choice"),
583        }
584
585        // Verify thinking state was updated
586        assert!(thinking_state.is_some());
587        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
588    }
589
590    #[test]
591    fn test_handle_signature_delta_event() {
592        let event = StreamingEvent::ContentBlockDelta {
593            index: 0,
594            delta: ContentDelta::SignatureDelta {
595                signature: "test_signature".to_string(),
596            },
597        };
598
599        let mut tool_call_state = None;
600        let mut thinking_state = None;
601        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
602
603        // SignatureDelta should not yield anything (returns None)
604        assert!(result.is_none());
605
606        // But signature should be captured in thinking state
607        assert!(thinking_state.is_some());
608        assert_eq!(thinking_state.unwrap().signature, "test_signature");
609    }
610
611    #[test]
612    fn test_handle_redacted_thinking_content_block_start_event() {
613        let event = StreamingEvent::ContentBlockStart {
614            index: 0,
615            content_block: Content::RedactedThinking {
616                data: "redacted_blob".to_string(),
617            },
618        };
619        let mut tool_call_state = None;
620        let mut thinking_state = None;
621        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
622
623        assert!(result.is_some());
624        match result.unwrap().unwrap() {
625            RawStreamingChoice::Reasoning {
626                content: ReasoningContent::Redacted { data },
627                ..
628            } => {
629                assert_eq!(data, "redacted_blob");
630            }
631            _ => panic!("Expected Redacted reasoning chunk"),
632        }
633    }
634
635    #[test]
636    fn test_handle_text_delta_event() {
637        let event = StreamingEvent::ContentBlockDelta {
638            index: 0,
639            delta: ContentDelta::TextDelta {
640                text: "Hello, world!".to_string(),
641            },
642        };
643
644        let mut tool_call_state = None;
645        let mut thinking_state = None;
646        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
647
648        assert!(result.is_some());
649        let choice = result.unwrap().unwrap();
650
651        match choice {
652            RawStreamingChoice::Message(text) => {
653                assert_eq!(text, "Hello, world!");
654            }
655            _ => panic!("Expected Message choice"),
656        }
657    }
658
659    #[test]
660    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
661        // Thinking deltas should still be processed even if a tool call is in progress
662        let event = StreamingEvent::ContentBlockDelta {
663            index: 0,
664            delta: ContentDelta::ThinkingDelta {
665                thinking: "Thinking while tool is active...".to_string(),
666            },
667        };
668
669        let mut tool_call_state = Some(ToolCallState {
670            name: "test_tool".to_string(),
671            id: "tool_123".to_string(),
672            internal_call_id: nanoid::nanoid!(),
673            input_json: String::new(),
674        });
675        let mut thinking_state = None;
676
677        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
678
679        assert!(result.is_some());
680        let choice = result.unwrap().unwrap();
681
682        match choice {
683            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
684                assert_eq!(reasoning, "Thinking while tool is active...");
685            }
686            _ => panic!("Expected ReasoningDelta choice"),
687        }
688
689        // Tool call state should remain unchanged
690        assert!(tool_call_state.is_some());
691    }
692
693    #[test]
694    fn test_handle_input_json_delta_event() {
695        let event = StreamingEvent::ContentBlockDelta {
696            index: 0,
697            delta: ContentDelta::InputJsonDelta {
698                partial_json: "{\"arg\":\"value".to_string(),
699            },
700        };
701
702        let mut tool_call_state = Some(ToolCallState {
703            name: "test_tool".to_string(),
704            id: "tool_123".to_string(),
705            internal_call_id: nanoid::nanoid!(),
706            input_json: String::new(),
707        });
708        let mut thinking_state = None;
709
710        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
711
712        // Should emit a ToolCallDelta
713        assert!(result.is_some());
714        let choice = result.unwrap().unwrap();
715
716        match choice {
717            RawStreamingChoice::ToolCallDelta {
718                id,
719                internal_call_id: _,
720                content,
721            } => {
722                assert_eq!(id, "tool_123");
723                match content {
724                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
725                    _ => panic!("Expected Delta content"),
726                }
727            }
728            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
729        }
730
731        // Verify the input_json was accumulated
732        assert!(tool_call_state.is_some());
733        let state = tool_call_state.unwrap();
734        assert_eq!(state.input_json, "{\"arg\":\"value");
735    }
736
737    #[test]
738    fn test_tool_call_accumulation_with_multiple_deltas() {
739        let mut tool_call_state = Some(ToolCallState {
740            name: "test_tool".to_string(),
741            id: "tool_123".to_string(),
742            internal_call_id: nanoid::nanoid!(),
743            input_json: String::new(),
744        });
745        let mut thinking_state = None;
746
747        // First delta
748        let event1 = StreamingEvent::ContentBlockDelta {
749            index: 0,
750            delta: ContentDelta::InputJsonDelta {
751                partial_json: "{\"location\":".to_string(),
752            },
753        };
754        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
755        assert!(result1.is_some());
756
757        // Second delta
758        let event2 = StreamingEvent::ContentBlockDelta {
759            index: 0,
760            delta: ContentDelta::InputJsonDelta {
761                partial_json: "\"Paris\",".to_string(),
762            },
763        };
764        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
765        assert!(result2.is_some());
766
767        // Third delta
768        let event3 = StreamingEvent::ContentBlockDelta {
769            index: 0,
770            delta: ContentDelta::InputJsonDelta {
771                partial_json: "\"temp\":\"20C\"}".to_string(),
772            },
773        };
774        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
775        assert!(result3.is_some());
776
777        // Verify accumulated JSON
778        assert!(tool_call_state.is_some());
779        let state = tool_call_state.as_ref().unwrap();
780        assert_eq!(
781            state.input_json,
782            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
783        );
784
785        // Final ContentBlockStop should emit complete tool call
786        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
787        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
788        assert!(final_result.is_some());
789
790        match final_result.unwrap().unwrap() {
791            RawStreamingChoice::ToolCall(RawStreamingToolCall {
792                id,
793                name,
794                arguments,
795                ..
796            }) => {
797                assert_eq!(id, "tool_123");
798                assert_eq!(name, "test_tool");
799                assert_eq!(
800                    arguments.get("location").unwrap().as_str().unwrap(),
801                    "Paris"
802                );
803                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
804            }
805            other => panic!("Expected ToolCall, got {:?}", other),
806        }
807
808        // Tool call state should be taken
809        assert!(tool_call_state.is_none());
810    }
811}