1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CacheControl, CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition,
10 Usage, apply_cache_control, split_system_messages_from_history,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25 MessageStart {
26 message: MessageStart,
27 },
28 ContentBlockStart {
29 index: usize,
30 content_block: Content,
31 },
32 ContentBlockDelta {
33 index: usize,
34 delta: ContentDelta,
35 },
36 ContentBlockStop {
37 index: usize,
38 },
39 MessageDelta {
40 delta: MessageDelta,
41 usage: PartialUsage,
42 },
43 MessageStop,
44 Ping,
45 #[serde(other)]
46 Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51 pub id: String,
52 pub role: String,
53 pub content: Vec<Content>,
54 pub model: String,
55 pub stop_reason: Option<String>,
56 pub stop_sequence: Option<String>,
57 pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63 TextDelta { text: String },
64 InputJsonDelta { partial_json: String },
65 ThinkingDelta { thinking: String },
66 SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71 pub stop_reason: Option<String>,
72 pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77 pub output_tokens: usize,
78 #[serde(default)]
79 pub input_tokens: Option<usize>,
80 #[serde(default)]
81 pub cache_creation_input_tokens: Option<u64>,
82 #[serde(default)]
83 pub cache_read_input_tokens: Option<u64>,
84}
85
86impl GetTokenUsage for PartialUsage {
87 fn token_usage(&self) -> Option<crate::completion::Usage> {
88 let mut usage = crate::completion::Usage::new();
89
90 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
91 usage.output_tokens = self.output_tokens as u64;
92 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
93 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
94 usage.total_tokens = usage.input_tokens
95 + usage.cached_input_tokens
96 + usage.cache_creation_input_tokens
97 + usage.output_tokens;
98 Some(usage)
99 }
100}
101
102#[derive(Default)]
103struct ToolCallState {
104 name: String,
105 id: String,
106 internal_call_id: String,
107 input_json: String,
108}
109
110#[derive(Default)]
111struct ThinkingState {
112 thinking: String,
113 signature: String,
114}
115
116#[derive(Clone, Debug, Deserialize, Serialize)]
117pub struct StreamingCompletionResponse {
118 pub usage: PartialUsage,
119}
120
121impl GetTokenUsage for StreamingCompletionResponse {
122 fn token_usage(&self) -> Option<crate::completion::Usage> {
123 let mut usage = crate::completion::Usage::new();
124 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
125 usage.output_tokens = self.usage.output_tokens as u64;
126 usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
127 usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
128 usage.total_tokens = usage.input_tokens
129 + usage.cached_input_tokens
130 + usage.cache_creation_input_tokens
131 + usage.output_tokens;
132
133 Some(usage)
134 }
135}
136
137impl<T> CompletionModel<T>
138where
139 T: HttpClientExt + Clone + Default + 'static,
140{
141 pub(crate) async fn stream(
142 &self,
143 mut completion_request: CompletionRequest,
144 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
145 {
146 let request_model = completion_request
147 .model
148 .clone()
149 .unwrap_or_else(|| self.model.clone());
150 let span = if tracing::Span::current().is_disabled() {
151 info_span!(
152 target: "rig::completions",
153 "chat_streaming",
154 gen_ai.operation.name = "chat_streaming",
155 gen_ai.provider.name = "anthropic",
156 gen_ai.request.model = &request_model,
157 gen_ai.system_instructions = &completion_request.preamble,
158 gen_ai.response.id = tracing::field::Empty,
159 gen_ai.response.model = &request_model,
160 gen_ai.usage.output_tokens = tracing::field::Empty,
161 gen_ai.usage.input_tokens = tracing::field::Empty,
162 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
163 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
164 gen_ai.input.messages = tracing::field::Empty,
165 gen_ai.output.messages = tracing::field::Empty,
166 )
167 } else {
168 tracing::Span::current()
169 };
170 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
171 tokens
172 } else if let Some(tokens) = self.default_max_tokens {
173 tokens
174 } else {
175 return Err(CompletionError::RequestError(
176 "`max_tokens` must be set for Anthropic".into(),
177 ));
178 };
179
180 let mut full_history = vec![];
181 if let Some(docs) = completion_request.normalized_documents() {
182 full_history.push(docs);
183 }
184 full_history.extend(completion_request.chat_history);
185 let (history_system, full_history) = split_system_messages_from_history(full_history);
186
187 let mut messages = full_history
188 .into_iter()
189 .map(Message::try_from)
190 .collect::<Result<Vec<Message>, _>>()?;
191
192 let mut system: Vec<SystemContent> =
194 if let Some(preamble) = completion_request.preamble.as_ref() {
195 if preamble.is_empty() {
196 vec![]
197 } else {
198 vec![SystemContent::Text {
199 text: preamble.clone(),
200 cache_control: None,
201 }]
202 }
203 } else {
204 vec![]
205 };
206 system.extend(history_system);
207
208 if self.prompt_caching {
210 apply_cache_control(&mut system, &mut messages);
211 }
212
213 let mut body = json!({
214 "model": request_model,
215 "messages": messages,
216 "max_tokens": max_tokens,
217 "stream": true,
218 });
219
220 if self.automatic_caching {
223 let cc = CacheControl::Ephemeral {
224 ttl: self.automatic_caching_ttl.clone(),
225 };
226 merge_inplace(
227 &mut body,
228 json!({ "cache_control": serde_json::to_value(&cc)? }),
229 );
230 }
231
232 if !system.is_empty() {
234 merge_inplace(&mut body, json!({ "system": system }));
235 }
236
237 if let Some(temperature) = completion_request.temperature {
238 merge_inplace(&mut body, json!({ "temperature": temperature }));
239 }
240
241 let mut additional_params_payload = completion_request
242 .additional_params
243 .take()
244 .unwrap_or(Value::Null);
245 let mut additional_tools =
246 extract_tools_from_additional_params(&mut additional_params_payload)?;
247
248 let mut tools = completion_request
249 .tools
250 .into_iter()
251 .map(|tool| ToolDefinition {
252 name: tool.name,
253 description: Some(tool.description),
254 input_schema: tool.parameters,
255 })
256 .map(serde_json::to_value)
257 .collect::<Result<Vec<_>, _>>()?;
258 tools.append(&mut additional_tools);
259
260 if !tools.is_empty() {
261 merge_inplace(
262 &mut body,
263 json!({
264 "tools": tools,
265 "tool_choice": ToolChoice::Auto,
266 }),
267 );
268 }
269
270 if !additional_params_payload.is_null() {
271 merge_inplace(&mut body, additional_params_payload)
272 }
273
274 if enabled!(Level::TRACE) {
275 tracing::trace!(
276 target: "rig::completions",
277 "Anthropic completion request: {}",
278 serde_json::to_string_pretty(&body)?
279 );
280 }
281
282 let body: Vec<u8> = serde_json::to_vec(&body)?;
283
284 let req = self
285 .client
286 .post("/v1/messages")?
287 .body(body)
288 .map_err(http_client::Error::Protocol)?;
289
290 let stream = GenericEventSource::new(self.client.clone(), req);
291
292 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
294 let mut current_tool_call: Option<ToolCallState> = None;
295 let mut current_thinking: Option<ThinkingState> = None;
296 let mut sse_stream = Box::pin(stream);
297 let mut input_tokens = 0;
298 let mut final_usage = None;
299
300 let mut text_content = String::new();
301
302 while let Some(sse_result) = sse_stream.next().await {
303 match sse_result {
304 Ok(Event::Open) => {}
305 Ok(Event::Message(sse)) => {
306 match serde_json::from_str::<StreamingEvent>(&sse.data) {
308 Ok(event) => {
309 match &event {
310 StreamingEvent::MessageStart { message } => {
311 input_tokens = message.usage.input_tokens;
312
313 let span = tracing::Span::current();
314 span.record("gen_ai.response.id", &message.id);
315 span.record("gen_ai.response.model_name", &message.model);
316 },
317 StreamingEvent::MessageDelta { delta, usage } => {
318 if delta.stop_reason.is_some() {
319 let usage = PartialUsage {
323 output_tokens: usage.output_tokens,
324 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
325 cache_creation_input_tokens: usage.cache_creation_input_tokens,
326 cache_read_input_tokens: usage.cache_read_input_tokens
327 };
328
329 let span = tracing::Span::current();
330 span.record_token_usage(&usage);
331 final_usage = Some(usage);
332 break;
333 }
334 }
335 _ => {}
336 }
337
338 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
339 if let Ok(RawStreamingChoice::Message(ref text)) = result {
340 text_content += text;
341 }
342 yield result;
343 }
344 },
345 Err(e) => {
346 if !sse.data.trim().is_empty() {
347 yield Err(CompletionError::ResponseError(
348 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
349 ));
350 }
351 }
352 }
353 },
354 Err(e) => {
355 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
356 break;
357 }
358 }
359 }
360
361 sse_stream.close();
363
364 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
365 usage: final_usage.unwrap_or_default()
366 }))
367 }.instrument(span));
368
369 Ok(streaming::StreamingCompletionResponse::stream(stream))
370 }
371}
372
373fn extract_tools_from_additional_params(
374 additional_params: &mut Value,
375) -> Result<Vec<Value>, CompletionError> {
376 if let Some(map) = additional_params.as_object_mut()
377 && let Some(raw_tools) = map.remove("tools")
378 {
379 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
380 CompletionError::RequestError(
381 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
382 )
383 });
384 }
385
386 Ok(Vec::new())
387}
388
389fn handle_event(
390 event: &StreamingEvent,
391 current_tool_call: &mut Option<ToolCallState>,
392 current_thinking: &mut Option<ThinkingState>,
393) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
394 match event {
395 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
396 ContentDelta::TextDelta { text } => {
397 if current_tool_call.is_none() {
398 return Some(Ok(RawStreamingChoice::Message(text.clone())));
399 }
400 None
401 }
402 ContentDelta::InputJsonDelta { partial_json } => {
403 if let Some(tool_call) = current_tool_call {
404 tool_call.input_json.push_str(partial_json);
405 return Some(Ok(RawStreamingChoice::ToolCallDelta {
407 id: tool_call.id.clone(),
408 internal_call_id: tool_call.internal_call_id.clone(),
409 content: ToolCallDeltaContent::Delta(partial_json.clone()),
410 }));
411 }
412 None
413 }
414 ContentDelta::ThinkingDelta { thinking } => {
415 current_thinking
416 .get_or_insert_with(ThinkingState::default)
417 .thinking
418 .push_str(thinking);
419
420 Some(Ok(RawStreamingChoice::ReasoningDelta {
421 id: None,
422 reasoning: thinking.clone(),
423 }))
424 }
425 ContentDelta::SignatureDelta { signature } => {
426 current_thinking
427 .get_or_insert_with(ThinkingState::default)
428 .signature
429 .push_str(signature);
430
431 None
433 }
434 },
435 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
436 Content::ToolUse { id, name, .. } => {
437 let internal_call_id = nanoid::nanoid!();
438 *current_tool_call = Some(ToolCallState {
439 name: name.clone(),
440 id: id.clone(),
441 internal_call_id: internal_call_id.clone(),
442 input_json: String::new(),
443 });
444 Some(Ok(RawStreamingChoice::ToolCallDelta {
445 id: id.clone(),
446 internal_call_id,
447 content: ToolCallDeltaContent::Name(name.clone()),
448 }))
449 }
450 Content::Thinking { .. } => {
451 *current_thinking = Some(ThinkingState::default());
452 None
453 }
454 Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
455 id: None,
456 content: ReasoningContent::Redacted { data: data.clone() },
457 })),
458 _ => None,
460 },
461 StreamingEvent::ContentBlockStop { .. } => {
462 if let Some(thinking_state) = Option::take(current_thinking)
463 && !thinking_state.thinking.is_empty()
464 {
465 let signature = if thinking_state.signature.is_empty() {
466 None
467 } else {
468 Some(thinking_state.signature)
469 };
470
471 return Some(Ok(RawStreamingChoice::Reasoning {
472 id: None,
473 content: ReasoningContent::Text {
474 text: thinking_state.thinking,
475 signature,
476 },
477 }));
478 }
479
480 if let Some(tool_call) = Option::take(current_tool_call) {
481 let json_str = if tool_call.input_json.is_empty() {
482 "{}"
483 } else {
484 &tool_call.input_json
485 };
486 match serde_json::from_str(json_str) {
487 Ok(json_value) => {
488 let raw_tool_call =
489 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
490 .with_internal_call_id(tool_call.internal_call_id);
491 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
492 }
493 Err(e) => Some(Err(CompletionError::from(e))),
494 }
495 } else {
496 None
497 }
498 }
499 StreamingEvent::MessageStart { .. }
501 | StreamingEvent::MessageDelta { .. }
502 | StreamingEvent::MessageStop
503 | StreamingEvent::Ping
504 | StreamingEvent::Unknown => None,
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_thinking_delta_deserialization() {
514 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
515 let delta: ContentDelta = serde_json::from_str(json).unwrap();
516
517 match delta {
518 ContentDelta::ThinkingDelta { thinking } => {
519 assert_eq!(thinking, "Let me think about this...");
520 }
521 _ => panic!("Expected ThinkingDelta variant"),
522 }
523 }
524
525 #[test]
526 fn test_signature_delta_deserialization() {
527 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
528 let delta: ContentDelta = serde_json::from_str(json).unwrap();
529
530 match delta {
531 ContentDelta::SignatureDelta { signature } => {
532 assert_eq!(signature, "abc123def456");
533 }
534 _ => panic!("Expected SignatureDelta variant"),
535 }
536 }
537
538 #[test]
539 fn test_thinking_delta_streaming_event_deserialization() {
540 let json = r#"{
541 "type": "content_block_delta",
542 "index": 0,
543 "delta": {
544 "type": "thinking_delta",
545 "thinking": "First, I need to understand the problem."
546 }
547 }"#;
548
549 let event: StreamingEvent = serde_json::from_str(json).unwrap();
550
551 match event {
552 StreamingEvent::ContentBlockDelta { index, delta } => {
553 assert_eq!(index, 0);
554 match delta {
555 ContentDelta::ThinkingDelta { thinking } => {
556 assert_eq!(thinking, "First, I need to understand the problem.");
557 }
558 _ => panic!("Expected ThinkingDelta"),
559 }
560 }
561 _ => panic!("Expected ContentBlockDelta event"),
562 }
563 }
564
565 #[test]
566 fn test_signature_delta_streaming_event_deserialization() {
567 let json = r#"{
568 "type": "content_block_delta",
569 "index": 0,
570 "delta": {
571 "type": "signature_delta",
572 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
573 }
574 }"#;
575
576 let event: StreamingEvent = serde_json::from_str(json).unwrap();
577
578 match event {
579 StreamingEvent::ContentBlockDelta { index, delta } => {
580 assert_eq!(index, 0);
581 match delta {
582 ContentDelta::SignatureDelta { signature } => {
583 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
584 }
585 _ => panic!("Expected SignatureDelta"),
586 }
587 }
588 _ => panic!("Expected ContentBlockDelta event"),
589 }
590 }
591
592 #[test]
593 fn test_handle_thinking_delta_event() {
594 let event = StreamingEvent::ContentBlockDelta {
595 index: 0,
596 delta: ContentDelta::ThinkingDelta {
597 thinking: "Analyzing the request...".to_string(),
598 },
599 };
600
601 let mut tool_call_state = None;
602 let mut thinking_state = None;
603 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
604
605 assert!(result.is_some());
606 let choice = result.unwrap().unwrap();
607
608 match choice {
609 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
610 assert_eq!(id, None);
611 assert_eq!(reasoning, "Analyzing the request...");
612 }
613 _ => panic!("Expected ReasoningDelta choice"),
614 }
615
616 assert!(thinking_state.is_some());
618 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
619 }
620
621 #[test]
622 fn test_handle_signature_delta_event() {
623 let event = StreamingEvent::ContentBlockDelta {
624 index: 0,
625 delta: ContentDelta::SignatureDelta {
626 signature: "test_signature".to_string(),
627 },
628 };
629
630 let mut tool_call_state = None;
631 let mut thinking_state = None;
632 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
633
634 assert!(result.is_none());
636
637 assert!(thinking_state.is_some());
639 assert_eq!(thinking_state.unwrap().signature, "test_signature");
640 }
641
642 #[test]
643 fn test_handle_redacted_thinking_content_block_start_event() {
644 let event = StreamingEvent::ContentBlockStart {
645 index: 0,
646 content_block: Content::RedactedThinking {
647 data: "redacted_blob".to_string(),
648 },
649 };
650 let mut tool_call_state = None;
651 let mut thinking_state = None;
652 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
653
654 assert!(result.is_some());
655 match result.unwrap().unwrap() {
656 RawStreamingChoice::Reasoning {
657 content: ReasoningContent::Redacted { data },
658 ..
659 } => {
660 assert_eq!(data, "redacted_blob");
661 }
662 _ => panic!("Expected Redacted reasoning chunk"),
663 }
664 }
665
666 #[test]
667 fn test_handle_text_delta_event() {
668 let event = StreamingEvent::ContentBlockDelta {
669 index: 0,
670 delta: ContentDelta::TextDelta {
671 text: "Hello, world!".to_string(),
672 },
673 };
674
675 let mut tool_call_state = None;
676 let mut thinking_state = None;
677 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
678
679 assert!(result.is_some());
680 let choice = result.unwrap().unwrap();
681
682 match choice {
683 RawStreamingChoice::Message(text) => {
684 assert_eq!(text, "Hello, world!");
685 }
686 _ => panic!("Expected Message choice"),
687 }
688 }
689
690 #[test]
691 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
692 let event = StreamingEvent::ContentBlockDelta {
694 index: 0,
695 delta: ContentDelta::ThinkingDelta {
696 thinking: "Thinking while tool is active...".to_string(),
697 },
698 };
699
700 let mut tool_call_state = Some(ToolCallState {
701 name: "test_tool".to_string(),
702 id: "tool_123".to_string(),
703 internal_call_id: nanoid::nanoid!(),
704 input_json: String::new(),
705 });
706 let mut thinking_state = None;
707
708 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
709
710 assert!(result.is_some());
711 let choice = result.unwrap().unwrap();
712
713 match choice {
714 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
715 assert_eq!(reasoning, "Thinking while tool is active...");
716 }
717 _ => panic!("Expected ReasoningDelta choice"),
718 }
719
720 assert!(tool_call_state.is_some());
722 }
723
724 #[test]
725 fn test_handle_input_json_delta_event() {
726 let event = StreamingEvent::ContentBlockDelta {
727 index: 0,
728 delta: ContentDelta::InputJsonDelta {
729 partial_json: "{\"arg\":\"value".to_string(),
730 },
731 };
732
733 let mut tool_call_state = Some(ToolCallState {
734 name: "test_tool".to_string(),
735 id: "tool_123".to_string(),
736 internal_call_id: nanoid::nanoid!(),
737 input_json: String::new(),
738 });
739 let mut thinking_state = None;
740
741 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
742
743 assert!(result.is_some());
745 let choice = result.unwrap().unwrap();
746
747 match choice {
748 RawStreamingChoice::ToolCallDelta {
749 id,
750 internal_call_id: _,
751 content,
752 } => {
753 assert_eq!(id, "tool_123");
754 match content {
755 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
756 _ => panic!("Expected Delta content"),
757 }
758 }
759 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
760 }
761
762 assert!(tool_call_state.is_some());
764 let state = tool_call_state.unwrap();
765 assert_eq!(state.input_json, "{\"arg\":\"value");
766 }
767
768 #[test]
769 fn test_tool_call_accumulation_with_multiple_deltas() {
770 let mut tool_call_state = Some(ToolCallState {
771 name: "test_tool".to_string(),
772 id: "tool_123".to_string(),
773 internal_call_id: nanoid::nanoid!(),
774 input_json: String::new(),
775 });
776 let mut thinking_state = None;
777
778 let event1 = StreamingEvent::ContentBlockDelta {
780 index: 0,
781 delta: ContentDelta::InputJsonDelta {
782 partial_json: "{\"location\":".to_string(),
783 },
784 };
785 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
786 assert!(result1.is_some());
787
788 let event2 = StreamingEvent::ContentBlockDelta {
790 index: 0,
791 delta: ContentDelta::InputJsonDelta {
792 partial_json: "\"Paris\",".to_string(),
793 },
794 };
795 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
796 assert!(result2.is_some());
797
798 let event3 = StreamingEvent::ContentBlockDelta {
800 index: 0,
801 delta: ContentDelta::InputJsonDelta {
802 partial_json: "\"temp\":\"20C\"}".to_string(),
803 },
804 };
805 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
806 assert!(result3.is_some());
807
808 assert!(tool_call_state.is_some());
810 let state = tool_call_state.as_ref().unwrap();
811 assert_eq!(
812 state.input_json,
813 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
814 );
815
816 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
818 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
819 assert!(final_result.is_some());
820
821 match final_result.unwrap().unwrap() {
822 RawStreamingChoice::ToolCall(RawStreamingToolCall {
823 id,
824 name,
825 arguments,
826 ..
827 }) => {
828 assert_eq!(id, "tool_123");
829 assert_eq!(name, "test_tool");
830 assert_eq!(
831 arguments.get("location").unwrap().as_str().unwrap(),
832 "Paris"
833 );
834 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
835 }
836 other => panic!("Expected ToolCall, got {:?}", other),
837 }
838
839 assert!(tool_call_state.is_none());
841 }
842}