Skip to main content

rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CacheControl, CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition,
10    Usage, apply_cache_control, split_system_messages_from_history,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25    MessageStart {
26        message: MessageStart,
27    },
28    ContentBlockStart {
29        index: usize,
30        content_block: Content,
31    },
32    ContentBlockDelta {
33        index: usize,
34        delta: ContentDelta,
35    },
36    ContentBlockStop {
37        index: usize,
38    },
39    MessageDelta {
40        delta: MessageDelta,
41        usage: PartialUsage,
42    },
43    MessageStop,
44    Ping,
45    #[serde(other)]
46    Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51    pub id: String,
52    pub role: String,
53    pub content: Vec<Content>,
54    pub model: String,
55    pub stop_reason: Option<String>,
56    pub stop_sequence: Option<String>,
57    pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63    TextDelta { text: String },
64    InputJsonDelta { partial_json: String },
65    ThinkingDelta { thinking: String },
66    SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71    pub stop_reason: Option<String>,
72    pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77    pub output_tokens: usize,
78    #[serde(default)]
79    pub input_tokens: Option<usize>,
80    #[serde(default)]
81    pub cache_creation_input_tokens: Option<u64>,
82    #[serde(default)]
83    pub cache_read_input_tokens: Option<u64>,
84}
85
86impl GetTokenUsage for PartialUsage {
87    fn token_usage(&self) -> Option<crate::completion::Usage> {
88        let mut usage = crate::completion::Usage::new();
89
90        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
91        usage.output_tokens = self.output_tokens as u64;
92        usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
93        usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
94        usage.total_tokens = usage.input_tokens
95            + usage.cached_input_tokens
96            + usage.cache_creation_input_tokens
97            + usage.output_tokens;
98        Some(usage)
99    }
100}
101
102#[derive(Default)]
103struct ToolCallState {
104    name: String,
105    id: String,
106    internal_call_id: String,
107    input_json: String,
108}
109
110#[derive(Default)]
111struct ThinkingState {
112    thinking: String,
113    signature: String,
114}
115
116#[derive(Clone, Debug, Deserialize, Serialize)]
117pub struct StreamingCompletionResponse {
118    pub usage: PartialUsage,
119}
120
121impl GetTokenUsage for StreamingCompletionResponse {
122    fn token_usage(&self) -> Option<crate::completion::Usage> {
123        let mut usage = crate::completion::Usage::new();
124        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
125        usage.output_tokens = self.usage.output_tokens as u64;
126        usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
127        usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
128        usage.total_tokens = usage.input_tokens
129            + usage.cached_input_tokens
130            + usage.cache_creation_input_tokens
131            + usage.output_tokens;
132
133        Some(usage)
134    }
135}
136
137impl<T> CompletionModel<T>
138where
139    T: HttpClientExt + Clone + Default + 'static,
140{
141    pub(crate) async fn stream(
142        &self,
143        mut completion_request: CompletionRequest,
144    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
145    {
146        let request_model = completion_request
147            .model
148            .clone()
149            .unwrap_or_else(|| self.model.clone());
150        let span = if tracing::Span::current().is_disabled() {
151            info_span!(
152                target: "rig::completions",
153                "chat_streaming",
154                gen_ai.operation.name = "chat_streaming",
155                gen_ai.provider.name = "anthropic",
156                gen_ai.request.model = &request_model,
157                gen_ai.system_instructions = &completion_request.preamble,
158                gen_ai.response.id = tracing::field::Empty,
159                gen_ai.response.model = &request_model,
160                gen_ai.usage.output_tokens = tracing::field::Empty,
161                gen_ai.usage.input_tokens = tracing::field::Empty,
162                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
163                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
164                gen_ai.input.messages = tracing::field::Empty,
165                gen_ai.output.messages = tracing::field::Empty,
166            )
167        } else {
168            tracing::Span::current()
169        };
170        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
171            tokens
172        } else if let Some(tokens) = self.default_max_tokens {
173            tokens
174        } else {
175            return Err(CompletionError::RequestError(
176                "`max_tokens` must be set for Anthropic".into(),
177            ));
178        };
179
180        let mut full_history = vec![];
181        if let Some(docs) = completion_request.normalized_documents() {
182            full_history.push(docs);
183        }
184        full_history.extend(completion_request.chat_history);
185        let (history_system, full_history) = split_system_messages_from_history(full_history);
186
187        let mut messages = full_history
188            .into_iter()
189            .map(Message::try_from)
190            .collect::<Result<Vec<Message>, _>>()?;
191
192        // Convert system prompt to array format for cache_control support
193        let mut system: Vec<SystemContent> =
194            if let Some(preamble) = completion_request.preamble.as_ref() {
195                if preamble.is_empty() {
196                    vec![]
197                } else {
198                    vec![SystemContent::Text {
199                        text: preamble.clone(),
200                        cache_control: None,
201                    }]
202                }
203            } else {
204                vec![]
205            };
206        system.extend(history_system);
207
208        // Apply cache control breakpoints only if prompt_caching is enabled
209        if self.prompt_caching {
210            apply_cache_control(&mut system, &mut messages);
211        }
212
213        let mut body = json!({
214            "model": request_model,
215            "messages": messages,
216            "max_tokens": max_tokens,
217            "stream": true,
218        });
219
220        // Automatic caching: one top-level field; the API moves the breakpoint automatically.
221        // No beta header is required.
222        if self.automatic_caching {
223            let cc = CacheControl::Ephemeral {
224                ttl: self.automatic_caching_ttl.clone(),
225            };
226            merge_inplace(
227                &mut body,
228                json!({ "cache_control": serde_json::to_value(&cc)? }),
229            );
230        }
231
232        // Add system prompt if non-empty
233        if !system.is_empty() {
234            merge_inplace(&mut body, json!({ "system": system }));
235        }
236
237        if let Some(temperature) = completion_request.temperature {
238            merge_inplace(&mut body, json!({ "temperature": temperature }));
239        }
240
241        let mut additional_params_payload = completion_request
242            .additional_params
243            .take()
244            .unwrap_or(Value::Null);
245        let mut additional_tools =
246            extract_tools_from_additional_params(&mut additional_params_payload)?;
247
248        let mut tools = completion_request
249            .tools
250            .into_iter()
251            .map(|tool| ToolDefinition {
252                name: tool.name,
253                description: Some(tool.description),
254                input_schema: tool.parameters,
255            })
256            .map(serde_json::to_value)
257            .collect::<Result<Vec<_>, _>>()?;
258        tools.append(&mut additional_tools);
259
260        if !tools.is_empty() {
261            merge_inplace(
262                &mut body,
263                json!({
264                    "tools": tools,
265                    "tool_choice": ToolChoice::Auto,
266                }),
267            );
268        }
269
270        if !additional_params_payload.is_null() {
271            merge_inplace(&mut body, additional_params_payload)
272        }
273
274        if enabled!(Level::TRACE) {
275            tracing::trace!(
276                target: "rig::completions",
277                "Anthropic completion request: {}",
278                serde_json::to_string_pretty(&body)?
279            );
280        }
281
282        let body: Vec<u8> = serde_json::to_vec(&body)?;
283
284        let req = self
285            .client
286            .post("/v1/messages")?
287            .body(body)
288            .map_err(http_client::Error::Protocol)?;
289
290        let stream = GenericEventSource::new(self.client.clone(), req);
291
292        // Use our SSE decoder to directly handle Server-Sent Events format
293        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
294            let mut current_tool_call: Option<ToolCallState> = None;
295            let mut current_thinking: Option<ThinkingState> = None;
296            let mut sse_stream = Box::pin(stream);
297            let mut input_tokens = 0;
298            let mut final_usage = None;
299
300            let mut text_content = String::new();
301
302            while let Some(sse_result) = sse_stream.next().await {
303                match sse_result {
304                    Ok(Event::Open) => {}
305                    Ok(Event::Message(sse)) => {
306                        // Parse the SSE data as a StreamingEvent
307                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
308                            Ok(event) => {
309                                match &event {
310                                    StreamingEvent::MessageStart { message } => {
311                                        input_tokens = message.usage.input_tokens;
312
313                                        let span = tracing::Span::current();
314                                        span.record("gen_ai.response.id", &message.id);
315                                        span.record("gen_ai.response.model_name", &message.model);
316                                    },
317                                    StreamingEvent::MessageDelta { delta, usage } => {
318                                        if delta.stop_reason.is_some() {
319                                            // cache_creation_input_tokens and cache_read_input_tokens
320                                            // are cumulative totals on message_delta.usage per the
321                                            // Anthropic streaming API spec — use them directly.
322                                            let usage = PartialUsage {
323                                                 output_tokens: usage.output_tokens,
324                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
325                                                 cache_creation_input_tokens: usage.cache_creation_input_tokens,
326                                                 cache_read_input_tokens: usage.cache_read_input_tokens
327                                            };
328
329                                            let span = tracing::Span::current();
330                                            span.record_token_usage(&usage);
331                                            final_usage = Some(usage);
332                                            break;
333                                        }
334                                    }
335                                    _ => {}
336                                }
337
338                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
339                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
340                                        text_content += text;
341                                    }
342                                    yield result;
343                                }
344                            },
345                            Err(e) => {
346                                if !sse.data.trim().is_empty() {
347                                    yield Err(CompletionError::ResponseError(
348                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
349                                    ));
350                                }
351                            }
352                        }
353                    },
354                    Err(e) => {
355                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
356                        break;
357                    }
358                }
359            }
360
361            // Ensure event source is closed when stream ends
362            sse_stream.close();
363
364            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
365                usage: final_usage.unwrap_or_default()
366            }))
367        }.instrument(span));
368
369        Ok(streaming::StreamingCompletionResponse::stream(stream))
370    }
371}
372
373fn extract_tools_from_additional_params(
374    additional_params: &mut Value,
375) -> Result<Vec<Value>, CompletionError> {
376    if let Some(map) = additional_params.as_object_mut()
377        && let Some(raw_tools) = map.remove("tools")
378    {
379        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
380            CompletionError::RequestError(
381                format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
382            )
383        });
384    }
385
386    Ok(Vec::new())
387}
388
389fn handle_event(
390    event: &StreamingEvent,
391    current_tool_call: &mut Option<ToolCallState>,
392    current_thinking: &mut Option<ThinkingState>,
393) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
394    match event {
395        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
396            ContentDelta::TextDelta { text } => {
397                if current_tool_call.is_none() {
398                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
399                }
400                None
401            }
402            ContentDelta::InputJsonDelta { partial_json } => {
403                if let Some(tool_call) = current_tool_call {
404                    tool_call.input_json.push_str(partial_json);
405                    // Emit the delta so UI can show progress
406                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
407                        id: tool_call.id.clone(),
408                        internal_call_id: tool_call.internal_call_id.clone(),
409                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
410                    }));
411                }
412                None
413            }
414            ContentDelta::ThinkingDelta { thinking } => {
415                current_thinking
416                    .get_or_insert_with(ThinkingState::default)
417                    .thinking
418                    .push_str(thinking);
419
420                Some(Ok(RawStreamingChoice::ReasoningDelta {
421                    id: None,
422                    reasoning: thinking.clone(),
423                }))
424            }
425            ContentDelta::SignatureDelta { signature } => {
426                current_thinking
427                    .get_or_insert_with(ThinkingState::default)
428                    .signature
429                    .push_str(signature);
430
431                // Don't yield signature chunks, they will be included in the final Reasoning
432                None
433            }
434        },
435        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
436            Content::ToolUse { id, name, .. } => {
437                let internal_call_id = nanoid::nanoid!();
438                *current_tool_call = Some(ToolCallState {
439                    name: name.clone(),
440                    id: id.clone(),
441                    internal_call_id: internal_call_id.clone(),
442                    input_json: String::new(),
443                });
444                Some(Ok(RawStreamingChoice::ToolCallDelta {
445                    id: id.clone(),
446                    internal_call_id,
447                    content: ToolCallDeltaContent::Name(name.clone()),
448                }))
449            }
450            Content::Thinking { .. } => {
451                *current_thinking = Some(ThinkingState::default());
452                None
453            }
454            Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
455                id: None,
456                content: ReasoningContent::Redacted { data: data.clone() },
457            })),
458            // Handle other content types - they don't need special handling
459            _ => None,
460        },
461        StreamingEvent::ContentBlockStop { .. } => {
462            if let Some(thinking_state) = Option::take(current_thinking)
463                && !thinking_state.thinking.is_empty()
464            {
465                let signature = if thinking_state.signature.is_empty() {
466                    None
467                } else {
468                    Some(thinking_state.signature)
469                };
470
471                return Some(Ok(RawStreamingChoice::Reasoning {
472                    id: None,
473                    content: ReasoningContent::Text {
474                        text: thinking_state.thinking,
475                        signature,
476                    },
477                }));
478            }
479
480            if let Some(tool_call) = Option::take(current_tool_call) {
481                let json_str = if tool_call.input_json.is_empty() {
482                    "{}"
483                } else {
484                    &tool_call.input_json
485                };
486                match serde_json::from_str(json_str) {
487                    Ok(json_value) => {
488                        let raw_tool_call =
489                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
490                                .with_internal_call_id(tool_call.internal_call_id);
491                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
492                    }
493                    Err(e) => Some(Err(CompletionError::from(e))),
494                }
495            } else {
496                None
497            }
498        }
499        // Ignore other event types or handle as needed
500        StreamingEvent::MessageStart { .. }
501        | StreamingEvent::MessageDelta { .. }
502        | StreamingEvent::MessageStop
503        | StreamingEvent::Ping
504        | StreamingEvent::Unknown => None,
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_thinking_delta_deserialization() {
514        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
515        let delta: ContentDelta = serde_json::from_str(json).unwrap();
516
517        match delta {
518            ContentDelta::ThinkingDelta { thinking } => {
519                assert_eq!(thinking, "Let me think about this...");
520            }
521            _ => panic!("Expected ThinkingDelta variant"),
522        }
523    }
524
525    #[test]
526    fn test_signature_delta_deserialization() {
527        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
528        let delta: ContentDelta = serde_json::from_str(json).unwrap();
529
530        match delta {
531            ContentDelta::SignatureDelta { signature } => {
532                assert_eq!(signature, "abc123def456");
533            }
534            _ => panic!("Expected SignatureDelta variant"),
535        }
536    }
537
538    #[test]
539    fn test_thinking_delta_streaming_event_deserialization() {
540        let json = r#"{
541            "type": "content_block_delta",
542            "index": 0,
543            "delta": {
544                "type": "thinking_delta",
545                "thinking": "First, I need to understand the problem."
546            }
547        }"#;
548
549        let event: StreamingEvent = serde_json::from_str(json).unwrap();
550
551        match event {
552            StreamingEvent::ContentBlockDelta { index, delta } => {
553                assert_eq!(index, 0);
554                match delta {
555                    ContentDelta::ThinkingDelta { thinking } => {
556                        assert_eq!(thinking, "First, I need to understand the problem.");
557                    }
558                    _ => panic!("Expected ThinkingDelta"),
559                }
560            }
561            _ => panic!("Expected ContentBlockDelta event"),
562        }
563    }
564
565    #[test]
566    fn test_signature_delta_streaming_event_deserialization() {
567        let json = r#"{
568            "type": "content_block_delta",
569            "index": 0,
570            "delta": {
571                "type": "signature_delta",
572                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
573            }
574        }"#;
575
576        let event: StreamingEvent = serde_json::from_str(json).unwrap();
577
578        match event {
579            StreamingEvent::ContentBlockDelta { index, delta } => {
580                assert_eq!(index, 0);
581                match delta {
582                    ContentDelta::SignatureDelta { signature } => {
583                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
584                    }
585                    _ => panic!("Expected SignatureDelta"),
586                }
587            }
588            _ => panic!("Expected ContentBlockDelta event"),
589        }
590    }
591
592    #[test]
593    fn test_handle_thinking_delta_event() {
594        let event = StreamingEvent::ContentBlockDelta {
595            index: 0,
596            delta: ContentDelta::ThinkingDelta {
597                thinking: "Analyzing the request...".to_string(),
598            },
599        };
600
601        let mut tool_call_state = None;
602        let mut thinking_state = None;
603        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
604
605        assert!(result.is_some());
606        let choice = result.unwrap().unwrap();
607
608        match choice {
609            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
610                assert_eq!(id, None);
611                assert_eq!(reasoning, "Analyzing the request...");
612            }
613            _ => panic!("Expected ReasoningDelta choice"),
614        }
615
616        // Verify thinking state was updated
617        assert!(thinking_state.is_some());
618        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
619    }
620
621    #[test]
622    fn test_handle_signature_delta_event() {
623        let event = StreamingEvent::ContentBlockDelta {
624            index: 0,
625            delta: ContentDelta::SignatureDelta {
626                signature: "test_signature".to_string(),
627            },
628        };
629
630        let mut tool_call_state = None;
631        let mut thinking_state = None;
632        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
633
634        // SignatureDelta should not yield anything (returns None)
635        assert!(result.is_none());
636
637        // But signature should be captured in thinking state
638        assert!(thinking_state.is_some());
639        assert_eq!(thinking_state.unwrap().signature, "test_signature");
640    }
641
642    #[test]
643    fn test_handle_redacted_thinking_content_block_start_event() {
644        let event = StreamingEvent::ContentBlockStart {
645            index: 0,
646            content_block: Content::RedactedThinking {
647                data: "redacted_blob".to_string(),
648            },
649        };
650        let mut tool_call_state = None;
651        let mut thinking_state = None;
652        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
653
654        assert!(result.is_some());
655        match result.unwrap().unwrap() {
656            RawStreamingChoice::Reasoning {
657                content: ReasoningContent::Redacted { data },
658                ..
659            } => {
660                assert_eq!(data, "redacted_blob");
661            }
662            _ => panic!("Expected Redacted reasoning chunk"),
663        }
664    }
665
666    #[test]
667    fn test_handle_text_delta_event() {
668        let event = StreamingEvent::ContentBlockDelta {
669            index: 0,
670            delta: ContentDelta::TextDelta {
671                text: "Hello, world!".to_string(),
672            },
673        };
674
675        let mut tool_call_state = None;
676        let mut thinking_state = None;
677        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
678
679        assert!(result.is_some());
680        let choice = result.unwrap().unwrap();
681
682        match choice {
683            RawStreamingChoice::Message(text) => {
684                assert_eq!(text, "Hello, world!");
685            }
686            _ => panic!("Expected Message choice"),
687        }
688    }
689
690    #[test]
691    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
692        // Thinking deltas should still be processed even if a tool call is in progress
693        let event = StreamingEvent::ContentBlockDelta {
694            index: 0,
695            delta: ContentDelta::ThinkingDelta {
696                thinking: "Thinking while tool is active...".to_string(),
697            },
698        };
699
700        let mut tool_call_state = Some(ToolCallState {
701            name: "test_tool".to_string(),
702            id: "tool_123".to_string(),
703            internal_call_id: nanoid::nanoid!(),
704            input_json: String::new(),
705        });
706        let mut thinking_state = None;
707
708        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
709
710        assert!(result.is_some());
711        let choice = result.unwrap().unwrap();
712
713        match choice {
714            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
715                assert_eq!(reasoning, "Thinking while tool is active...");
716            }
717            _ => panic!("Expected ReasoningDelta choice"),
718        }
719
720        // Tool call state should remain unchanged
721        assert!(tool_call_state.is_some());
722    }
723
724    #[test]
725    fn test_handle_input_json_delta_event() {
726        let event = StreamingEvent::ContentBlockDelta {
727            index: 0,
728            delta: ContentDelta::InputJsonDelta {
729                partial_json: "{\"arg\":\"value".to_string(),
730            },
731        };
732
733        let mut tool_call_state = Some(ToolCallState {
734            name: "test_tool".to_string(),
735            id: "tool_123".to_string(),
736            internal_call_id: nanoid::nanoid!(),
737            input_json: String::new(),
738        });
739        let mut thinking_state = None;
740
741        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
742
743        // Should emit a ToolCallDelta
744        assert!(result.is_some());
745        let choice = result.unwrap().unwrap();
746
747        match choice {
748            RawStreamingChoice::ToolCallDelta {
749                id,
750                internal_call_id: _,
751                content,
752            } => {
753                assert_eq!(id, "tool_123");
754                match content {
755                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
756                    _ => panic!("Expected Delta content"),
757                }
758            }
759            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
760        }
761
762        // Verify the input_json was accumulated
763        assert!(tool_call_state.is_some());
764        let state = tool_call_state.unwrap();
765        assert_eq!(state.input_json, "{\"arg\":\"value");
766    }
767
768    #[test]
769    fn test_tool_call_accumulation_with_multiple_deltas() {
770        let mut tool_call_state = Some(ToolCallState {
771            name: "test_tool".to_string(),
772            id: "tool_123".to_string(),
773            internal_call_id: nanoid::nanoid!(),
774            input_json: String::new(),
775        });
776        let mut thinking_state = None;
777
778        // First delta
779        let event1 = StreamingEvent::ContentBlockDelta {
780            index: 0,
781            delta: ContentDelta::InputJsonDelta {
782                partial_json: "{\"location\":".to_string(),
783            },
784        };
785        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
786        assert!(result1.is_some());
787
788        // Second delta
789        let event2 = StreamingEvent::ContentBlockDelta {
790            index: 0,
791            delta: ContentDelta::InputJsonDelta {
792                partial_json: "\"Paris\",".to_string(),
793            },
794        };
795        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
796        assert!(result2.is_some());
797
798        // Third delta
799        let event3 = StreamingEvent::ContentBlockDelta {
800            index: 0,
801            delta: ContentDelta::InputJsonDelta {
802                partial_json: "\"temp\":\"20C\"}".to_string(),
803            },
804        };
805        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
806        assert!(result3.is_some());
807
808        // Verify accumulated JSON
809        assert!(tool_call_state.is_some());
810        let state = tool_call_state.as_ref().unwrap();
811        assert_eq!(
812            state.input_json,
813            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
814        );
815
816        // Final ContentBlockStop should emit complete tool call
817        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
818        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
819        assert!(final_result.is_some());
820
821        match final_result.unwrap().unwrap() {
822            RawStreamingChoice::ToolCall(RawStreamingToolCall {
823                id,
824                name,
825                arguments,
826                ..
827            }) => {
828                assert_eq!(id, "tool_123");
829                assert_eq!(name, "test_tool");
830                assert_eq!(
831                    arguments.get("location").unwrap().as_str().unwrap(),
832                    "Paris"
833                );
834                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
835            }
836            other => panic!("Expected ToolCall, got {:?}", other),
837        }
838
839        // Tool call state should be taken
840        assert!(tool_call_state.is_none());
841    }
842}