Skip to main content

rig_core/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    AnthropicCompatibleProvider, CacheTtl, Content, GenericCompletionModel, Message, SystemContent,
10    ToolChoice, Usage, apply_prompt_cache_control, build_tool_definitions,
11    resolve_top_level_cache_control, split_system_messages_from_history,
12    supports_mid_conversation_system_messages,
13};
14use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
15use crate::http_client::sse::{Event, GenericEventSource};
16use crate::http_client::{self, HttpClientExt};
17use crate::json_utils::merge_inplace;
18use crate::message::ReasoningContent;
19use crate::streaming::{
20    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
21};
22use crate::telemetry::SpanCombinator;
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use std::collections::HashMap;
25
26fn create_streaming_request_body(
27    request_model: String,
28    completion_request: &mut CompletionRequest,
29    max_tokens: u64,
30    prompt_caching: bool,
31    automatic_caching: bool,
32    automatic_caching_ttl: Option<CacheTtl>,
33) -> Result<Value, CompletionError> {
34    let chat_history = completion_request.chat_history_with_documents();
35    let (history_system, chat_history) = split_system_messages_from_history(
36        chat_history,
37        supports_mid_conversation_system_messages(&request_model),
38    );
39    let mut full_history = vec![];
40    full_history.extend(chat_history);
41
42    let mut messages = full_history
43        .into_iter()
44        .map(Message::try_from)
45        .collect::<Result<Vec<Message>, _>>()?;
46
47    // Convert system prompt to array format for cache_control support.
48    let mut system: Vec<SystemContent> =
49        if let Some(preamble) = completion_request.preamble.as_ref() {
50            if preamble.is_empty() {
51                vec![]
52            } else {
53                vec![SystemContent::Text {
54                    text: preamble.clone(),
55                    cache_control: None,
56                }]
57            }
58        } else {
59            vec![]
60        };
61    system.extend(history_system);
62
63    let mut additional_params_payload = completion_request
64        .additional_params
65        .take()
66        .unwrap_or(Value::Null);
67    let top_level_cache_control = resolve_top_level_cache_control(
68        automatic_caching,
69        automatic_caching_ttl,
70        &mut additional_params_payload,
71    )?;
72    let mut tools = build_tool_definitions(
73        std::mem::take(&mut completion_request.tools),
74        &mut additional_params_payload,
75    )?;
76
77    apply_prompt_cache_control(
78        &mut system,
79        &mut messages,
80        &mut tools,
81        prompt_caching,
82        top_level_cache_control.as_ref(),
83    )?;
84
85    let mut body = json!({
86        "model": request_model,
87        "messages": messages,
88        "max_tokens": max_tokens,
89        "stream": true,
90    });
91
92    // Automatic caching: one top-level field; the API moves the breakpoint automatically.
93    // No beta header is required.
94    if let Some(cache_control) = top_level_cache_control {
95        merge_inplace(
96            &mut body,
97            json!({ "cache_control": serde_json::to_value(&cache_control)? }),
98        );
99    }
100
101    // Add system prompt if non-empty.
102    if !system.is_empty() {
103        merge_inplace(&mut body, json!({ "system": system }));
104    }
105
106    if let Some(temperature) = completion_request.temperature {
107        merge_inplace(&mut body, json!({ "temperature": temperature }));
108    }
109
110    if !tools.is_empty() {
111        merge_inplace(
112            &mut body,
113            json!({
114                "tools": tools,
115                "tool_choice": ToolChoice::Auto,
116            }),
117        );
118    }
119
120    if !additional_params_payload.is_null() {
121        merge_inplace(&mut body, additional_params_payload)
122    }
123
124    Ok(body)
125}
126
127#[derive(Debug, Deserialize)]
128#[serde(tag = "type", rename_all = "snake_case")]
129pub enum StreamingEvent {
130    MessageStart {
131        message: MessageStart,
132    },
133    ContentBlockStart {
134        index: usize,
135        content_block: Content,
136    },
137    ContentBlockDelta {
138        index: usize,
139        delta: ContentDelta,
140    },
141    ContentBlockStop {
142        index: usize,
143    },
144    MessageDelta {
145        delta: MessageDelta,
146        usage: PartialUsage,
147    },
148    MessageStop,
149    Ping,
150    #[serde(other)]
151    Unknown,
152}
153
154#[derive(Debug, Deserialize)]
155pub struct MessageStart {
156    pub id: String,
157    pub role: String,
158    pub content: Vec<Content>,
159    pub model: String,
160    pub stop_reason: Option<String>,
161    pub stop_sequence: Option<String>,
162    pub usage: Usage,
163}
164
165#[derive(Debug, Deserialize)]
166#[serde(tag = "type", rename_all = "snake_case")]
167pub enum ContentDelta {
168    TextDelta {
169        text: String,
170    },
171    InputJsonDelta {
172        partial_json: String,
173    },
174    ThinkingDelta {
175        thinking: String,
176    },
177    SignatureDelta {
178        signature: String,
179    },
180    CitationsDelta {
181        citation: super::completion::Citation,
182    },
183    /// Forward-compatibility fallback. Any delta type Anthropic adds in the
184    /// future that this crate does not yet model deserializes here so the
185    /// surrounding [`StreamingEvent`] still parses.
186    #[serde(other)]
187    Unknown,
188}
189
190#[derive(Debug, Deserialize)]
191pub struct MessageDelta {
192    pub stop_reason: Option<String>,
193    pub stop_sequence: Option<String>,
194}
195
196#[derive(Debug, Deserialize, Clone, Serialize, Default)]
197pub struct PartialUsage {
198    pub output_tokens: usize,
199    #[serde(default)]
200    pub input_tokens: Option<usize>,
201    #[serde(default)]
202    pub cache_creation_input_tokens: Option<u64>,
203    #[serde(default)]
204    pub cache_read_input_tokens: Option<u64>,
205}
206
207impl GetTokenUsage for PartialUsage {
208    fn token_usage(&self) -> crate::completion::Usage {
209        let mut usage = crate::completion::Usage::new();
210
211        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
212        usage.output_tokens = self.output_tokens as u64;
213        usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
214        usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
215        usage.total_tokens = usage.input_tokens
216            + usage.cached_input_tokens
217            + usage.cache_creation_input_tokens
218            + usage.output_tokens;
219        usage
220    }
221}
222
223#[derive(Default)]
224struct ToolCallState {
225    name: String,
226    id: String,
227    internal_call_id: String,
228    input_json: String,
229}
230
231struct ServerToolUseState {
232    name: String,
233    id: String,
234    initial_input: Value,
235    input_json: String,
236}
237
238#[derive(Default)]
239struct ThinkingState {
240    thinking: String,
241    signature: String,
242}
243
244#[derive(Clone, Debug, Deserialize, Serialize)]
245pub struct StreamingCompletionResponse {
246    pub usage: PartialUsage,
247}
248
249impl GetTokenUsage for StreamingCompletionResponse {
250    fn token_usage(&self) -> crate::completion::Usage {
251        let mut usage = crate::completion::Usage::new();
252        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
253        usage.output_tokens = self.usage.output_tokens as u64;
254        usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
255        usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
256        usage.total_tokens = usage.input_tokens
257            + usage.cached_input_tokens
258            + usage.cache_creation_input_tokens
259            + usage.output_tokens;
260
261        usage
262    }
263}
264
265impl<Ext, T> GenericCompletionModel<Ext, T>
266where
267    T: HttpClientExt + Clone + Default + 'static,
268    Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
269{
270    pub(crate) async fn stream(
271        &self,
272        mut completion_request: CompletionRequest,
273    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
274    {
275        let request_model = completion_request
276            .model
277            .clone()
278            .unwrap_or_else(|| self.model.clone());
279        let span = if tracing::Span::current().is_disabled() {
280            info_span!(
281                target: "rig::completions",
282                "chat_streaming",
283                gen_ai.operation.name = "chat_streaming",
284                gen_ai.provider.name = Ext::PROVIDER_NAME,
285                gen_ai.request.model = &request_model,
286                gen_ai.system_instructions = &completion_request.preamble,
287                gen_ai.response.id = tracing::field::Empty,
288                gen_ai.response.model = &request_model,
289                gen_ai.usage.output_tokens = tracing::field::Empty,
290                gen_ai.usage.input_tokens = tracing::field::Empty,
291                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
292                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
293                gen_ai.input.messages = tracing::field::Empty,
294                gen_ai.output.messages = tracing::field::Empty,
295            )
296        } else {
297            tracing::Span::current()
298        };
299        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
300            tokens
301        } else if let Some(tokens) = self.default_max_tokens {
302            tokens
303        } else {
304            return Err(CompletionError::RequestError(
305                "`max_tokens` must be set for Anthropic".into(),
306            ));
307        };
308
309        let body = create_streaming_request_body(
310            request_model,
311            &mut completion_request,
312            max_tokens,
313            self.prompt_caching,
314            self.automatic_caching,
315            self.automatic_caching_ttl.clone(),
316        )?;
317
318        if enabled!(Level::TRACE) {
319            tracing::trace!(
320                target: "rig::completions",
321                "Anthropic completion request: {}",
322                serde_json::to_string_pretty(&body)?
323            );
324        }
325
326        let body: Vec<u8> = serde_json::to_vec(&body)?;
327
328        let req = self
329            .client
330            .post("/v1/messages")?
331            .body(body)
332            .map_err(http_client::Error::Protocol)?;
333
334        let stream = GenericEventSource::new(self.client.clone(), req);
335
336        // Use our SSE decoder to directly handle Server-Sent Events format
337        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
338            let mut current_tool_call: Option<ToolCallState> = None;
339            let mut server_tool_uses: HashMap<usize, ServerToolUseState> = HashMap::new();
340            let mut current_thinking: Option<ThinkingState> = None;
341            let mut sse_stream = Box::pin(stream);
342            let mut input_tokens = 0;
343            let mut final_usage = None;
344
345            let mut text_content = String::new();
346
347            while let Some(sse_result) = sse_stream.next().await {
348                match sse_result {
349                    Ok(Event::Open) => {}
350                    Ok(Event::Message(sse)) => {
351                        // Parse the SSE data as a StreamingEvent
352                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
353                            Ok(event) => {
354                                match &event {
355                                    StreamingEvent::MessageStart { message } => {
356                                        input_tokens = message.usage.input_tokens;
357
358                                        let span = tracing::Span::current();
359                                        span.record("gen_ai.response.id", &message.id);
360                                        span.record("gen_ai.response.model", &message.model);
361                                    },
362                                    StreamingEvent::MessageDelta { delta, usage } => {
363                                        if delta.stop_reason.is_some() {
364                                            // cache_creation_input_tokens and cache_read_input_tokens
365                                            // are cumulative totals on message_delta.usage per the
366                                            // Anthropic streaming API spec — use them directly.
367                                            let usage = PartialUsage {
368                                                 output_tokens: usage.output_tokens,
369                                                 input_tokens: usize::try_from(input_tokens).ok(),
370                                                 cache_creation_input_tokens: usage.cache_creation_input_tokens,
371                                                 cache_read_input_tokens: usage.cache_read_input_tokens
372                                            };
373
374                                            let span = tracing::Span::current();
375                                            span.record_token_usage(&usage);
376                                            final_usage = Some(usage);
377                                            break;
378                                        }
379                                    }
380                                    _ => {}
381                                }
382
383                                if let Some(result) = handle_event(
384                                    &event,
385                                    &mut current_tool_call,
386                                    &mut server_tool_uses,
387                                    &mut current_thinking,
388                                ) {
389                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
390                                        text_content += text;
391                                    }
392                                    yield result;
393                                }
394                            },
395                            Err(e) => {
396                                if !sse.data.trim().is_empty() {
397                                    yield Err(CompletionError::ResponseError(
398                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
399                                    ));
400                                }
401                            }
402                        }
403                    },
404                    Err(e) => {
405                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
406                        break;
407                    }
408                }
409            }
410
411            // Ensure event source is closed when stream ends
412            sse_stream.close();
413
414            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
415                usage: final_usage.unwrap_or_default()
416            }))
417        }.instrument(span));
418
419        Ok(streaming::StreamingCompletionResponse::stream(stream))
420    }
421}
422
423fn handle_event(
424    event: &StreamingEvent,
425    current_tool_call: &mut Option<ToolCallState>,
426    server_tool_uses: &mut HashMap<usize, ServerToolUseState>,
427    current_thinking: &mut Option<ThinkingState>,
428) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
429    match event {
430        StreamingEvent::ContentBlockDelta { index, delta } => match delta {
431            ContentDelta::TextDelta { text } => {
432                if current_tool_call.is_none() {
433                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
434                }
435                None
436            }
437            ContentDelta::InputJsonDelta { partial_json } => {
438                if let Some(server_tool_use) = server_tool_uses.get_mut(index) {
439                    server_tool_use.input_json.push_str(partial_json);
440                    return None;
441                }
442
443                if let Some(tool_call) = current_tool_call {
444                    tool_call.input_json.push_str(partial_json);
445                    // Emit the delta so UI can show progress
446                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
447                        id: tool_call.id.clone(),
448                        internal_call_id: tool_call.internal_call_id.clone(),
449                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
450                    }));
451                }
452                None
453            }
454            ContentDelta::ThinkingDelta { thinking } => {
455                current_thinking
456                    .get_or_insert_with(ThinkingState::default)
457                    .thinking
458                    .push_str(thinking);
459
460                Some(Ok(RawStreamingChoice::ReasoningDelta {
461                    id: None,
462                    reasoning: thinking.clone(),
463                }))
464            }
465            ContentDelta::SignatureDelta { signature } => {
466                current_thinking
467                    .get_or_insert_with(ThinkingState::default)
468                    .signature
469                    .push_str(signature);
470
471                // Don't yield signature chunks, they will be included in the final Reasoning
472                None
473            }
474            ContentDelta::CitationsDelta { citation } => {
475                Some(Ok(RawStreamingChoice::TextAdditionalParams(json!({
476                    "citations": [citation]
477                }))))
478            }
479            ContentDelta::Unknown => None,
480        },
481        StreamingEvent::ContentBlockStart {
482            index,
483            content_block,
484        } => match content_block {
485            Content::Text { citations, .. } => {
486                let additional_params = (!citations.is_empty()).then(|| {
487                    json!({
488                        "citations": citations
489                    })
490                });
491                Some(Ok(RawStreamingChoice::TextStart { additional_params }))
492            }
493            Content::ServerToolUse { id, name, input } => {
494                server_tool_uses.insert(
495                    *index,
496                    ServerToolUseState {
497                        name: name.clone(),
498                        id: id.clone(),
499                        initial_input: input.clone(),
500                        input_json: String::new(),
501                    },
502                );
503                None
504            }
505            raw @ Content::WebSearchToolResult { .. } => Some(Ok(RawStreamingChoice::TextStart {
506                additional_params: Some(json!({
507                    super::completion::ANTHROPIC_RAW_CONTENT_KEY: raw
508                })),
509            })),
510            Content::ToolUse { id, name, .. } => {
511                let internal_call_id = nanoid::nanoid!();
512                *current_tool_call = Some(ToolCallState {
513                    name: name.clone(),
514                    id: id.clone(),
515                    internal_call_id: internal_call_id.clone(),
516                    input_json: String::new(),
517                });
518                Some(Ok(RawStreamingChoice::ToolCallDelta {
519                    id: id.clone(),
520                    internal_call_id,
521                    content: ToolCallDeltaContent::Name(name.clone()),
522                }))
523            }
524            Content::Thinking { .. } => {
525                *current_thinking = Some(ThinkingState::default());
526                None
527            }
528            Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
529                id: None,
530                content: ReasoningContent::Redacted { data: data.clone() },
531            })),
532            // Handle other content types - they don't need special handling
533            _ => None,
534        },
535        StreamingEvent::ContentBlockStop { index } => {
536            if let Some(thinking_state) = Option::take(current_thinking)
537                && !thinking_state.thinking.is_empty()
538            {
539                let signature = if thinking_state.signature.is_empty() {
540                    None
541                } else {
542                    Some(thinking_state.signature)
543                };
544
545                return Some(Ok(RawStreamingChoice::Reasoning {
546                    id: None,
547                    content: ReasoningContent::Text {
548                        text: thinking_state.thinking,
549                        signature,
550                    },
551                }));
552            }
553
554            if let Some(server_tool_use) = server_tool_uses.remove(index) {
555                let input = if server_tool_use.input_json.is_empty() {
556                    if server_tool_use.initial_input.is_null() {
557                        json!({})
558                    } else {
559                        server_tool_use.initial_input
560                    }
561                } else {
562                    match serde_json::from_str(&server_tool_use.input_json) {
563                        Ok(json_value) => json_value,
564                        Err(e) => return Some(Err(CompletionError::from(e))),
565                    }
566                };
567
568                return Some(Ok(RawStreamingChoice::TextStart {
569                    additional_params: Some(json!({
570                        super::completion::ANTHROPIC_RAW_CONTENT_KEY: Content::ServerToolUse {
571                            id: server_tool_use.id,
572                            name: server_tool_use.name,
573                            input,
574                        }
575                    })),
576                }));
577            }
578
579            if let Some(tool_call) = Option::take(current_tool_call) {
580                let json_str = if tool_call.input_json.is_empty() {
581                    "{}"
582                } else {
583                    &tool_call.input_json
584                };
585                match serde_json::from_str(json_str) {
586                    Ok(json_value) => {
587                        let raw_tool_call =
588                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
589                                .with_internal_call_id(tool_call.internal_call_id);
590                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
591                    }
592                    Err(e) => Some(Err(CompletionError::from(e))),
593                }
594            } else {
595                None
596            }
597        }
598        // Ignore other event types or handle as needed
599        StreamingEvent::MessageStart { .. }
600        | StreamingEvent::MessageDelta { .. }
601        | StreamingEvent::MessageStop
602        | StreamingEvent::Ping
603        | StreamingEvent::Unknown => None,
604    }
605}
606
607#[cfg(test)]
608mod tests {
609    use super::super::completion::{CLAUDE_OPUS_4_8, CacheControl, CacheTtl};
610    use super::*;
611    use crate::OneOrMany;
612    use crate::completion::Message as RigMessage;
613    use crate::completion::request::Document as RigDocument;
614    use async_stream::stream;
615    use futures::StreamExt;
616
617    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
618    fn to_stream_result(
619        stream: impl futures::Stream<
620            Item = Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>,
621        > + Send
622        + 'static,
623    ) -> crate::streaming::StreamingResult<StreamingCompletionResponse> {
624        Box::pin(stream)
625    }
626
627    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
628    fn to_stream_result(
629        stream: impl futures::Stream<
630            Item = Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>,
631        > + 'static,
632    ) -> crate::streaming::StreamingResult<StreamingCompletionResponse> {
633        Box::pin(stream)
634    }
635
636    #[test]
637    fn test_streaming_tool_build_marks_final_combined_tool() {
638        let mut additional_params = json!({
639            "tools": [{
640                "name": "provider_tool",
641                "description": "Provider tool",
642                "input_schema": {"type": "object"}
643            }]
644        });
645
646        let mut tools = build_tool_definitions(
647            vec![crate::completion::ToolDefinition {
648                name: "rig_tool".to_string(),
649                description: "Rig tool".to_string(),
650                parameters: json!({"type": "object", "properties": {}}),
651            }],
652            &mut additional_params,
653        )
654        .unwrap();
655        let mut system: Vec<SystemContent> = Vec::new();
656        let mut messages: Vec<Message> = Vec::new();
657        apply_prompt_cache_control(&mut system, &mut messages, &mut tools, true, None).unwrap();
658
659        assert_eq!(tools.len(), 2);
660        assert!(tools[0].get("cache_control").is_none());
661        assert_eq!(tools[1]["name"], "provider_tool");
662        assert_eq!(tools[1]["cache_control"]["type"], "ephemeral");
663    }
664
665    #[test]
666    fn streaming_request_keeps_documents_after_leading_system_messages() {
667        let mut request = CompletionRequest {
668            model: None,
669            preamble: None,
670            chat_history: OneOrMany::many(vec![
671                RigMessage::system("System prompt"),
672                RigMessage::assistant("Earlier assistant turn"),
673                RigMessage::system("Mid-conversation instruction"),
674                RigMessage::user("Prompt"),
675            ])
676            .unwrap(),
677            documents: vec![RigDocument {
678                id: "doc1".to_string(),
679                text: "Document text.".to_string(),
680                additional_props: Default::default(),
681            }],
682            tools: vec![],
683            temperature: None,
684            max_tokens: Some(64),
685            tool_choice: None,
686            additional_params: None,
687            output_schema: None,
688        };
689
690        let body = create_streaming_request_body(
691            CLAUDE_OPUS_4_8.to_string(),
692            &mut request,
693            64,
694            false,
695            false,
696            None,
697        )
698        .expect("streaming request body should build");
699
700        assert_eq!(body["system"][0]["text"], "System prompt");
701        assert_eq!(body["system"][1]["text"], "Mid-conversation instruction");
702        let messages = body["messages"]
703            .as_array()
704            .expect("messages should be array");
705        assert_eq!(messages.len(), 3);
706        assert_eq!(messages[0]["role"], "user");
707        assert!(
708            messages[0].to_string().contains("<file id: doc1>"),
709            "document message should follow top-level system: {messages:?}"
710        );
711        assert_eq!(messages[1]["role"], "assistant");
712        assert_eq!(messages[2]["role"], "user");
713        assert_eq!(
714            messages
715                .iter()
716                .filter(|message| message.to_string().contains("<file id: doc1>"))
717                .count(),
718            1,
719            "document message should appear exactly once: {messages:?}"
720        );
721    }
722
723    #[test]
724    fn test_streaming_prompt_cache_control_uses_raw_top_level_ttl() {
725        let mut additional_params = json!({
726            "cache_control": {"type": "ephemeral", "ttl": "1h"}
727        });
728        let top_level_cache_control =
729            resolve_top_level_cache_control(false, None, &mut additional_params).unwrap();
730        let mut tools = build_tool_definitions(
731            vec![crate::completion::ToolDefinition {
732                name: "rig_tool".to_string(),
733                description: "Rig tool".to_string(),
734                parameters: json!({"type": "object", "properties": {}}),
735            }],
736            &mut additional_params,
737        )
738        .unwrap();
739        let mut system = vec![SystemContent::Text {
740            text: "System prompt".to_string(),
741            cache_control: None,
742        }];
743        let mut messages: Vec<Message> = Vec::new();
744
745        apply_prompt_cache_control(
746            &mut system,
747            &mut messages,
748            &mut tools,
749            true,
750            top_level_cache_control.as_ref(),
751        )
752        .unwrap();
753
754        assert_eq!(tools[0]["cache_control"]["type"], "ephemeral");
755        assert_eq!(tools[0]["cache_control"]["ttl"], "1h");
756        match &system[0] {
757            SystemContent::Text {
758                cache_control: Some(CacheControl::Ephemeral { ttl }),
759                ..
760            } => assert_eq!(ttl.as_ref(), Some(&CacheTtl::OneHour)),
761            other => panic!("expected system cache_control, got {other:?}"),
762        }
763        assert!(additional_params.get("cache_control").is_none());
764    }
765
766    fn handle_event(
767        event: &StreamingEvent,
768        current_tool_call: &mut Option<ToolCallState>,
769        current_thinking: &mut Option<ThinkingState>,
770    ) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
771        let mut server_tool_uses = HashMap::new();
772        super::handle_event(
773            event,
774            current_tool_call,
775            &mut server_tool_uses,
776            current_thinking,
777        )
778    }
779
780    #[test]
781    fn test_thinking_delta_deserialization() {
782        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
783        let delta: ContentDelta = serde_json::from_str(json).unwrap();
784
785        match delta {
786            ContentDelta::ThinkingDelta { thinking } => {
787                assert_eq!(thinking, "Let me think about this...");
788            }
789            _ => panic!("Expected ThinkingDelta variant"),
790        }
791    }
792
793    #[test]
794    fn test_signature_delta_deserialization() {
795        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
796        let delta: ContentDelta = serde_json::from_str(json).unwrap();
797
798        match delta {
799            ContentDelta::SignatureDelta { signature } => {
800                assert_eq!(signature, "abc123def456");
801            }
802            _ => panic!("Expected SignatureDelta variant"),
803        }
804    }
805
806    #[test]
807    fn test_thinking_delta_streaming_event_deserialization() {
808        let json = r#"{
809            "type": "content_block_delta",
810            "index": 0,
811            "delta": {
812                "type": "thinking_delta",
813                "thinking": "First, I need to understand the problem."
814            }
815        }"#;
816
817        let event: StreamingEvent = serde_json::from_str(json).unwrap();
818
819        match event {
820            StreamingEvent::ContentBlockDelta { index, delta } => {
821                assert_eq!(index, 0);
822                match delta {
823                    ContentDelta::ThinkingDelta { thinking } => {
824                        assert_eq!(thinking, "First, I need to understand the problem.");
825                    }
826                    _ => panic!("Expected ThinkingDelta"),
827                }
828            }
829            _ => panic!("Expected ContentBlockDelta event"),
830        }
831    }
832
833    #[test]
834    fn test_signature_delta_streaming_event_deserialization() {
835        let json = r#"{
836            "type": "content_block_delta",
837            "index": 0,
838            "delta": {
839                "type": "signature_delta",
840                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
841            }
842        }"#;
843
844        let event: StreamingEvent = serde_json::from_str(json).unwrap();
845
846        match event {
847            StreamingEvent::ContentBlockDelta { index, delta } => {
848                assert_eq!(index, 0);
849                match delta {
850                    ContentDelta::SignatureDelta { signature } => {
851                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
852                    }
853                    _ => panic!("Expected SignatureDelta"),
854                }
855            }
856            _ => panic!("Expected ContentBlockDelta event"),
857        }
858    }
859
860    #[test]
861    fn test_handle_thinking_delta_event() {
862        let event = StreamingEvent::ContentBlockDelta {
863            index: 0,
864            delta: ContentDelta::ThinkingDelta {
865                thinking: "Analyzing the request...".to_string(),
866            },
867        };
868
869        let mut tool_call_state = None;
870        let mut thinking_state = None;
871        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
872
873        assert!(result.is_some());
874        let choice = result.unwrap().unwrap();
875
876        match choice {
877            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
878                assert_eq!(id, None);
879                assert_eq!(reasoning, "Analyzing the request...");
880            }
881            _ => panic!("Expected ReasoningDelta choice"),
882        }
883
884        // Verify thinking state was updated
885        assert!(thinking_state.is_some());
886        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
887    }
888
889    #[test]
890    fn test_handle_signature_delta_event() {
891        let event = StreamingEvent::ContentBlockDelta {
892            index: 0,
893            delta: ContentDelta::SignatureDelta {
894                signature: "test_signature".to_string(),
895            },
896        };
897
898        let mut tool_call_state = None;
899        let mut thinking_state = None;
900        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
901
902        // SignatureDelta should not yield anything (returns None)
903        assert!(result.is_none());
904
905        // But signature should be captured in thinking state
906        assert!(thinking_state.is_some());
907        assert_eq!(thinking_state.unwrap().signature, "test_signature");
908    }
909
910    #[test]
911    fn test_handle_redacted_thinking_content_block_start_event() {
912        let event = StreamingEvent::ContentBlockStart {
913            index: 0,
914            content_block: Content::RedactedThinking {
915                data: "redacted_blob".to_string(),
916            },
917        };
918        let mut tool_call_state = None;
919        let mut thinking_state = None;
920        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
921
922        assert!(result.is_some());
923        match result.unwrap().unwrap() {
924            RawStreamingChoice::Reasoning {
925                content: ReasoningContent::Redacted { data },
926                ..
927            } => {
928                assert_eq!(data, "redacted_blob");
929            }
930            _ => panic!("Expected Redacted reasoning chunk"),
931        }
932    }
933
934    #[test]
935    fn test_handle_text_delta_event() {
936        let event = StreamingEvent::ContentBlockDelta {
937            index: 0,
938            delta: ContentDelta::TextDelta {
939                text: "Hello, world!".to_string(),
940            },
941        };
942
943        let mut tool_call_state = None;
944        let mut thinking_state = None;
945        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
946
947        assert!(result.is_some());
948        let choice = result.unwrap().unwrap();
949
950        match choice {
951            RawStreamingChoice::Message(text) => {
952                assert_eq!(text, "Hello, world!");
953            }
954            _ => panic!("Expected Message choice"),
955        }
956    }
957
958    #[test]
959    fn test_handle_text_block_start_event() {
960        let event = StreamingEvent::ContentBlockStart {
961            index: 0,
962            content_block: Content::Text {
963                text: String::new(),
964                citations: Vec::new(),
965                cache_control: None,
966            },
967        };
968
969        let mut tool_call_state = None;
970        let mut thinking_state = None;
971        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
972
973        assert!(result.is_some());
974        let choice = result.unwrap().unwrap();
975        assert!(matches!(
976            choice,
977            RawStreamingChoice::TextStart {
978                additional_params: None
979            }
980        ));
981    }
982
983    #[test]
984    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
985        // Thinking deltas should still be processed even if a tool call is in progress
986        let event = StreamingEvent::ContentBlockDelta {
987            index: 0,
988            delta: ContentDelta::ThinkingDelta {
989                thinking: "Thinking while tool is active...".to_string(),
990            },
991        };
992
993        let mut tool_call_state = Some(ToolCallState {
994            name: "test_tool".to_string(),
995            id: "tool_123".to_string(),
996            internal_call_id: nanoid::nanoid!(),
997            input_json: String::new(),
998        });
999        let mut thinking_state = None;
1000
1001        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1002
1003        assert!(result.is_some());
1004        let choice = result.unwrap().unwrap();
1005
1006        match choice {
1007            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
1008                assert_eq!(reasoning, "Thinking while tool is active...");
1009            }
1010            _ => panic!("Expected ReasoningDelta choice"),
1011        }
1012
1013        // Tool call state should remain unchanged
1014        assert!(tool_call_state.is_some());
1015    }
1016
1017    #[test]
1018    fn test_handle_input_json_delta_event() {
1019        let event = StreamingEvent::ContentBlockDelta {
1020            index: 0,
1021            delta: ContentDelta::InputJsonDelta {
1022                partial_json: "{\"arg\":\"value".to_string(),
1023            },
1024        };
1025
1026        let mut tool_call_state = Some(ToolCallState {
1027            name: "test_tool".to_string(),
1028            id: "tool_123".to_string(),
1029            internal_call_id: nanoid::nanoid!(),
1030            input_json: String::new(),
1031        });
1032        let mut thinking_state = None;
1033
1034        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1035
1036        // Should emit a ToolCallDelta
1037        assert!(result.is_some());
1038        let choice = result.unwrap().unwrap();
1039
1040        match choice {
1041            RawStreamingChoice::ToolCallDelta {
1042                id,
1043                internal_call_id: _,
1044                content,
1045            } => {
1046                assert_eq!(id, "tool_123");
1047                match content {
1048                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
1049                    _ => panic!("Expected Delta content"),
1050                }
1051            }
1052            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
1053        }
1054
1055        // Verify the input_json was accumulated
1056        assert!(tool_call_state.is_some());
1057        let state = tool_call_state.unwrap();
1058        assert_eq!(state.input_json, "{\"arg\":\"value");
1059    }
1060
1061    #[test]
1062    fn test_tool_call_accumulation_with_multiple_deltas() {
1063        let mut tool_call_state = Some(ToolCallState {
1064            name: "test_tool".to_string(),
1065            id: "tool_123".to_string(),
1066            internal_call_id: nanoid::nanoid!(),
1067            input_json: String::new(),
1068        });
1069        let mut thinking_state = None;
1070
1071        // First delta
1072        let event1 = StreamingEvent::ContentBlockDelta {
1073            index: 0,
1074            delta: ContentDelta::InputJsonDelta {
1075                partial_json: "{\"location\":".to_string(),
1076            },
1077        };
1078        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
1079        assert!(result1.is_some());
1080
1081        // Second delta
1082        let event2 = StreamingEvent::ContentBlockDelta {
1083            index: 0,
1084            delta: ContentDelta::InputJsonDelta {
1085                partial_json: "\"Paris\",".to_string(),
1086            },
1087        };
1088        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
1089        assert!(result2.is_some());
1090
1091        // Third delta
1092        let event3 = StreamingEvent::ContentBlockDelta {
1093            index: 0,
1094            delta: ContentDelta::InputJsonDelta {
1095                partial_json: "\"temp\":\"20C\"}".to_string(),
1096            },
1097        };
1098        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
1099        assert!(result3.is_some());
1100
1101        // Verify accumulated JSON
1102        assert!(tool_call_state.is_some());
1103        let state = tool_call_state.as_ref().unwrap();
1104        assert_eq!(
1105            state.input_json,
1106            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
1107        );
1108
1109        // Final ContentBlockStop should emit complete tool call
1110        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
1111        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
1112        assert!(final_result.is_some());
1113
1114        match final_result.unwrap().unwrap() {
1115            RawStreamingChoice::ToolCall(RawStreamingToolCall {
1116                id,
1117                name,
1118                arguments,
1119                ..
1120            }) => {
1121                assert_eq!(id, "tool_123");
1122                assert_eq!(name, "test_tool");
1123                assert_eq!(
1124                    arguments.get("location").unwrap().as_str().unwrap(),
1125                    "Paris"
1126                );
1127                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
1128            }
1129            other => panic!("Expected ToolCall, got {:?}", other),
1130        }
1131
1132        // Tool call state should be taken
1133        assert!(tool_call_state.is_none());
1134    }
1135
1136    #[test]
1137    fn test_citations_delta_streaming_event_deserialization() {
1138        let json = r#"{
1139            "type": "content_block_delta",
1140            "index": 0,
1141            "delta": {
1142                "type": "citations_delta",
1143                "citation": {
1144                    "type": "char_location",
1145                    "cited_text": "The grass is green.",
1146                    "document_index": 0,
1147                    "document_title": "Example",
1148                    "start_char_index": 0,
1149                    "end_char_index": 20
1150                }
1151            }
1152        }"#;
1153
1154        let event: StreamingEvent = serde_json::from_str(json).unwrap();
1155        let StreamingEvent::ContentBlockDelta { index, delta } = event else {
1156            panic!("expected ContentBlockDelta");
1157        };
1158        assert_eq!(index, 0);
1159        let ContentDelta::CitationsDelta { citation } = delta else {
1160            panic!("expected CitationsDelta");
1161        };
1162        let crate::providers::anthropic::completion::Citation::CharLocation {
1163            start_char_index,
1164            end_char_index,
1165            ..
1166        } = citation
1167        else {
1168            panic!("expected CharLocation");
1169        };
1170        assert_eq!(start_char_index, 0);
1171        assert_eq!(end_char_index, 20);
1172    }
1173
1174    #[test]
1175    fn test_search_result_citations_delta_streaming_event_deserialization() {
1176        let json = r#"{
1177            "type": "content_block_delta",
1178            "index": 0,
1179            "delta": {
1180                "type": "citations_delta",
1181                "citation": {
1182                    "type": "search_result_location",
1183                    "cited_text": "API requests require a key.",
1184                    "source": "https://docs.example.com/api-reference",
1185                    "title": "API Reference",
1186                    "search_result_index": 0,
1187                    "start_block_index": 0,
1188                    "end_block_index": 1
1189                }
1190            }
1191        }"#;
1192
1193        let event: StreamingEvent = serde_json::from_str(json).unwrap();
1194        let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1195            panic!("expected ContentBlockDelta");
1196        };
1197        let ContentDelta::CitationsDelta { citation } = delta else {
1198            panic!("expected CitationsDelta");
1199        };
1200        assert!(matches!(
1201            citation,
1202            crate::providers::anthropic::completion::Citation::SearchResultLocation {
1203                search_result_index: 0,
1204                start_block_index: 0,
1205                end_block_index: 1,
1206                ..
1207            }
1208        ));
1209    }
1210
1211    #[test]
1212    fn test_web_search_result_citations_delta_streaming_event_deserialization() {
1213        let json = r#"{
1214            "type": "content_block_delta",
1215            "index": 0,
1216            "delta": {
1217                "type": "citations_delta",
1218                "citation": {
1219                    "type": "web_search_result_location",
1220                    "cited_text": "Claude Shannon was a mathematician.",
1221                    "url": "https://example.com/shannon",
1222                    "title": "Claude Shannon",
1223                    "encrypted_index": "encrypted-reference"
1224                }
1225            }
1226        }"#;
1227
1228        let event: StreamingEvent = serde_json::from_str(json).unwrap();
1229        let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1230            panic!("expected ContentBlockDelta");
1231        };
1232        let ContentDelta::CitationsDelta { citation } = delta else {
1233            panic!("expected CitationsDelta");
1234        };
1235        assert!(matches!(
1236            citation,
1237            crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1238                ref url,
1239                ref encrypted_index,
1240                ..
1241            } if url == "https://example.com/shannon"
1242                && encrypted_index == "encrypted-reference"
1243        ));
1244    }
1245
1246    #[test]
1247    fn test_web_search_result_citations_delta_allows_null_title() {
1248        let json = r#"{
1249            "type": "content_block_delta",
1250            "index": 0,
1251            "delta": {
1252                "type": "citations_delta",
1253                "citation": {
1254                    "type": "web_search_result_location",
1255                    "cited_text": "Claude Shannon was a mathematician.",
1256                    "url": "https://example.com/shannon",
1257                    "title": null,
1258                    "encrypted_index": "encrypted-reference"
1259                }
1260            }
1261        }"#;
1262
1263        let event: StreamingEvent = serde_json::from_str(json).unwrap();
1264        let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1265            panic!("expected ContentBlockDelta");
1266        };
1267        let ContentDelta::CitationsDelta { citation } = delta else {
1268            panic!("expected CitationsDelta");
1269        };
1270        assert!(matches!(
1271            citation,
1272            crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1273                title: None,
1274                ..
1275            }
1276        ));
1277    }
1278
1279    #[test]
1280    fn test_web_search_content_block_start_events_deserialize() {
1281        let server_tool_use = r#"{
1282            "type": "content_block_start",
1283            "index": 1,
1284            "content_block": {
1285                "type": "server_tool_use",
1286                "id": "srvtoolu_01",
1287                "name": "web_search",
1288                "input": {
1289                    "query": "claude shannon birth date"
1290                }
1291            }
1292        }"#;
1293        let event: StreamingEvent = serde_json::from_str(server_tool_use).unwrap();
1294        assert!(matches!(
1295            event,
1296            StreamingEvent::ContentBlockStart {
1297                content_block: Content::ServerToolUse {
1298                    ref id,
1299                    ref name,
1300                    ref input
1301                },
1302                ..
1303            } if id == "srvtoolu_01"
1304                && name == "web_search"
1305                && input["query"] == "claude shannon birth date"
1306        ));
1307
1308        let web_search_tool_result = r#"{
1309            "type": "content_block_start",
1310            "index": 2,
1311            "content_block": {
1312                "type": "web_search_tool_result",
1313                "tool_use_id": "srvtoolu_01",
1314                "content": [{
1315                    "type": "web_search_result",
1316                    "url": "https://example.com/shannon",
1317                    "title": "Claude Shannon",
1318                    "encrypted_content": "encrypted-content"
1319                }]
1320            }
1321        }"#;
1322        let event: StreamingEvent = serde_json::from_str(web_search_tool_result).unwrap();
1323        assert!(matches!(
1324            event,
1325            StreamingEvent::ContentBlockStart {
1326                content_block: Content::WebSearchToolResult {
1327                    ref tool_use_id,
1328                    ref content
1329                },
1330                ..
1331            } if tool_use_id == "srvtoolu_01"
1332                && content[0]["encrypted_content"] == "encrypted-content"
1333        ));
1334    }
1335
1336    #[tokio::test]
1337    async fn test_streaming_web_search_blocks_are_preserved_on_final_choice() {
1338        let raw_stream = stream! {
1339            let mut tool_call_state = None;
1340            let mut server_tool_uses = HashMap::new();
1341            let mut thinking_state = None;
1342
1343            let server_tool_use_start = super::handle_event(
1344                &StreamingEvent::ContentBlockStart {
1345                    index: 0,
1346                    content_block: Content::ServerToolUse {
1347                        id: "srvtoolu_01".to_string(),
1348                        name: "web_search".to_string(),
1349                        input: serde_json::Value::Null,
1350                    },
1351                },
1352                &mut tool_call_state,
1353                &mut server_tool_uses,
1354                &mut thinking_state,
1355            );
1356            assert!(
1357                server_tool_use_start.is_none(),
1358                "server_tool_use start should be accumulated until its input JSON is complete"
1359            );
1360
1361            let server_tool_use_delta = super::handle_event(
1362                &StreamingEvent::ContentBlockDelta {
1363                    index: 0,
1364                    delta: ContentDelta::InputJsonDelta {
1365                        partial_json: r#"{"query":"claude shannon birth date"}"#.to_string(),
1366                    },
1367                },
1368                &mut tool_call_state,
1369                &mut server_tool_uses,
1370                &mut thinking_state,
1371            );
1372            assert!(
1373                server_tool_use_delta.is_none(),
1374                "server_tool_use input JSON should not be emitted as a Rig tool-call delta"
1375            );
1376
1377            yield super::handle_event(
1378                &StreamingEvent::ContentBlockStop { index: 0 },
1379                &mut tool_call_state,
1380                &mut server_tool_uses,
1381                &mut thinking_state,
1382            )
1383            .expect("server_tool_use stop should produce completed raw metadata");
1384
1385            yield super::handle_event(
1386                &StreamingEvent::ContentBlockStart {
1387                    index: 1,
1388                    content_block: Content::WebSearchToolResult {
1389                        tool_use_id: "srvtoolu_01".to_string(),
1390                        content: serde_json::json!([{
1391                            "type": "web_search_result",
1392                            "url": "https://example.com/shannon",
1393                            "title": "Claude Shannon",
1394                            "encrypted_content": "encrypted-content"
1395                        }]),
1396                    },
1397                },
1398                &mut tool_call_state,
1399                &mut server_tool_uses,
1400                &mut thinking_state,
1401            )
1402            .expect("web_search_tool_result block should produce raw metadata");
1403
1404            yield super::handle_event(
1405                &StreamingEvent::ContentBlockStart {
1406                    index: 2,
1407                    content_block: Content::Text {
1408                        text: String::new(),
1409                        citations: Vec::new(),
1410                        cache_control: None,
1411                    },
1412                },
1413                &mut tool_call_state,
1414                &mut server_tool_uses,
1415                &mut thinking_state,
1416            )
1417            .expect("text block start should produce a raw choice");
1418
1419            yield super::handle_event(
1420                &StreamingEvent::ContentBlockDelta {
1421                    index: 2,
1422                    delta: ContentDelta::TextDelta {
1423                        text: "Claude Shannon was born on April 30, 1916.".to_string(),
1424                    },
1425                },
1426                &mut tool_call_state,
1427                &mut server_tool_uses,
1428                &mut thinking_state,
1429            )
1430            .expect("text delta should produce a raw choice");
1431
1432            yield super::handle_event(
1433                &StreamingEvent::ContentBlockDelta {
1434                    index: 2,
1435                    delta: ContentDelta::CitationsDelta {
1436                        citation: crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1437                            cited_text: "Claude Shannon was born on April 30, 1916.".to_string(),
1438                            url: "https://example.com/shannon".to_string(),
1439                            title: Some("Claude Shannon".to_string()),
1440                            encrypted_index: "encrypted-index".to_string(),
1441                        },
1442                    },
1443                },
1444                &mut tool_call_state,
1445                &mut server_tool_uses,
1446                &mut thinking_state,
1447            )
1448            .expect("citation delta should produce a raw choice");
1449
1450            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
1451                usage: PartialUsage::default(),
1452            }));
1453        };
1454
1455        let mut stream =
1456            crate::streaming::StreamingCompletionResponse::stream(to_stream_result(raw_stream));
1457        while stream.next().await.is_some() {}
1458
1459        let choice_items: Vec<crate::message::AssistantContent> =
1460            stream.choice.clone().into_iter().collect();
1461        assert_eq!(choice_items.len(), 3);
1462        assert!(
1463            choice_items
1464                .iter()
1465                .all(|item| !matches!(item, crate::message::AssistantContent::ToolCall(_))),
1466            "provider-owned web-search blocks must not become Rig client tool calls"
1467        );
1468
1469        let Some(crate::message::AssistantContent::Text(server_tool_use)) = choice_items.first()
1470        else {
1471            panic!("expected raw server_tool_use metadata");
1472        };
1473        assert_eq!(
1474            server_tool_use.additional_params.as_ref().unwrap()
1475                [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["type"],
1476            "server_tool_use"
1477        );
1478        assert_eq!(
1479            server_tool_use.additional_params.as_ref().unwrap()
1480                [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["input"]["query"],
1481            "claude shannon birth date"
1482        );
1483
1484        let Some(crate::message::AssistantContent::Text(web_search_result)) = choice_items.get(1)
1485        else {
1486            panic!("expected raw web_search_tool_result metadata");
1487        };
1488        assert_eq!(
1489            web_search_result.additional_params.as_ref().unwrap()
1490                [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["content"][0]
1491                ["encrypted_content"],
1492            "encrypted-content"
1493        );
1494
1495        let Some(crate::message::AssistantContent::Text(answer)) = choice_items.get(2) else {
1496            panic!("expected answer text");
1497        };
1498        assert_eq!(answer.text, "Claude Shannon was born on April 30, 1916.");
1499        let citations = crate::providers::anthropic::completion::anthropic_citations(answer)
1500            .expect("expected preserved citations");
1501        assert!(matches!(
1502            citations.first(),
1503            Some(crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1504                encrypted_index,
1505                ..
1506            }) if encrypted_index == "encrypted-index"
1507        ));
1508    }
1509
1510    #[test]
1511    fn test_handle_citations_delta_event_preserves_metadata() {
1512        let event = StreamingEvent::ContentBlockDelta {
1513            index: 0,
1514            delta: ContentDelta::CitationsDelta {
1515                citation: crate::providers::anthropic::completion::Citation::CharLocation {
1516                    cited_text: "The grass is green.".to_string(),
1517                    document_index: 0,
1518                    document_title: Some("Example".to_string()),
1519                    start_char_index: 0,
1520                    end_char_index: 20,
1521                },
1522            },
1523        };
1524
1525        let mut tool_call_state = None;
1526        let mut thinking_state = None;
1527        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1528
1529        assert!(result.is_some());
1530        let choice = result.unwrap().unwrap();
1531        let RawStreamingChoice::TextAdditionalParams(additional_params) = choice else {
1532            panic!("expected TextAdditionalParams choice");
1533        };
1534        assert_eq!(additional_params["citations"][0]["type"], "char_location");
1535    }
1536
1537    #[tokio::test]
1538    async fn test_streaming_citation_deltas_are_preserved_on_final_text() {
1539        let citation = crate::providers::anthropic::completion::Citation::CharLocation {
1540            cited_text: "The grass is green.".to_string(),
1541            document_index: 0,
1542            document_title: Some("Example".to_string()),
1543            start_char_index: 0,
1544            end_char_index: 20,
1545        };
1546
1547        let raw_stream = stream! {
1548            let mut tool_call_state = None;
1549            let mut thinking_state = None;
1550
1551            yield handle_event(
1552                &StreamingEvent::ContentBlockStart {
1553                    index: 0,
1554                    content_block: Content::Text {
1555                        text: String::new(),
1556                        citations: Vec::new(),
1557                        cache_control: None,
1558                    },
1559                },
1560                &mut tool_call_state,
1561                &mut thinking_state,
1562            )
1563            .expect("text block start should produce a raw choice");
1564
1565            yield handle_event(
1566                &StreamingEvent::ContentBlockDelta {
1567                    index: 0,
1568                    delta: ContentDelta::TextDelta {
1569                        text: "the grass is green".to_string(),
1570                    },
1571                },
1572                &mut tool_call_state,
1573                &mut thinking_state,
1574            )
1575            .expect("text delta should produce a raw choice");
1576
1577            yield handle_event(
1578                &StreamingEvent::ContentBlockDelta {
1579                    index: 0,
1580                    delta: ContentDelta::CitationsDelta {
1581                        citation: crate::providers::anthropic::completion::Citation::CharLocation {
1582                            cited_text: "The grass is green.".to_string(),
1583                            document_index: 0,
1584                            document_title: Some("Example".to_string()),
1585                            start_char_index: 0,
1586                            end_char_index: 20,
1587                        },
1588                    },
1589                },
1590                &mut tool_call_state,
1591                &mut thinking_state,
1592            )
1593            .expect("citation delta should produce a raw choice");
1594
1595            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
1596                usage: PartialUsage::default(),
1597            }));
1598        };
1599
1600        let mut stream =
1601            crate::streaming::StreamingCompletionResponse::stream(to_stream_result(raw_stream));
1602        while stream.next().await.is_some() {}
1603
1604        let choice_items: Vec<crate::message::AssistantContent> =
1605            stream.choice.clone().into_iter().collect();
1606        let Some(crate::message::AssistantContent::Text(text)) = choice_items.first() else {
1607            panic!("expected accumulated text item");
1608        };
1609
1610        assert_eq!(text.text, "the grass is green");
1611        let citations = crate::providers::anthropic::completion::anthropic_citations(text).unwrap();
1612        assert_eq!(citations, vec![citation]);
1613    }
1614
1615    #[test]
1616    fn test_unknown_content_delta_falls_back() {
1617        let json = r#"{"type": "something_new_from_anthropic", "field": "x"}"#;
1618        let delta: ContentDelta = serde_json::from_str(json).unwrap();
1619        assert!(matches!(delta, ContentDelta::Unknown));
1620    }
1621}