Skip to main content

rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    AnthropicCompatibleProvider, CacheControl, Content, GenericCompletionModel, Message,
10    SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control,
11    split_system_messages_from_history,
12};
13use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge_inplace;
17use crate::message::ReasoningContent;
18use crate::streaming::{
19    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
20};
21use crate::telemetry::SpanCombinator;
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23
24#[derive(Debug, Deserialize)]
25#[serde(tag = "type", rename_all = "snake_case")]
26pub enum StreamingEvent {
27    MessageStart {
28        message: MessageStart,
29    },
30    ContentBlockStart {
31        index: usize,
32        content_block: Content,
33    },
34    ContentBlockDelta {
35        index: usize,
36        delta: ContentDelta,
37    },
38    ContentBlockStop {
39        index: usize,
40    },
41    MessageDelta {
42        delta: MessageDelta,
43        usage: PartialUsage,
44    },
45    MessageStop,
46    Ping,
47    #[serde(other)]
48    Unknown,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct MessageStart {
53    pub id: String,
54    pub role: String,
55    pub content: Vec<Content>,
56    pub model: String,
57    pub stop_reason: Option<String>,
58    pub stop_sequence: Option<String>,
59    pub usage: Usage,
60}
61
62#[derive(Debug, Deserialize)]
63#[serde(tag = "type", rename_all = "snake_case")]
64pub enum ContentDelta {
65    TextDelta { text: String },
66    InputJsonDelta { partial_json: String },
67    ThinkingDelta { thinking: String },
68    SignatureDelta { signature: String },
69}
70
71#[derive(Debug, Deserialize)]
72pub struct MessageDelta {
73    pub stop_reason: Option<String>,
74    pub stop_sequence: Option<String>,
75}
76
77#[derive(Debug, Deserialize, Clone, Serialize, Default)]
78pub struct PartialUsage {
79    pub output_tokens: usize,
80    #[serde(default)]
81    pub input_tokens: Option<usize>,
82    #[serde(default)]
83    pub cache_creation_input_tokens: Option<u64>,
84    #[serde(default)]
85    pub cache_read_input_tokens: Option<u64>,
86}
87
88impl GetTokenUsage for PartialUsage {
89    fn token_usage(&self) -> Option<crate::completion::Usage> {
90        let mut usage = crate::completion::Usage::new();
91
92        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
93        usage.output_tokens = self.output_tokens as u64;
94        usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
95        usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
96        usage.total_tokens = usage.input_tokens
97            + usage.cached_input_tokens
98            + usage.cache_creation_input_tokens
99            + usage.output_tokens;
100        Some(usage)
101    }
102}
103
104#[derive(Default)]
105struct ToolCallState {
106    name: String,
107    id: String,
108    internal_call_id: String,
109    input_json: String,
110}
111
112#[derive(Default)]
113struct ThinkingState {
114    thinking: String,
115    signature: String,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct StreamingCompletionResponse {
120    pub usage: PartialUsage,
121}
122
123impl GetTokenUsage for StreamingCompletionResponse {
124    fn token_usage(&self) -> Option<crate::completion::Usage> {
125        let mut usage = crate::completion::Usage::new();
126        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
127        usage.output_tokens = self.usage.output_tokens as u64;
128        usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
129        usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
130        usage.total_tokens = usage.input_tokens
131            + usage.cached_input_tokens
132            + usage.cache_creation_input_tokens
133            + usage.output_tokens;
134
135        Some(usage)
136    }
137}
138
139impl<Ext, T> GenericCompletionModel<Ext, T>
140where
141    T: HttpClientExt + Clone + Default + 'static,
142    Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
143{
144    pub(crate) async fn stream(
145        &self,
146        mut completion_request: CompletionRequest,
147    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
148    {
149        let request_model = completion_request
150            .model
151            .clone()
152            .unwrap_or_else(|| self.model.clone());
153        let span = if tracing::Span::current().is_disabled() {
154            info_span!(
155                target: "rig::completions",
156                "chat_streaming",
157                gen_ai.operation.name = "chat_streaming",
158                gen_ai.provider.name = Ext::PROVIDER_NAME,
159                gen_ai.request.model = &request_model,
160                gen_ai.system_instructions = &completion_request.preamble,
161                gen_ai.response.id = tracing::field::Empty,
162                gen_ai.response.model = &request_model,
163                gen_ai.usage.output_tokens = tracing::field::Empty,
164                gen_ai.usage.input_tokens = tracing::field::Empty,
165                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
166                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
167                gen_ai.input.messages = tracing::field::Empty,
168                gen_ai.output.messages = tracing::field::Empty,
169            )
170        } else {
171            tracing::Span::current()
172        };
173        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
174            tokens
175        } else if let Some(tokens) = self.default_max_tokens {
176            tokens
177        } else {
178            return Err(CompletionError::RequestError(
179                "`max_tokens` must be set for Anthropic".into(),
180            ));
181        };
182
183        let mut full_history = vec![];
184        if let Some(docs) = completion_request.normalized_documents() {
185            full_history.push(docs);
186        }
187        full_history.extend(completion_request.chat_history);
188        let (history_system, full_history) = split_system_messages_from_history(full_history);
189
190        let mut messages = full_history
191            .into_iter()
192            .map(Message::try_from)
193            .collect::<Result<Vec<Message>, _>>()?;
194
195        // Convert system prompt to array format for cache_control support
196        let mut system: Vec<SystemContent> =
197            if let Some(preamble) = completion_request.preamble.as_ref() {
198                if preamble.is_empty() {
199                    vec![]
200                } else {
201                    vec![SystemContent::Text {
202                        text: preamble.clone(),
203                        cache_control: None,
204                    }]
205                }
206            } else {
207                vec![]
208            };
209        system.extend(history_system);
210
211        // Apply cache control breakpoints only if prompt_caching is enabled
212        if self.prompt_caching {
213            apply_cache_control(&mut system, &mut messages);
214        }
215
216        let mut body = json!({
217            "model": request_model,
218            "messages": messages,
219            "max_tokens": max_tokens,
220            "stream": true,
221        });
222
223        // Automatic caching: one top-level field; the API moves the breakpoint automatically.
224        // No beta header is required.
225        if self.automatic_caching {
226            let cc = CacheControl::Ephemeral {
227                ttl: self.automatic_caching_ttl.clone(),
228            };
229            merge_inplace(
230                &mut body,
231                json!({ "cache_control": serde_json::to_value(&cc)? }),
232            );
233        }
234
235        // Add system prompt if non-empty
236        if !system.is_empty() {
237            merge_inplace(&mut body, json!({ "system": system }));
238        }
239
240        if let Some(temperature) = completion_request.temperature {
241            merge_inplace(&mut body, json!({ "temperature": temperature }));
242        }
243
244        let mut additional_params_payload = completion_request
245            .additional_params
246            .take()
247            .unwrap_or(Value::Null);
248        let mut additional_tools =
249            extract_tools_from_additional_params(&mut additional_params_payload)?;
250
251        let mut tools = completion_request
252            .tools
253            .into_iter()
254            .map(|tool| ToolDefinition {
255                name: tool.name,
256                description: Some(tool.description),
257                input_schema: tool.parameters,
258            })
259            .map(serde_json::to_value)
260            .collect::<Result<Vec<_>, _>>()?;
261        tools.append(&mut additional_tools);
262
263        if !tools.is_empty() {
264            merge_inplace(
265                &mut body,
266                json!({
267                    "tools": tools,
268                    "tool_choice": ToolChoice::Auto,
269                }),
270            );
271        }
272
273        if !additional_params_payload.is_null() {
274            merge_inplace(&mut body, additional_params_payload)
275        }
276
277        if enabled!(Level::TRACE) {
278            tracing::trace!(
279                target: "rig::completions",
280                "Anthropic completion request: {}",
281                serde_json::to_string_pretty(&body)?
282            );
283        }
284
285        let body: Vec<u8> = serde_json::to_vec(&body)?;
286
287        let req = self
288            .client
289            .post("/v1/messages")?
290            .body(body)
291            .map_err(http_client::Error::Protocol)?;
292
293        let stream = GenericEventSource::new(self.client.clone(), req);
294
295        // Use our SSE decoder to directly handle Server-Sent Events format
296        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
297            let mut current_tool_call: Option<ToolCallState> = None;
298            let mut current_thinking: Option<ThinkingState> = None;
299            let mut sse_stream = Box::pin(stream);
300            let mut input_tokens = 0;
301            let mut final_usage = None;
302
303            let mut text_content = String::new();
304
305            while let Some(sse_result) = sse_stream.next().await {
306                match sse_result {
307                    Ok(Event::Open) => {}
308                    Ok(Event::Message(sse)) => {
309                        // Parse the SSE data as a StreamingEvent
310                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
311                            Ok(event) => {
312                                match &event {
313                                    StreamingEvent::MessageStart { message } => {
314                                        input_tokens = message.usage.input_tokens;
315
316                                        let span = tracing::Span::current();
317                                        span.record("gen_ai.response.id", &message.id);
318                                        span.record("gen_ai.response.model", &message.model);
319                                    },
320                                    StreamingEvent::MessageDelta { delta, usage } => {
321                                        if delta.stop_reason.is_some() {
322                                            // cache_creation_input_tokens and cache_read_input_tokens
323                                            // are cumulative totals on message_delta.usage per the
324                                            // Anthropic streaming API spec — use them directly.
325                                            let usage = PartialUsage {
326                                                 output_tokens: usage.output_tokens,
327                                                 input_tokens: usize::try_from(input_tokens).ok(),
328                                                 cache_creation_input_tokens: usage.cache_creation_input_tokens,
329                                                 cache_read_input_tokens: usage.cache_read_input_tokens
330                                            };
331
332                                            let span = tracing::Span::current();
333                                            span.record_token_usage(&usage);
334                                            final_usage = Some(usage);
335                                            break;
336                                        }
337                                    }
338                                    _ => {}
339                                }
340
341                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
342                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
343                                        text_content += text;
344                                    }
345                                    yield result;
346                                }
347                            },
348                            Err(e) => {
349                                if !sse.data.trim().is_empty() {
350                                    yield Err(CompletionError::ResponseError(
351                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
352                                    ));
353                                }
354                            }
355                        }
356                    },
357                    Err(e) => {
358                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
359                        break;
360                    }
361                }
362            }
363
364            // Ensure event source is closed when stream ends
365            sse_stream.close();
366
367            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
368                usage: final_usage.unwrap_or_default()
369            }))
370        }.instrument(span));
371
372        Ok(streaming::StreamingCompletionResponse::stream(stream))
373    }
374}
375
376fn extract_tools_from_additional_params(
377    additional_params: &mut Value,
378) -> Result<Vec<Value>, CompletionError> {
379    if let Some(map) = additional_params.as_object_mut()
380        && let Some(raw_tools) = map.remove("tools")
381    {
382        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
383            CompletionError::RequestError(
384                format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
385            )
386        });
387    }
388
389    Ok(Vec::new())
390}
391
392fn handle_event(
393    event: &StreamingEvent,
394    current_tool_call: &mut Option<ToolCallState>,
395    current_thinking: &mut Option<ThinkingState>,
396) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
397    match event {
398        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
399            ContentDelta::TextDelta { text } => {
400                if current_tool_call.is_none() {
401                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
402                }
403                None
404            }
405            ContentDelta::InputJsonDelta { partial_json } => {
406                if let Some(tool_call) = current_tool_call {
407                    tool_call.input_json.push_str(partial_json);
408                    // Emit the delta so UI can show progress
409                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
410                        id: tool_call.id.clone(),
411                        internal_call_id: tool_call.internal_call_id.clone(),
412                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
413                    }));
414                }
415                None
416            }
417            ContentDelta::ThinkingDelta { thinking } => {
418                current_thinking
419                    .get_or_insert_with(ThinkingState::default)
420                    .thinking
421                    .push_str(thinking);
422
423                Some(Ok(RawStreamingChoice::ReasoningDelta {
424                    id: None,
425                    reasoning: thinking.clone(),
426                }))
427            }
428            ContentDelta::SignatureDelta { signature } => {
429                current_thinking
430                    .get_or_insert_with(ThinkingState::default)
431                    .signature
432                    .push_str(signature);
433
434                // Don't yield signature chunks, they will be included in the final Reasoning
435                None
436            }
437        },
438        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
439            Content::ToolUse { id, name, .. } => {
440                let internal_call_id = nanoid::nanoid!();
441                *current_tool_call = Some(ToolCallState {
442                    name: name.clone(),
443                    id: id.clone(),
444                    internal_call_id: internal_call_id.clone(),
445                    input_json: String::new(),
446                });
447                Some(Ok(RawStreamingChoice::ToolCallDelta {
448                    id: id.clone(),
449                    internal_call_id,
450                    content: ToolCallDeltaContent::Name(name.clone()),
451                }))
452            }
453            Content::Thinking { .. } => {
454                *current_thinking = Some(ThinkingState::default());
455                None
456            }
457            Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
458                id: None,
459                content: ReasoningContent::Redacted { data: data.clone() },
460            })),
461            // Handle other content types - they don't need special handling
462            _ => None,
463        },
464        StreamingEvent::ContentBlockStop { .. } => {
465            if let Some(thinking_state) = Option::take(current_thinking)
466                && !thinking_state.thinking.is_empty()
467            {
468                let signature = if thinking_state.signature.is_empty() {
469                    None
470                } else {
471                    Some(thinking_state.signature)
472                };
473
474                return Some(Ok(RawStreamingChoice::Reasoning {
475                    id: None,
476                    content: ReasoningContent::Text {
477                        text: thinking_state.thinking,
478                        signature,
479                    },
480                }));
481            }
482
483            if let Some(tool_call) = Option::take(current_tool_call) {
484                let json_str = if tool_call.input_json.is_empty() {
485                    "{}"
486                } else {
487                    &tool_call.input_json
488                };
489                match serde_json::from_str(json_str) {
490                    Ok(json_value) => {
491                        let raw_tool_call =
492                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
493                                .with_internal_call_id(tool_call.internal_call_id);
494                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
495                    }
496                    Err(e) => Some(Err(CompletionError::from(e))),
497                }
498            } else {
499                None
500            }
501        }
502        // Ignore other event types or handle as needed
503        StreamingEvent::MessageStart { .. }
504        | StreamingEvent::MessageDelta { .. }
505        | StreamingEvent::MessageStop
506        | StreamingEvent::Ping
507        | StreamingEvent::Unknown => None,
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_thinking_delta_deserialization() {
517        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
518        let delta: ContentDelta = serde_json::from_str(json).unwrap();
519
520        match delta {
521            ContentDelta::ThinkingDelta { thinking } => {
522                assert_eq!(thinking, "Let me think about this...");
523            }
524            _ => panic!("Expected ThinkingDelta variant"),
525        }
526    }
527
528    #[test]
529    fn test_signature_delta_deserialization() {
530        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
531        let delta: ContentDelta = serde_json::from_str(json).unwrap();
532
533        match delta {
534            ContentDelta::SignatureDelta { signature } => {
535                assert_eq!(signature, "abc123def456");
536            }
537            _ => panic!("Expected SignatureDelta variant"),
538        }
539    }
540
541    #[test]
542    fn test_thinking_delta_streaming_event_deserialization() {
543        let json = r#"{
544            "type": "content_block_delta",
545            "index": 0,
546            "delta": {
547                "type": "thinking_delta",
548                "thinking": "First, I need to understand the problem."
549            }
550        }"#;
551
552        let event: StreamingEvent = serde_json::from_str(json).unwrap();
553
554        match event {
555            StreamingEvent::ContentBlockDelta { index, delta } => {
556                assert_eq!(index, 0);
557                match delta {
558                    ContentDelta::ThinkingDelta { thinking } => {
559                        assert_eq!(thinking, "First, I need to understand the problem.");
560                    }
561                    _ => panic!("Expected ThinkingDelta"),
562                }
563            }
564            _ => panic!("Expected ContentBlockDelta event"),
565        }
566    }
567
568    #[test]
569    fn test_signature_delta_streaming_event_deserialization() {
570        let json = r#"{
571            "type": "content_block_delta",
572            "index": 0,
573            "delta": {
574                "type": "signature_delta",
575                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
576            }
577        }"#;
578
579        let event: StreamingEvent = serde_json::from_str(json).unwrap();
580
581        match event {
582            StreamingEvent::ContentBlockDelta { index, delta } => {
583                assert_eq!(index, 0);
584                match delta {
585                    ContentDelta::SignatureDelta { signature } => {
586                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
587                    }
588                    _ => panic!("Expected SignatureDelta"),
589                }
590            }
591            _ => panic!("Expected ContentBlockDelta event"),
592        }
593    }
594
595    #[test]
596    fn test_handle_thinking_delta_event() {
597        let event = StreamingEvent::ContentBlockDelta {
598            index: 0,
599            delta: ContentDelta::ThinkingDelta {
600                thinking: "Analyzing the request...".to_string(),
601            },
602        };
603
604        let mut tool_call_state = None;
605        let mut thinking_state = None;
606        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
607
608        assert!(result.is_some());
609        let choice = result.unwrap().unwrap();
610
611        match choice {
612            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
613                assert_eq!(id, None);
614                assert_eq!(reasoning, "Analyzing the request...");
615            }
616            _ => panic!("Expected ReasoningDelta choice"),
617        }
618
619        // Verify thinking state was updated
620        assert!(thinking_state.is_some());
621        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
622    }
623
624    #[test]
625    fn test_handle_signature_delta_event() {
626        let event = StreamingEvent::ContentBlockDelta {
627            index: 0,
628            delta: ContentDelta::SignatureDelta {
629                signature: "test_signature".to_string(),
630            },
631        };
632
633        let mut tool_call_state = None;
634        let mut thinking_state = None;
635        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
636
637        // SignatureDelta should not yield anything (returns None)
638        assert!(result.is_none());
639
640        // But signature should be captured in thinking state
641        assert!(thinking_state.is_some());
642        assert_eq!(thinking_state.unwrap().signature, "test_signature");
643    }
644
645    #[test]
646    fn test_handle_redacted_thinking_content_block_start_event() {
647        let event = StreamingEvent::ContentBlockStart {
648            index: 0,
649            content_block: Content::RedactedThinking {
650                data: "redacted_blob".to_string(),
651            },
652        };
653        let mut tool_call_state = None;
654        let mut thinking_state = None;
655        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
656
657        assert!(result.is_some());
658        match result.unwrap().unwrap() {
659            RawStreamingChoice::Reasoning {
660                content: ReasoningContent::Redacted { data },
661                ..
662            } => {
663                assert_eq!(data, "redacted_blob");
664            }
665            _ => panic!("Expected Redacted reasoning chunk"),
666        }
667    }
668
669    #[test]
670    fn test_handle_text_delta_event() {
671        let event = StreamingEvent::ContentBlockDelta {
672            index: 0,
673            delta: ContentDelta::TextDelta {
674                text: "Hello, world!".to_string(),
675            },
676        };
677
678        let mut tool_call_state = None;
679        let mut thinking_state = None;
680        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
681
682        assert!(result.is_some());
683        let choice = result.unwrap().unwrap();
684
685        match choice {
686            RawStreamingChoice::Message(text) => {
687                assert_eq!(text, "Hello, world!");
688            }
689            _ => panic!("Expected Message choice"),
690        }
691    }
692
693    #[test]
694    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
695        // Thinking deltas should still be processed even if a tool call is in progress
696        let event = StreamingEvent::ContentBlockDelta {
697            index: 0,
698            delta: ContentDelta::ThinkingDelta {
699                thinking: "Thinking while tool is active...".to_string(),
700            },
701        };
702
703        let mut tool_call_state = Some(ToolCallState {
704            name: "test_tool".to_string(),
705            id: "tool_123".to_string(),
706            internal_call_id: nanoid::nanoid!(),
707            input_json: String::new(),
708        });
709        let mut thinking_state = None;
710
711        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
712
713        assert!(result.is_some());
714        let choice = result.unwrap().unwrap();
715
716        match choice {
717            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
718                assert_eq!(reasoning, "Thinking while tool is active...");
719            }
720            _ => panic!("Expected ReasoningDelta choice"),
721        }
722
723        // Tool call state should remain unchanged
724        assert!(tool_call_state.is_some());
725    }
726
727    #[test]
728    fn test_handle_input_json_delta_event() {
729        let event = StreamingEvent::ContentBlockDelta {
730            index: 0,
731            delta: ContentDelta::InputJsonDelta {
732                partial_json: "{\"arg\":\"value".to_string(),
733            },
734        };
735
736        let mut tool_call_state = Some(ToolCallState {
737            name: "test_tool".to_string(),
738            id: "tool_123".to_string(),
739            internal_call_id: nanoid::nanoid!(),
740            input_json: String::new(),
741        });
742        let mut thinking_state = None;
743
744        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
745
746        // Should emit a ToolCallDelta
747        assert!(result.is_some());
748        let choice = result.unwrap().unwrap();
749
750        match choice {
751            RawStreamingChoice::ToolCallDelta {
752                id,
753                internal_call_id: _,
754                content,
755            } => {
756                assert_eq!(id, "tool_123");
757                match content {
758                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
759                    _ => panic!("Expected Delta content"),
760                }
761            }
762            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
763        }
764
765        // Verify the input_json was accumulated
766        assert!(tool_call_state.is_some());
767        let state = tool_call_state.unwrap();
768        assert_eq!(state.input_json, "{\"arg\":\"value");
769    }
770
771    #[test]
772    fn test_tool_call_accumulation_with_multiple_deltas() {
773        let mut tool_call_state = Some(ToolCallState {
774            name: "test_tool".to_string(),
775            id: "tool_123".to_string(),
776            internal_call_id: nanoid::nanoid!(),
777            input_json: String::new(),
778        });
779        let mut thinking_state = None;
780
781        // First delta
782        let event1 = StreamingEvent::ContentBlockDelta {
783            index: 0,
784            delta: ContentDelta::InputJsonDelta {
785                partial_json: "{\"location\":".to_string(),
786            },
787        };
788        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
789        assert!(result1.is_some());
790
791        // Second delta
792        let event2 = StreamingEvent::ContentBlockDelta {
793            index: 0,
794            delta: ContentDelta::InputJsonDelta {
795                partial_json: "\"Paris\",".to_string(),
796            },
797        };
798        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
799        assert!(result2.is_some());
800
801        // Third delta
802        let event3 = StreamingEvent::ContentBlockDelta {
803            index: 0,
804            delta: ContentDelta::InputJsonDelta {
805                partial_json: "\"temp\":\"20C\"}".to_string(),
806            },
807        };
808        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
809        assert!(result3.is_some());
810
811        // Verify accumulated JSON
812        assert!(tool_call_state.is_some());
813        let state = tool_call_state.as_ref().unwrap();
814        assert_eq!(
815            state.input_json,
816            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
817        );
818
819        // Final ContentBlockStop should emit complete tool call
820        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
821        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
822        assert!(final_result.is_some());
823
824        match final_result.unwrap().unwrap() {
825            RawStreamingChoice::ToolCall(RawStreamingToolCall {
826                id,
827                name,
828                arguments,
829                ..
830            }) => {
831                assert_eq!(id, "tool_123");
832                assert_eq!(name, "test_tool");
833                assert_eq!(
834                    arguments.get("location").unwrap().as_str().unwrap(),
835                    "Paris"
836                );
837                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
838            }
839            other => panic!("Expected ToolCall, got {:?}", other),
840        }
841
842        // Tool call state should be taken
843        assert!(tool_call_state.is_none());
844    }
845}