1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 AnthropicCompatibleProvider, CacheTtl, Content, GenericCompletionModel, Message, SystemContent,
10 ToolChoice, Usage, apply_prompt_cache_control, build_tool_definitions,
11 resolve_top_level_cache_control, split_system_messages_from_history,
12 supports_mid_conversation_system_messages,
13};
14use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
15use crate::http_client::sse::{Event, GenericEventSource};
16use crate::http_client::{self, HttpClientExt};
17use crate::json_utils::merge_inplace;
18use crate::message::ReasoningContent;
19use crate::streaming::{
20 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
21};
22use crate::telemetry::SpanCombinator;
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use std::collections::HashMap;
25
26fn create_streaming_request_body(
27 request_model: String,
28 completion_request: &mut CompletionRequest,
29 max_tokens: u64,
30 prompt_caching: bool,
31 automatic_caching: bool,
32 automatic_caching_ttl: Option<CacheTtl>,
33) -> Result<Value, CompletionError> {
34 let chat_history = completion_request.chat_history_with_documents();
35 let (history_system, chat_history) = split_system_messages_from_history(
36 chat_history,
37 supports_mid_conversation_system_messages(&request_model),
38 );
39 let mut full_history = vec![];
40 full_history.extend(chat_history);
41
42 let mut messages = full_history
43 .into_iter()
44 .map(Message::try_from)
45 .collect::<Result<Vec<Message>, _>>()?;
46
47 let mut system: Vec<SystemContent> =
49 if let Some(preamble) = completion_request.preamble.as_ref() {
50 if preamble.is_empty() {
51 vec![]
52 } else {
53 vec![SystemContent::Text {
54 text: preamble.clone(),
55 cache_control: None,
56 }]
57 }
58 } else {
59 vec![]
60 };
61 system.extend(history_system);
62
63 let mut additional_params_payload = completion_request
64 .additional_params
65 .take()
66 .unwrap_or(Value::Null);
67 let top_level_cache_control = resolve_top_level_cache_control(
68 automatic_caching,
69 automatic_caching_ttl,
70 &mut additional_params_payload,
71 )?;
72 let mut tools = build_tool_definitions(
73 std::mem::take(&mut completion_request.tools),
74 &mut additional_params_payload,
75 )?;
76
77 apply_prompt_cache_control(
78 &mut system,
79 &mut messages,
80 &mut tools,
81 prompt_caching,
82 top_level_cache_control.as_ref(),
83 )?;
84
85 let mut body = json!({
86 "model": request_model,
87 "messages": messages,
88 "max_tokens": max_tokens,
89 "stream": true,
90 });
91
92 if let Some(cache_control) = top_level_cache_control {
95 merge_inplace(
96 &mut body,
97 json!({ "cache_control": serde_json::to_value(&cache_control)? }),
98 );
99 }
100
101 if !system.is_empty() {
103 merge_inplace(&mut body, json!({ "system": system }));
104 }
105
106 if let Some(temperature) = completion_request.temperature {
107 merge_inplace(&mut body, json!({ "temperature": temperature }));
108 }
109
110 if !tools.is_empty() {
111 merge_inplace(
112 &mut body,
113 json!({
114 "tools": tools,
115 "tool_choice": ToolChoice::Auto,
116 }),
117 );
118 }
119
120 if !additional_params_payload.is_null() {
121 merge_inplace(&mut body, additional_params_payload)
122 }
123
124 Ok(body)
125}
126
127#[derive(Debug, Deserialize)]
128#[serde(tag = "type", rename_all = "snake_case")]
129pub enum StreamingEvent {
130 MessageStart {
131 message: MessageStart,
132 },
133 ContentBlockStart {
134 index: usize,
135 content_block: Content,
136 },
137 ContentBlockDelta {
138 index: usize,
139 delta: ContentDelta,
140 },
141 ContentBlockStop {
142 index: usize,
143 },
144 MessageDelta {
145 delta: MessageDelta,
146 usage: PartialUsage,
147 },
148 MessageStop,
149 Ping,
150 #[serde(other)]
151 Unknown,
152}
153
154#[derive(Debug, Deserialize)]
155pub struct MessageStart {
156 pub id: String,
157 pub role: String,
158 pub content: Vec<Content>,
159 pub model: String,
160 pub stop_reason: Option<String>,
161 pub stop_sequence: Option<String>,
162 pub usage: Usage,
163}
164
165#[derive(Debug, Deserialize)]
166#[serde(tag = "type", rename_all = "snake_case")]
167pub enum ContentDelta {
168 TextDelta {
169 text: String,
170 },
171 InputJsonDelta {
172 partial_json: String,
173 },
174 ThinkingDelta {
175 thinking: String,
176 },
177 SignatureDelta {
178 signature: String,
179 },
180 CitationsDelta {
181 citation: super::completion::Citation,
182 },
183 #[serde(other)]
187 Unknown,
188}
189
190#[derive(Debug, Deserialize)]
191pub struct MessageDelta {
192 pub stop_reason: Option<String>,
193 pub stop_sequence: Option<String>,
194}
195
196#[derive(Debug, Deserialize, Clone, Serialize, Default)]
197pub struct PartialUsage {
198 pub output_tokens: usize,
199 #[serde(default)]
200 pub input_tokens: Option<usize>,
201 #[serde(default)]
202 pub cache_creation_input_tokens: Option<u64>,
203 #[serde(default)]
204 pub cache_read_input_tokens: Option<u64>,
205}
206
207impl GetTokenUsage for PartialUsage {
208 fn token_usage(&self) -> crate::completion::Usage {
209 let mut usage = crate::completion::Usage::new();
210
211 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
212 usage.output_tokens = self.output_tokens as u64;
213 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
214 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
215 usage.total_tokens = usage.input_tokens
216 + usage.cached_input_tokens
217 + usage.cache_creation_input_tokens
218 + usage.output_tokens;
219 usage
220 }
221}
222
223#[derive(Default)]
224struct ToolCallState {
225 name: String,
226 id: String,
227 internal_call_id: String,
228 input_json: String,
229}
230
231struct ServerToolUseState {
232 name: String,
233 id: String,
234 initial_input: Value,
235 input_json: String,
236}
237
238#[derive(Default)]
239struct ThinkingState {
240 thinking: String,
241 signature: String,
242}
243
244#[derive(Clone, Debug, Deserialize, Serialize)]
245pub struct StreamingCompletionResponse {
246 pub usage: PartialUsage,
247}
248
249impl GetTokenUsage for StreamingCompletionResponse {
250 fn token_usage(&self) -> crate::completion::Usage {
251 let mut usage = crate::completion::Usage::new();
252 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
253 usage.output_tokens = self.usage.output_tokens as u64;
254 usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
255 usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
256 usage.total_tokens = usage.input_tokens
257 + usage.cached_input_tokens
258 + usage.cache_creation_input_tokens
259 + usage.output_tokens;
260
261 usage
262 }
263}
264
265impl<Ext, T> GenericCompletionModel<Ext, T>
266where
267 T: HttpClientExt + Clone + Default + 'static,
268 Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
269{
270 pub(crate) async fn stream(
271 &self,
272 mut completion_request: CompletionRequest,
273 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
274 {
275 let request_model = completion_request
276 .model
277 .clone()
278 .unwrap_or_else(|| self.model.clone());
279 let span = if tracing::Span::current().is_disabled() {
280 info_span!(
281 target: "rig::completions",
282 "chat_streaming",
283 gen_ai.operation.name = "chat_streaming",
284 gen_ai.provider.name = Ext::PROVIDER_NAME,
285 gen_ai.request.model = &request_model,
286 gen_ai.system_instructions = &completion_request.preamble,
287 gen_ai.response.id = tracing::field::Empty,
288 gen_ai.response.model = &request_model,
289 gen_ai.usage.output_tokens = tracing::field::Empty,
290 gen_ai.usage.input_tokens = tracing::field::Empty,
291 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
292 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
293 gen_ai.input.messages = tracing::field::Empty,
294 gen_ai.output.messages = tracing::field::Empty,
295 )
296 } else {
297 tracing::Span::current()
298 };
299 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
300 tokens
301 } else if let Some(tokens) = self.default_max_tokens {
302 tokens
303 } else {
304 return Err(CompletionError::RequestError(
305 "`max_tokens` must be set for Anthropic".into(),
306 ));
307 };
308
309 let body = create_streaming_request_body(
310 request_model,
311 &mut completion_request,
312 max_tokens,
313 self.prompt_caching,
314 self.automatic_caching,
315 self.automatic_caching_ttl.clone(),
316 )?;
317
318 if enabled!(Level::TRACE) {
319 tracing::trace!(
320 target: "rig::completions",
321 "Anthropic completion request: {}",
322 serde_json::to_string_pretty(&body)?
323 );
324 }
325
326 let body: Vec<u8> = serde_json::to_vec(&body)?;
327
328 let req = self
329 .client
330 .post("/v1/messages")?
331 .body(body)
332 .map_err(http_client::Error::Protocol)?;
333
334 let stream = GenericEventSource::new(self.client.clone(), req);
335
336 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
338 let mut current_tool_call: Option<ToolCallState> = None;
339 let mut server_tool_uses: HashMap<usize, ServerToolUseState> = HashMap::new();
340 let mut current_thinking: Option<ThinkingState> = None;
341 let mut sse_stream = Box::pin(stream);
342 let mut input_tokens = 0;
343 let mut final_usage = None;
344
345 let mut text_content = String::new();
346
347 while let Some(sse_result) = sse_stream.next().await {
348 match sse_result {
349 Ok(Event::Open) => {}
350 Ok(Event::Message(sse)) => {
351 match serde_json::from_str::<StreamingEvent>(&sse.data) {
353 Ok(event) => {
354 match &event {
355 StreamingEvent::MessageStart { message } => {
356 input_tokens = message.usage.input_tokens;
357
358 let span = tracing::Span::current();
359 span.record("gen_ai.response.id", &message.id);
360 span.record("gen_ai.response.model", &message.model);
361 },
362 StreamingEvent::MessageDelta { delta, usage } => {
363 if delta.stop_reason.is_some() {
364 let usage = PartialUsage {
368 output_tokens: usage.output_tokens,
369 input_tokens: usize::try_from(input_tokens).ok(),
370 cache_creation_input_tokens: usage.cache_creation_input_tokens,
371 cache_read_input_tokens: usage.cache_read_input_tokens
372 };
373
374 let span = tracing::Span::current();
375 span.record_token_usage(&usage);
376 final_usage = Some(usage);
377 break;
378 }
379 }
380 _ => {}
381 }
382
383 if let Some(result) = handle_event(
384 &event,
385 &mut current_tool_call,
386 &mut server_tool_uses,
387 &mut current_thinking,
388 ) {
389 if let Ok(RawStreamingChoice::Message(ref text)) = result {
390 text_content += text;
391 }
392 yield result;
393 }
394 },
395 Err(e) => {
396 if !sse.data.trim().is_empty() {
397 yield Err(CompletionError::ResponseError(
398 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
399 ));
400 }
401 }
402 }
403 },
404 Err(e) => {
405 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
406 break;
407 }
408 }
409 }
410
411 sse_stream.close();
413
414 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
415 usage: final_usage.unwrap_or_default()
416 }))
417 }.instrument(span));
418
419 Ok(streaming::StreamingCompletionResponse::stream(stream))
420 }
421}
422
423fn handle_event(
424 event: &StreamingEvent,
425 current_tool_call: &mut Option<ToolCallState>,
426 server_tool_uses: &mut HashMap<usize, ServerToolUseState>,
427 current_thinking: &mut Option<ThinkingState>,
428) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
429 match event {
430 StreamingEvent::ContentBlockDelta { index, delta } => match delta {
431 ContentDelta::TextDelta { text } => {
432 if current_tool_call.is_none() {
433 return Some(Ok(RawStreamingChoice::Message(text.clone())));
434 }
435 None
436 }
437 ContentDelta::InputJsonDelta { partial_json } => {
438 if let Some(server_tool_use) = server_tool_uses.get_mut(index) {
439 server_tool_use.input_json.push_str(partial_json);
440 return None;
441 }
442
443 if let Some(tool_call) = current_tool_call {
444 tool_call.input_json.push_str(partial_json);
445 return Some(Ok(RawStreamingChoice::ToolCallDelta {
447 id: tool_call.id.clone(),
448 internal_call_id: tool_call.internal_call_id.clone(),
449 content: ToolCallDeltaContent::Delta(partial_json.clone()),
450 }));
451 }
452 None
453 }
454 ContentDelta::ThinkingDelta { thinking } => {
455 current_thinking
456 .get_or_insert_with(ThinkingState::default)
457 .thinking
458 .push_str(thinking);
459
460 Some(Ok(RawStreamingChoice::ReasoningDelta {
461 id: None,
462 reasoning: thinking.clone(),
463 }))
464 }
465 ContentDelta::SignatureDelta { signature } => {
466 current_thinking
467 .get_or_insert_with(ThinkingState::default)
468 .signature
469 .push_str(signature);
470
471 None
473 }
474 ContentDelta::CitationsDelta { citation } => {
475 Some(Ok(RawStreamingChoice::TextAdditionalParams(json!({
476 "citations": [citation]
477 }))))
478 }
479 ContentDelta::Unknown => None,
480 },
481 StreamingEvent::ContentBlockStart {
482 index,
483 content_block,
484 } => match content_block {
485 Content::Text { citations, .. } => {
486 let additional_params = (!citations.is_empty()).then(|| {
487 json!({
488 "citations": citations
489 })
490 });
491 Some(Ok(RawStreamingChoice::TextStart { additional_params }))
492 }
493 Content::ServerToolUse { id, name, input } => {
494 server_tool_uses.insert(
495 *index,
496 ServerToolUseState {
497 name: name.clone(),
498 id: id.clone(),
499 initial_input: input.clone(),
500 input_json: String::new(),
501 },
502 );
503 None
504 }
505 raw @ Content::WebSearchToolResult { .. } => Some(Ok(RawStreamingChoice::TextStart {
506 additional_params: Some(json!({
507 super::completion::ANTHROPIC_RAW_CONTENT_KEY: raw
508 })),
509 })),
510 Content::ToolUse { id, name, .. } => {
511 let internal_call_id = nanoid::nanoid!();
512 *current_tool_call = Some(ToolCallState {
513 name: name.clone(),
514 id: id.clone(),
515 internal_call_id: internal_call_id.clone(),
516 input_json: String::new(),
517 });
518 Some(Ok(RawStreamingChoice::ToolCallDelta {
519 id: id.clone(),
520 internal_call_id,
521 content: ToolCallDeltaContent::Name(name.clone()),
522 }))
523 }
524 Content::Thinking { .. } => {
525 *current_thinking = Some(ThinkingState::default());
526 None
527 }
528 Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
529 id: None,
530 content: ReasoningContent::Redacted { data: data.clone() },
531 })),
532 _ => None,
534 },
535 StreamingEvent::ContentBlockStop { index } => {
536 if let Some(thinking_state) = Option::take(current_thinking)
537 && !thinking_state.thinking.is_empty()
538 {
539 let signature = if thinking_state.signature.is_empty() {
540 None
541 } else {
542 Some(thinking_state.signature)
543 };
544
545 return Some(Ok(RawStreamingChoice::Reasoning {
546 id: None,
547 content: ReasoningContent::Text {
548 text: thinking_state.thinking,
549 signature,
550 },
551 }));
552 }
553
554 if let Some(server_tool_use) = server_tool_uses.remove(index) {
555 let input = if server_tool_use.input_json.is_empty() {
556 if server_tool_use.initial_input.is_null() {
557 json!({})
558 } else {
559 server_tool_use.initial_input
560 }
561 } else {
562 match serde_json::from_str(&server_tool_use.input_json) {
563 Ok(json_value) => json_value,
564 Err(e) => return Some(Err(CompletionError::from(e))),
565 }
566 };
567
568 return Some(Ok(RawStreamingChoice::TextStart {
569 additional_params: Some(json!({
570 super::completion::ANTHROPIC_RAW_CONTENT_KEY: Content::ServerToolUse {
571 id: server_tool_use.id,
572 name: server_tool_use.name,
573 input,
574 }
575 })),
576 }));
577 }
578
579 if let Some(tool_call) = Option::take(current_tool_call) {
580 let json_str = if tool_call.input_json.is_empty() {
581 "{}"
582 } else {
583 &tool_call.input_json
584 };
585 match serde_json::from_str(json_str) {
586 Ok(json_value) => {
587 let raw_tool_call =
588 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
589 .with_internal_call_id(tool_call.internal_call_id);
590 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
591 }
592 Err(e) => Some(Err(CompletionError::from(e))),
593 }
594 } else {
595 None
596 }
597 }
598 StreamingEvent::MessageStart { .. }
600 | StreamingEvent::MessageDelta { .. }
601 | StreamingEvent::MessageStop
602 | StreamingEvent::Ping
603 | StreamingEvent::Unknown => None,
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use super::super::completion::{CLAUDE_OPUS_4_8, CacheControl, CacheTtl};
610 use super::*;
611 use crate::OneOrMany;
612 use crate::completion::Message as RigMessage;
613 use crate::completion::request::Document as RigDocument;
614 use async_stream::stream;
615 use futures::StreamExt;
616
617 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
618 fn to_stream_result(
619 stream: impl futures::Stream<
620 Item = Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>,
621 > + Send
622 + 'static,
623 ) -> crate::streaming::StreamingResult<StreamingCompletionResponse> {
624 Box::pin(stream)
625 }
626
627 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
628 fn to_stream_result(
629 stream: impl futures::Stream<
630 Item = Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>,
631 > + 'static,
632 ) -> crate::streaming::StreamingResult<StreamingCompletionResponse> {
633 Box::pin(stream)
634 }
635
636 #[test]
637 fn test_streaming_tool_build_marks_final_combined_tool() {
638 let mut additional_params = json!({
639 "tools": [{
640 "name": "provider_tool",
641 "description": "Provider tool",
642 "input_schema": {"type": "object"}
643 }]
644 });
645
646 let mut tools = build_tool_definitions(
647 vec![crate::completion::ToolDefinition {
648 name: "rig_tool".to_string(),
649 description: "Rig tool".to_string(),
650 parameters: json!({"type": "object", "properties": {}}),
651 }],
652 &mut additional_params,
653 )
654 .unwrap();
655 let mut system: Vec<SystemContent> = Vec::new();
656 let mut messages: Vec<Message> = Vec::new();
657 apply_prompt_cache_control(&mut system, &mut messages, &mut tools, true, None).unwrap();
658
659 assert_eq!(tools.len(), 2);
660 assert!(tools[0].get("cache_control").is_none());
661 assert_eq!(tools[1]["name"], "provider_tool");
662 assert_eq!(tools[1]["cache_control"]["type"], "ephemeral");
663 }
664
665 #[test]
666 fn streaming_request_keeps_documents_after_leading_system_messages() {
667 let mut request = CompletionRequest {
668 model: None,
669 preamble: None,
670 chat_history: OneOrMany::many(vec![
671 RigMessage::system("System prompt"),
672 RigMessage::assistant("Earlier assistant turn"),
673 RigMessage::system("Mid-conversation instruction"),
674 RigMessage::user("Prompt"),
675 ])
676 .unwrap(),
677 documents: vec![RigDocument {
678 id: "doc1".to_string(),
679 text: "Document text.".to_string(),
680 additional_props: Default::default(),
681 }],
682 tools: vec![],
683 temperature: None,
684 max_tokens: Some(64),
685 tool_choice: None,
686 additional_params: None,
687 output_schema: None,
688 };
689
690 let body = create_streaming_request_body(
691 CLAUDE_OPUS_4_8.to_string(),
692 &mut request,
693 64,
694 false,
695 false,
696 None,
697 )
698 .expect("streaming request body should build");
699
700 assert_eq!(body["system"][0]["text"], "System prompt");
701 assert_eq!(body["system"][1]["text"], "Mid-conversation instruction");
702 let messages = body["messages"]
703 .as_array()
704 .expect("messages should be array");
705 assert_eq!(messages.len(), 3);
706 assert_eq!(messages[0]["role"], "user");
707 assert!(
708 messages[0].to_string().contains("<file id: doc1>"),
709 "document message should follow top-level system: {messages:?}"
710 );
711 assert_eq!(messages[1]["role"], "assistant");
712 assert_eq!(messages[2]["role"], "user");
713 assert_eq!(
714 messages
715 .iter()
716 .filter(|message| message.to_string().contains("<file id: doc1>"))
717 .count(),
718 1,
719 "document message should appear exactly once: {messages:?}"
720 );
721 }
722
723 #[test]
724 fn test_streaming_prompt_cache_control_uses_raw_top_level_ttl() {
725 let mut additional_params = json!({
726 "cache_control": {"type": "ephemeral", "ttl": "1h"}
727 });
728 let top_level_cache_control =
729 resolve_top_level_cache_control(false, None, &mut additional_params).unwrap();
730 let mut tools = build_tool_definitions(
731 vec![crate::completion::ToolDefinition {
732 name: "rig_tool".to_string(),
733 description: "Rig tool".to_string(),
734 parameters: json!({"type": "object", "properties": {}}),
735 }],
736 &mut additional_params,
737 )
738 .unwrap();
739 let mut system = vec![SystemContent::Text {
740 text: "System prompt".to_string(),
741 cache_control: None,
742 }];
743 let mut messages: Vec<Message> = Vec::new();
744
745 apply_prompt_cache_control(
746 &mut system,
747 &mut messages,
748 &mut tools,
749 true,
750 top_level_cache_control.as_ref(),
751 )
752 .unwrap();
753
754 assert_eq!(tools[0]["cache_control"]["type"], "ephemeral");
755 assert_eq!(tools[0]["cache_control"]["ttl"], "1h");
756 match &system[0] {
757 SystemContent::Text {
758 cache_control: Some(CacheControl::Ephemeral { ttl }),
759 ..
760 } => assert_eq!(ttl.as_ref(), Some(&CacheTtl::OneHour)),
761 other => panic!("expected system cache_control, got {other:?}"),
762 }
763 assert!(additional_params.get("cache_control").is_none());
764 }
765
766 fn handle_event(
767 event: &StreamingEvent,
768 current_tool_call: &mut Option<ToolCallState>,
769 current_thinking: &mut Option<ThinkingState>,
770 ) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
771 let mut server_tool_uses = HashMap::new();
772 super::handle_event(
773 event,
774 current_tool_call,
775 &mut server_tool_uses,
776 current_thinking,
777 )
778 }
779
780 #[test]
781 fn test_thinking_delta_deserialization() {
782 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
783 let delta: ContentDelta = serde_json::from_str(json).unwrap();
784
785 match delta {
786 ContentDelta::ThinkingDelta { thinking } => {
787 assert_eq!(thinking, "Let me think about this...");
788 }
789 _ => panic!("Expected ThinkingDelta variant"),
790 }
791 }
792
793 #[test]
794 fn test_signature_delta_deserialization() {
795 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
796 let delta: ContentDelta = serde_json::from_str(json).unwrap();
797
798 match delta {
799 ContentDelta::SignatureDelta { signature } => {
800 assert_eq!(signature, "abc123def456");
801 }
802 _ => panic!("Expected SignatureDelta variant"),
803 }
804 }
805
806 #[test]
807 fn test_thinking_delta_streaming_event_deserialization() {
808 let json = r#"{
809 "type": "content_block_delta",
810 "index": 0,
811 "delta": {
812 "type": "thinking_delta",
813 "thinking": "First, I need to understand the problem."
814 }
815 }"#;
816
817 let event: StreamingEvent = serde_json::from_str(json).unwrap();
818
819 match event {
820 StreamingEvent::ContentBlockDelta { index, delta } => {
821 assert_eq!(index, 0);
822 match delta {
823 ContentDelta::ThinkingDelta { thinking } => {
824 assert_eq!(thinking, "First, I need to understand the problem.");
825 }
826 _ => panic!("Expected ThinkingDelta"),
827 }
828 }
829 _ => panic!("Expected ContentBlockDelta event"),
830 }
831 }
832
833 #[test]
834 fn test_signature_delta_streaming_event_deserialization() {
835 let json = r#"{
836 "type": "content_block_delta",
837 "index": 0,
838 "delta": {
839 "type": "signature_delta",
840 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
841 }
842 }"#;
843
844 let event: StreamingEvent = serde_json::from_str(json).unwrap();
845
846 match event {
847 StreamingEvent::ContentBlockDelta { index, delta } => {
848 assert_eq!(index, 0);
849 match delta {
850 ContentDelta::SignatureDelta { signature } => {
851 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
852 }
853 _ => panic!("Expected SignatureDelta"),
854 }
855 }
856 _ => panic!("Expected ContentBlockDelta event"),
857 }
858 }
859
860 #[test]
861 fn test_handle_thinking_delta_event() {
862 let event = StreamingEvent::ContentBlockDelta {
863 index: 0,
864 delta: ContentDelta::ThinkingDelta {
865 thinking: "Analyzing the request...".to_string(),
866 },
867 };
868
869 let mut tool_call_state = None;
870 let mut thinking_state = None;
871 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
872
873 assert!(result.is_some());
874 let choice = result.unwrap().unwrap();
875
876 match choice {
877 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
878 assert_eq!(id, None);
879 assert_eq!(reasoning, "Analyzing the request...");
880 }
881 _ => panic!("Expected ReasoningDelta choice"),
882 }
883
884 assert!(thinking_state.is_some());
886 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
887 }
888
889 #[test]
890 fn test_handle_signature_delta_event() {
891 let event = StreamingEvent::ContentBlockDelta {
892 index: 0,
893 delta: ContentDelta::SignatureDelta {
894 signature: "test_signature".to_string(),
895 },
896 };
897
898 let mut tool_call_state = None;
899 let mut thinking_state = None;
900 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
901
902 assert!(result.is_none());
904
905 assert!(thinking_state.is_some());
907 assert_eq!(thinking_state.unwrap().signature, "test_signature");
908 }
909
910 #[test]
911 fn test_handle_redacted_thinking_content_block_start_event() {
912 let event = StreamingEvent::ContentBlockStart {
913 index: 0,
914 content_block: Content::RedactedThinking {
915 data: "redacted_blob".to_string(),
916 },
917 };
918 let mut tool_call_state = None;
919 let mut thinking_state = None;
920 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
921
922 assert!(result.is_some());
923 match result.unwrap().unwrap() {
924 RawStreamingChoice::Reasoning {
925 content: ReasoningContent::Redacted { data },
926 ..
927 } => {
928 assert_eq!(data, "redacted_blob");
929 }
930 _ => panic!("Expected Redacted reasoning chunk"),
931 }
932 }
933
934 #[test]
935 fn test_handle_text_delta_event() {
936 let event = StreamingEvent::ContentBlockDelta {
937 index: 0,
938 delta: ContentDelta::TextDelta {
939 text: "Hello, world!".to_string(),
940 },
941 };
942
943 let mut tool_call_state = None;
944 let mut thinking_state = None;
945 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
946
947 assert!(result.is_some());
948 let choice = result.unwrap().unwrap();
949
950 match choice {
951 RawStreamingChoice::Message(text) => {
952 assert_eq!(text, "Hello, world!");
953 }
954 _ => panic!("Expected Message choice"),
955 }
956 }
957
958 #[test]
959 fn test_handle_text_block_start_event() {
960 let event = StreamingEvent::ContentBlockStart {
961 index: 0,
962 content_block: Content::Text {
963 text: String::new(),
964 citations: Vec::new(),
965 cache_control: None,
966 },
967 };
968
969 let mut tool_call_state = None;
970 let mut thinking_state = None;
971 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
972
973 assert!(result.is_some());
974 let choice = result.unwrap().unwrap();
975 assert!(matches!(
976 choice,
977 RawStreamingChoice::TextStart {
978 additional_params: None
979 }
980 ));
981 }
982
983 #[test]
984 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
985 let event = StreamingEvent::ContentBlockDelta {
987 index: 0,
988 delta: ContentDelta::ThinkingDelta {
989 thinking: "Thinking while tool is active...".to_string(),
990 },
991 };
992
993 let mut tool_call_state = Some(ToolCallState {
994 name: "test_tool".to_string(),
995 id: "tool_123".to_string(),
996 internal_call_id: nanoid::nanoid!(),
997 input_json: String::new(),
998 });
999 let mut thinking_state = None;
1000
1001 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1002
1003 assert!(result.is_some());
1004 let choice = result.unwrap().unwrap();
1005
1006 match choice {
1007 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
1008 assert_eq!(reasoning, "Thinking while tool is active...");
1009 }
1010 _ => panic!("Expected ReasoningDelta choice"),
1011 }
1012
1013 assert!(tool_call_state.is_some());
1015 }
1016
1017 #[test]
1018 fn test_handle_input_json_delta_event() {
1019 let event = StreamingEvent::ContentBlockDelta {
1020 index: 0,
1021 delta: ContentDelta::InputJsonDelta {
1022 partial_json: "{\"arg\":\"value".to_string(),
1023 },
1024 };
1025
1026 let mut tool_call_state = Some(ToolCallState {
1027 name: "test_tool".to_string(),
1028 id: "tool_123".to_string(),
1029 internal_call_id: nanoid::nanoid!(),
1030 input_json: String::new(),
1031 });
1032 let mut thinking_state = None;
1033
1034 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1035
1036 assert!(result.is_some());
1038 let choice = result.unwrap().unwrap();
1039
1040 match choice {
1041 RawStreamingChoice::ToolCallDelta {
1042 id,
1043 internal_call_id: _,
1044 content,
1045 } => {
1046 assert_eq!(id, "tool_123");
1047 match content {
1048 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
1049 _ => panic!("Expected Delta content"),
1050 }
1051 }
1052 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
1053 }
1054
1055 assert!(tool_call_state.is_some());
1057 let state = tool_call_state.unwrap();
1058 assert_eq!(state.input_json, "{\"arg\":\"value");
1059 }
1060
1061 #[test]
1062 fn test_tool_call_accumulation_with_multiple_deltas() {
1063 let mut tool_call_state = Some(ToolCallState {
1064 name: "test_tool".to_string(),
1065 id: "tool_123".to_string(),
1066 internal_call_id: nanoid::nanoid!(),
1067 input_json: String::new(),
1068 });
1069 let mut thinking_state = None;
1070
1071 let event1 = StreamingEvent::ContentBlockDelta {
1073 index: 0,
1074 delta: ContentDelta::InputJsonDelta {
1075 partial_json: "{\"location\":".to_string(),
1076 },
1077 };
1078 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
1079 assert!(result1.is_some());
1080
1081 let event2 = StreamingEvent::ContentBlockDelta {
1083 index: 0,
1084 delta: ContentDelta::InputJsonDelta {
1085 partial_json: "\"Paris\",".to_string(),
1086 },
1087 };
1088 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
1089 assert!(result2.is_some());
1090
1091 let event3 = StreamingEvent::ContentBlockDelta {
1093 index: 0,
1094 delta: ContentDelta::InputJsonDelta {
1095 partial_json: "\"temp\":\"20C\"}".to_string(),
1096 },
1097 };
1098 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
1099 assert!(result3.is_some());
1100
1101 assert!(tool_call_state.is_some());
1103 let state = tool_call_state.as_ref().unwrap();
1104 assert_eq!(
1105 state.input_json,
1106 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
1107 );
1108
1109 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
1111 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
1112 assert!(final_result.is_some());
1113
1114 match final_result.unwrap().unwrap() {
1115 RawStreamingChoice::ToolCall(RawStreamingToolCall {
1116 id,
1117 name,
1118 arguments,
1119 ..
1120 }) => {
1121 assert_eq!(id, "tool_123");
1122 assert_eq!(name, "test_tool");
1123 assert_eq!(
1124 arguments.get("location").unwrap().as_str().unwrap(),
1125 "Paris"
1126 );
1127 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
1128 }
1129 other => panic!("Expected ToolCall, got {:?}", other),
1130 }
1131
1132 assert!(tool_call_state.is_none());
1134 }
1135
1136 #[test]
1137 fn test_citations_delta_streaming_event_deserialization() {
1138 let json = r#"{
1139 "type": "content_block_delta",
1140 "index": 0,
1141 "delta": {
1142 "type": "citations_delta",
1143 "citation": {
1144 "type": "char_location",
1145 "cited_text": "The grass is green.",
1146 "document_index": 0,
1147 "document_title": "Example",
1148 "start_char_index": 0,
1149 "end_char_index": 20
1150 }
1151 }
1152 }"#;
1153
1154 let event: StreamingEvent = serde_json::from_str(json).unwrap();
1155 let StreamingEvent::ContentBlockDelta { index, delta } = event else {
1156 panic!("expected ContentBlockDelta");
1157 };
1158 assert_eq!(index, 0);
1159 let ContentDelta::CitationsDelta { citation } = delta else {
1160 panic!("expected CitationsDelta");
1161 };
1162 let crate::providers::anthropic::completion::Citation::CharLocation {
1163 start_char_index,
1164 end_char_index,
1165 ..
1166 } = citation
1167 else {
1168 panic!("expected CharLocation");
1169 };
1170 assert_eq!(start_char_index, 0);
1171 assert_eq!(end_char_index, 20);
1172 }
1173
1174 #[test]
1175 fn test_search_result_citations_delta_streaming_event_deserialization() {
1176 let json = r#"{
1177 "type": "content_block_delta",
1178 "index": 0,
1179 "delta": {
1180 "type": "citations_delta",
1181 "citation": {
1182 "type": "search_result_location",
1183 "cited_text": "API requests require a key.",
1184 "source": "https://docs.example.com/api-reference",
1185 "title": "API Reference",
1186 "search_result_index": 0,
1187 "start_block_index": 0,
1188 "end_block_index": 1
1189 }
1190 }
1191 }"#;
1192
1193 let event: StreamingEvent = serde_json::from_str(json).unwrap();
1194 let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1195 panic!("expected ContentBlockDelta");
1196 };
1197 let ContentDelta::CitationsDelta { citation } = delta else {
1198 panic!("expected CitationsDelta");
1199 };
1200 assert!(matches!(
1201 citation,
1202 crate::providers::anthropic::completion::Citation::SearchResultLocation {
1203 search_result_index: 0,
1204 start_block_index: 0,
1205 end_block_index: 1,
1206 ..
1207 }
1208 ));
1209 }
1210
1211 #[test]
1212 fn test_web_search_result_citations_delta_streaming_event_deserialization() {
1213 let json = r#"{
1214 "type": "content_block_delta",
1215 "index": 0,
1216 "delta": {
1217 "type": "citations_delta",
1218 "citation": {
1219 "type": "web_search_result_location",
1220 "cited_text": "Claude Shannon was a mathematician.",
1221 "url": "https://example.com/shannon",
1222 "title": "Claude Shannon",
1223 "encrypted_index": "encrypted-reference"
1224 }
1225 }
1226 }"#;
1227
1228 let event: StreamingEvent = serde_json::from_str(json).unwrap();
1229 let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1230 panic!("expected ContentBlockDelta");
1231 };
1232 let ContentDelta::CitationsDelta { citation } = delta else {
1233 panic!("expected CitationsDelta");
1234 };
1235 assert!(matches!(
1236 citation,
1237 crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1238 ref url,
1239 ref encrypted_index,
1240 ..
1241 } if url == "https://example.com/shannon"
1242 && encrypted_index == "encrypted-reference"
1243 ));
1244 }
1245
1246 #[test]
1247 fn test_web_search_result_citations_delta_allows_null_title() {
1248 let json = r#"{
1249 "type": "content_block_delta",
1250 "index": 0,
1251 "delta": {
1252 "type": "citations_delta",
1253 "citation": {
1254 "type": "web_search_result_location",
1255 "cited_text": "Claude Shannon was a mathematician.",
1256 "url": "https://example.com/shannon",
1257 "title": null,
1258 "encrypted_index": "encrypted-reference"
1259 }
1260 }
1261 }"#;
1262
1263 let event: StreamingEvent = serde_json::from_str(json).unwrap();
1264 let StreamingEvent::ContentBlockDelta { delta, .. } = event else {
1265 panic!("expected ContentBlockDelta");
1266 };
1267 let ContentDelta::CitationsDelta { citation } = delta else {
1268 panic!("expected CitationsDelta");
1269 };
1270 assert!(matches!(
1271 citation,
1272 crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1273 title: None,
1274 ..
1275 }
1276 ));
1277 }
1278
1279 #[test]
1280 fn test_web_search_content_block_start_events_deserialize() {
1281 let server_tool_use = r#"{
1282 "type": "content_block_start",
1283 "index": 1,
1284 "content_block": {
1285 "type": "server_tool_use",
1286 "id": "srvtoolu_01",
1287 "name": "web_search",
1288 "input": {
1289 "query": "claude shannon birth date"
1290 }
1291 }
1292 }"#;
1293 let event: StreamingEvent = serde_json::from_str(server_tool_use).unwrap();
1294 assert!(matches!(
1295 event,
1296 StreamingEvent::ContentBlockStart {
1297 content_block: Content::ServerToolUse {
1298 ref id,
1299 ref name,
1300 ref input
1301 },
1302 ..
1303 } if id == "srvtoolu_01"
1304 && name == "web_search"
1305 && input["query"] == "claude shannon birth date"
1306 ));
1307
1308 let web_search_tool_result = r#"{
1309 "type": "content_block_start",
1310 "index": 2,
1311 "content_block": {
1312 "type": "web_search_tool_result",
1313 "tool_use_id": "srvtoolu_01",
1314 "content": [{
1315 "type": "web_search_result",
1316 "url": "https://example.com/shannon",
1317 "title": "Claude Shannon",
1318 "encrypted_content": "encrypted-content"
1319 }]
1320 }
1321 }"#;
1322 let event: StreamingEvent = serde_json::from_str(web_search_tool_result).unwrap();
1323 assert!(matches!(
1324 event,
1325 StreamingEvent::ContentBlockStart {
1326 content_block: Content::WebSearchToolResult {
1327 ref tool_use_id,
1328 ref content
1329 },
1330 ..
1331 } if tool_use_id == "srvtoolu_01"
1332 && content[0]["encrypted_content"] == "encrypted-content"
1333 ));
1334 }
1335
1336 #[tokio::test]
1337 async fn test_streaming_web_search_blocks_are_preserved_on_final_choice() {
1338 let raw_stream = stream! {
1339 let mut tool_call_state = None;
1340 let mut server_tool_uses = HashMap::new();
1341 let mut thinking_state = None;
1342
1343 let server_tool_use_start = super::handle_event(
1344 &StreamingEvent::ContentBlockStart {
1345 index: 0,
1346 content_block: Content::ServerToolUse {
1347 id: "srvtoolu_01".to_string(),
1348 name: "web_search".to_string(),
1349 input: serde_json::Value::Null,
1350 },
1351 },
1352 &mut tool_call_state,
1353 &mut server_tool_uses,
1354 &mut thinking_state,
1355 );
1356 assert!(
1357 server_tool_use_start.is_none(),
1358 "server_tool_use start should be accumulated until its input JSON is complete"
1359 );
1360
1361 let server_tool_use_delta = super::handle_event(
1362 &StreamingEvent::ContentBlockDelta {
1363 index: 0,
1364 delta: ContentDelta::InputJsonDelta {
1365 partial_json: r#"{"query":"claude shannon birth date"}"#.to_string(),
1366 },
1367 },
1368 &mut tool_call_state,
1369 &mut server_tool_uses,
1370 &mut thinking_state,
1371 );
1372 assert!(
1373 server_tool_use_delta.is_none(),
1374 "server_tool_use input JSON should not be emitted as a Rig tool-call delta"
1375 );
1376
1377 yield super::handle_event(
1378 &StreamingEvent::ContentBlockStop { index: 0 },
1379 &mut tool_call_state,
1380 &mut server_tool_uses,
1381 &mut thinking_state,
1382 )
1383 .expect("server_tool_use stop should produce completed raw metadata");
1384
1385 yield super::handle_event(
1386 &StreamingEvent::ContentBlockStart {
1387 index: 1,
1388 content_block: Content::WebSearchToolResult {
1389 tool_use_id: "srvtoolu_01".to_string(),
1390 content: serde_json::json!([{
1391 "type": "web_search_result",
1392 "url": "https://example.com/shannon",
1393 "title": "Claude Shannon",
1394 "encrypted_content": "encrypted-content"
1395 }]),
1396 },
1397 },
1398 &mut tool_call_state,
1399 &mut server_tool_uses,
1400 &mut thinking_state,
1401 )
1402 .expect("web_search_tool_result block should produce raw metadata");
1403
1404 yield super::handle_event(
1405 &StreamingEvent::ContentBlockStart {
1406 index: 2,
1407 content_block: Content::Text {
1408 text: String::new(),
1409 citations: Vec::new(),
1410 cache_control: None,
1411 },
1412 },
1413 &mut tool_call_state,
1414 &mut server_tool_uses,
1415 &mut thinking_state,
1416 )
1417 .expect("text block start should produce a raw choice");
1418
1419 yield super::handle_event(
1420 &StreamingEvent::ContentBlockDelta {
1421 index: 2,
1422 delta: ContentDelta::TextDelta {
1423 text: "Claude Shannon was born on April 30, 1916.".to_string(),
1424 },
1425 },
1426 &mut tool_call_state,
1427 &mut server_tool_uses,
1428 &mut thinking_state,
1429 )
1430 .expect("text delta should produce a raw choice");
1431
1432 yield super::handle_event(
1433 &StreamingEvent::ContentBlockDelta {
1434 index: 2,
1435 delta: ContentDelta::CitationsDelta {
1436 citation: crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1437 cited_text: "Claude Shannon was born on April 30, 1916.".to_string(),
1438 url: "https://example.com/shannon".to_string(),
1439 title: Some("Claude Shannon".to_string()),
1440 encrypted_index: "encrypted-index".to_string(),
1441 },
1442 },
1443 },
1444 &mut tool_call_state,
1445 &mut server_tool_uses,
1446 &mut thinking_state,
1447 )
1448 .expect("citation delta should produce a raw choice");
1449
1450 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
1451 usage: PartialUsage::default(),
1452 }));
1453 };
1454
1455 let mut stream =
1456 crate::streaming::StreamingCompletionResponse::stream(to_stream_result(raw_stream));
1457 while stream.next().await.is_some() {}
1458
1459 let choice_items: Vec<crate::message::AssistantContent> =
1460 stream.choice.clone().into_iter().collect();
1461 assert_eq!(choice_items.len(), 3);
1462 assert!(
1463 choice_items
1464 .iter()
1465 .all(|item| !matches!(item, crate::message::AssistantContent::ToolCall(_))),
1466 "provider-owned web-search blocks must not become Rig client tool calls"
1467 );
1468
1469 let Some(crate::message::AssistantContent::Text(server_tool_use)) = choice_items.first()
1470 else {
1471 panic!("expected raw server_tool_use metadata");
1472 };
1473 assert_eq!(
1474 server_tool_use.additional_params.as_ref().unwrap()
1475 [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["type"],
1476 "server_tool_use"
1477 );
1478 assert_eq!(
1479 server_tool_use.additional_params.as_ref().unwrap()
1480 [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["input"]["query"],
1481 "claude shannon birth date"
1482 );
1483
1484 let Some(crate::message::AssistantContent::Text(web_search_result)) = choice_items.get(1)
1485 else {
1486 panic!("expected raw web_search_tool_result metadata");
1487 };
1488 assert_eq!(
1489 web_search_result.additional_params.as_ref().unwrap()
1490 [crate::providers::anthropic::completion::ANTHROPIC_RAW_CONTENT_KEY]["content"][0]
1491 ["encrypted_content"],
1492 "encrypted-content"
1493 );
1494
1495 let Some(crate::message::AssistantContent::Text(answer)) = choice_items.get(2) else {
1496 panic!("expected answer text");
1497 };
1498 assert_eq!(answer.text, "Claude Shannon was born on April 30, 1916.");
1499 let citations = crate::providers::anthropic::completion::anthropic_citations(answer)
1500 .expect("expected preserved citations");
1501 assert!(matches!(
1502 citations.first(),
1503 Some(crate::providers::anthropic::completion::Citation::WebSearchResultLocation {
1504 encrypted_index,
1505 ..
1506 }) if encrypted_index == "encrypted-index"
1507 ));
1508 }
1509
1510 #[test]
1511 fn test_handle_citations_delta_event_preserves_metadata() {
1512 let event = StreamingEvent::ContentBlockDelta {
1513 index: 0,
1514 delta: ContentDelta::CitationsDelta {
1515 citation: crate::providers::anthropic::completion::Citation::CharLocation {
1516 cited_text: "The grass is green.".to_string(),
1517 document_index: 0,
1518 document_title: Some("Example".to_string()),
1519 start_char_index: 0,
1520 end_char_index: 20,
1521 },
1522 },
1523 };
1524
1525 let mut tool_call_state = None;
1526 let mut thinking_state = None;
1527 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
1528
1529 assert!(result.is_some());
1530 let choice = result.unwrap().unwrap();
1531 let RawStreamingChoice::TextAdditionalParams(additional_params) = choice else {
1532 panic!("expected TextAdditionalParams choice");
1533 };
1534 assert_eq!(additional_params["citations"][0]["type"], "char_location");
1535 }
1536
1537 #[tokio::test]
1538 async fn test_streaming_citation_deltas_are_preserved_on_final_text() {
1539 let citation = crate::providers::anthropic::completion::Citation::CharLocation {
1540 cited_text: "The grass is green.".to_string(),
1541 document_index: 0,
1542 document_title: Some("Example".to_string()),
1543 start_char_index: 0,
1544 end_char_index: 20,
1545 };
1546
1547 let raw_stream = stream! {
1548 let mut tool_call_state = None;
1549 let mut thinking_state = None;
1550
1551 yield handle_event(
1552 &StreamingEvent::ContentBlockStart {
1553 index: 0,
1554 content_block: Content::Text {
1555 text: String::new(),
1556 citations: Vec::new(),
1557 cache_control: None,
1558 },
1559 },
1560 &mut tool_call_state,
1561 &mut thinking_state,
1562 )
1563 .expect("text block start should produce a raw choice");
1564
1565 yield handle_event(
1566 &StreamingEvent::ContentBlockDelta {
1567 index: 0,
1568 delta: ContentDelta::TextDelta {
1569 text: "the grass is green".to_string(),
1570 },
1571 },
1572 &mut tool_call_state,
1573 &mut thinking_state,
1574 )
1575 .expect("text delta should produce a raw choice");
1576
1577 yield handle_event(
1578 &StreamingEvent::ContentBlockDelta {
1579 index: 0,
1580 delta: ContentDelta::CitationsDelta {
1581 citation: crate::providers::anthropic::completion::Citation::CharLocation {
1582 cited_text: "The grass is green.".to_string(),
1583 document_index: 0,
1584 document_title: Some("Example".to_string()),
1585 start_char_index: 0,
1586 end_char_index: 20,
1587 },
1588 },
1589 },
1590 &mut tool_call_state,
1591 &mut thinking_state,
1592 )
1593 .expect("citation delta should produce a raw choice");
1594
1595 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
1596 usage: PartialUsage::default(),
1597 }));
1598 };
1599
1600 let mut stream =
1601 crate::streaming::StreamingCompletionResponse::stream(to_stream_result(raw_stream));
1602 while stream.next().await.is_some() {}
1603
1604 let choice_items: Vec<crate::message::AssistantContent> =
1605 stream.choice.clone().into_iter().collect();
1606 let Some(crate::message::AssistantContent::Text(text)) = choice_items.first() else {
1607 panic!("expected accumulated text item");
1608 };
1609
1610 assert_eq!(text.text, "the grass is green");
1611 let citations = crate::providers::anthropic::completion::anthropic_citations(text).unwrap();
1612 assert_eq!(citations, vec![citation]);
1613 }
1614
1615 #[test]
1616 fn test_unknown_content_delta_falls_back() {
1617 let json = r#"{"type": "something_new_from_anthropic", "field": "x"}"#;
1618 let delta: ContentDelta = serde_json::from_str(json).unwrap();
1619 assert!(matches!(delta, ContentDelta::Unknown));
1620 }
1621}