1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 AnthropicCompatibleProvider, CacheControl, Content, GenericCompletionModel, Message,
10 SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control,
11 split_system_messages_from_history,
12};
13use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge_inplace;
17use crate::message::ReasoningContent;
18use crate::streaming::{
19 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
20};
21use crate::telemetry::SpanCombinator;
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23
24#[derive(Debug, Deserialize)]
25#[serde(tag = "type", rename_all = "snake_case")]
26pub enum StreamingEvent {
27 MessageStart {
28 message: MessageStart,
29 },
30 ContentBlockStart {
31 index: usize,
32 content_block: Content,
33 },
34 ContentBlockDelta {
35 index: usize,
36 delta: ContentDelta,
37 },
38 ContentBlockStop {
39 index: usize,
40 },
41 MessageDelta {
42 delta: MessageDelta,
43 usage: PartialUsage,
44 },
45 MessageStop,
46 Ping,
47 #[serde(other)]
48 Unknown,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct MessageStart {
53 pub id: String,
54 pub role: String,
55 pub content: Vec<Content>,
56 pub model: String,
57 pub stop_reason: Option<String>,
58 pub stop_sequence: Option<String>,
59 pub usage: Usage,
60}
61
62#[derive(Debug, Deserialize)]
63#[serde(tag = "type", rename_all = "snake_case")]
64pub enum ContentDelta {
65 TextDelta { text: String },
66 InputJsonDelta { partial_json: String },
67 ThinkingDelta { thinking: String },
68 SignatureDelta { signature: String },
69}
70
71#[derive(Debug, Deserialize)]
72pub struct MessageDelta {
73 pub stop_reason: Option<String>,
74 pub stop_sequence: Option<String>,
75}
76
77#[derive(Debug, Deserialize, Clone, Serialize, Default)]
78pub struct PartialUsage {
79 pub output_tokens: usize,
80 #[serde(default)]
81 pub input_tokens: Option<usize>,
82 #[serde(default)]
83 pub cache_creation_input_tokens: Option<u64>,
84 #[serde(default)]
85 pub cache_read_input_tokens: Option<u64>,
86}
87
88impl GetTokenUsage for PartialUsage {
89 fn token_usage(&self) -> Option<crate::completion::Usage> {
90 let mut usage = crate::completion::Usage::new();
91
92 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
93 usage.output_tokens = self.output_tokens as u64;
94 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or(0);
95 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or(0);
96 usage.total_tokens = usage.input_tokens
97 + usage.cached_input_tokens
98 + usage.cache_creation_input_tokens
99 + usage.output_tokens;
100 Some(usage)
101 }
102}
103
104#[derive(Default)]
105struct ToolCallState {
106 name: String,
107 id: String,
108 internal_call_id: String,
109 input_json: String,
110}
111
112#[derive(Default)]
113struct ThinkingState {
114 thinking: String,
115 signature: String,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct StreamingCompletionResponse {
120 pub usage: PartialUsage,
121}
122
123impl GetTokenUsage for StreamingCompletionResponse {
124 fn token_usage(&self) -> Option<crate::completion::Usage> {
125 let mut usage = crate::completion::Usage::new();
126 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
127 usage.output_tokens = self.usage.output_tokens as u64;
128 usage.cached_input_tokens = self.usage.cache_read_input_tokens.unwrap_or(0);
129 usage.cache_creation_input_tokens = self.usage.cache_creation_input_tokens.unwrap_or(0);
130 usage.total_tokens = usage.input_tokens
131 + usage.cached_input_tokens
132 + usage.cache_creation_input_tokens
133 + usage.output_tokens;
134
135 Some(usage)
136 }
137}
138
139impl<Ext, T> GenericCompletionModel<Ext, T>
140where
141 T: HttpClientExt + Clone + Default + 'static,
142 Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
143{
144 pub(crate) async fn stream(
145 &self,
146 mut completion_request: CompletionRequest,
147 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
148 {
149 let request_model = completion_request
150 .model
151 .clone()
152 .unwrap_or_else(|| self.model.clone());
153 let span = if tracing::Span::current().is_disabled() {
154 info_span!(
155 target: "rig::completions",
156 "chat_streaming",
157 gen_ai.operation.name = "chat_streaming",
158 gen_ai.provider.name = Ext::PROVIDER_NAME,
159 gen_ai.request.model = &request_model,
160 gen_ai.system_instructions = &completion_request.preamble,
161 gen_ai.response.id = tracing::field::Empty,
162 gen_ai.response.model = &request_model,
163 gen_ai.usage.output_tokens = tracing::field::Empty,
164 gen_ai.usage.input_tokens = tracing::field::Empty,
165 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
166 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
167 gen_ai.input.messages = tracing::field::Empty,
168 gen_ai.output.messages = tracing::field::Empty,
169 )
170 } else {
171 tracing::Span::current()
172 };
173 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
174 tokens
175 } else if let Some(tokens) = self.default_max_tokens {
176 tokens
177 } else {
178 return Err(CompletionError::RequestError(
179 "`max_tokens` must be set for Anthropic".into(),
180 ));
181 };
182
183 let mut full_history = vec![];
184 if let Some(docs) = completion_request.normalized_documents() {
185 full_history.push(docs);
186 }
187 full_history.extend(completion_request.chat_history);
188 let (history_system, full_history) = split_system_messages_from_history(full_history);
189
190 let mut messages = full_history
191 .into_iter()
192 .map(Message::try_from)
193 .collect::<Result<Vec<Message>, _>>()?;
194
195 let mut system: Vec<SystemContent> =
197 if let Some(preamble) = completion_request.preamble.as_ref() {
198 if preamble.is_empty() {
199 vec![]
200 } else {
201 vec![SystemContent::Text {
202 text: preamble.clone(),
203 cache_control: None,
204 }]
205 }
206 } else {
207 vec![]
208 };
209 system.extend(history_system);
210
211 if self.prompt_caching {
213 apply_cache_control(&mut system, &mut messages);
214 }
215
216 let mut body = json!({
217 "model": request_model,
218 "messages": messages,
219 "max_tokens": max_tokens,
220 "stream": true,
221 });
222
223 if self.automatic_caching {
226 let cc = CacheControl::Ephemeral {
227 ttl: self.automatic_caching_ttl.clone(),
228 };
229 merge_inplace(
230 &mut body,
231 json!({ "cache_control": serde_json::to_value(&cc)? }),
232 );
233 }
234
235 if !system.is_empty() {
237 merge_inplace(&mut body, json!({ "system": system }));
238 }
239
240 if let Some(temperature) = completion_request.temperature {
241 merge_inplace(&mut body, json!({ "temperature": temperature }));
242 }
243
244 let mut additional_params_payload = completion_request
245 .additional_params
246 .take()
247 .unwrap_or(Value::Null);
248 let mut additional_tools =
249 extract_tools_from_additional_params(&mut additional_params_payload)?;
250
251 let mut tools = completion_request
252 .tools
253 .into_iter()
254 .map(|tool| ToolDefinition {
255 name: tool.name,
256 description: Some(tool.description),
257 input_schema: tool.parameters,
258 })
259 .map(serde_json::to_value)
260 .collect::<Result<Vec<_>, _>>()?;
261 tools.append(&mut additional_tools);
262
263 if !tools.is_empty() {
264 merge_inplace(
265 &mut body,
266 json!({
267 "tools": tools,
268 "tool_choice": ToolChoice::Auto,
269 }),
270 );
271 }
272
273 if !additional_params_payload.is_null() {
274 merge_inplace(&mut body, additional_params_payload)
275 }
276
277 if enabled!(Level::TRACE) {
278 tracing::trace!(
279 target: "rig::completions",
280 "Anthropic completion request: {}",
281 serde_json::to_string_pretty(&body)?
282 );
283 }
284
285 let body: Vec<u8> = serde_json::to_vec(&body)?;
286
287 let req = self
288 .client
289 .post("/v1/messages")?
290 .body(body)
291 .map_err(http_client::Error::Protocol)?;
292
293 let stream = GenericEventSource::new(self.client.clone(), req);
294
295 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
297 let mut current_tool_call: Option<ToolCallState> = None;
298 let mut current_thinking: Option<ThinkingState> = None;
299 let mut sse_stream = Box::pin(stream);
300 let mut input_tokens = 0;
301 let mut final_usage = None;
302
303 let mut text_content = String::new();
304
305 while let Some(sse_result) = sse_stream.next().await {
306 match sse_result {
307 Ok(Event::Open) => {}
308 Ok(Event::Message(sse)) => {
309 match serde_json::from_str::<StreamingEvent>(&sse.data) {
311 Ok(event) => {
312 match &event {
313 StreamingEvent::MessageStart { message } => {
314 input_tokens = message.usage.input_tokens;
315
316 let span = tracing::Span::current();
317 span.record("gen_ai.response.id", &message.id);
318 span.record("gen_ai.response.model", &message.model);
319 },
320 StreamingEvent::MessageDelta { delta, usage } => {
321 if delta.stop_reason.is_some() {
322 let usage = PartialUsage {
326 output_tokens: usage.output_tokens,
327 input_tokens: usize::try_from(input_tokens).ok(),
328 cache_creation_input_tokens: usage.cache_creation_input_tokens,
329 cache_read_input_tokens: usage.cache_read_input_tokens
330 };
331
332 let span = tracing::Span::current();
333 span.record_token_usage(&usage);
334 final_usage = Some(usage);
335 break;
336 }
337 }
338 _ => {}
339 }
340
341 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
342 if let Ok(RawStreamingChoice::Message(ref text)) = result {
343 text_content += text;
344 }
345 yield result;
346 }
347 },
348 Err(e) => {
349 if !sse.data.trim().is_empty() {
350 yield Err(CompletionError::ResponseError(
351 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
352 ));
353 }
354 }
355 }
356 },
357 Err(e) => {
358 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
359 break;
360 }
361 }
362 }
363
364 sse_stream.close();
366
367 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
368 usage: final_usage.unwrap_or_default()
369 }))
370 }.instrument(span));
371
372 Ok(streaming::StreamingCompletionResponse::stream(stream))
373 }
374}
375
376fn extract_tools_from_additional_params(
377 additional_params: &mut Value,
378) -> Result<Vec<Value>, CompletionError> {
379 if let Some(map) = additional_params.as_object_mut()
380 && let Some(raw_tools) = map.remove("tools")
381 {
382 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
383 CompletionError::RequestError(
384 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
385 )
386 });
387 }
388
389 Ok(Vec::new())
390}
391
392fn handle_event(
393 event: &StreamingEvent,
394 current_tool_call: &mut Option<ToolCallState>,
395 current_thinking: &mut Option<ThinkingState>,
396) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
397 match event {
398 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
399 ContentDelta::TextDelta { text } => {
400 if current_tool_call.is_none() {
401 return Some(Ok(RawStreamingChoice::Message(text.clone())));
402 }
403 None
404 }
405 ContentDelta::InputJsonDelta { partial_json } => {
406 if let Some(tool_call) = current_tool_call {
407 tool_call.input_json.push_str(partial_json);
408 return Some(Ok(RawStreamingChoice::ToolCallDelta {
410 id: tool_call.id.clone(),
411 internal_call_id: tool_call.internal_call_id.clone(),
412 content: ToolCallDeltaContent::Delta(partial_json.clone()),
413 }));
414 }
415 None
416 }
417 ContentDelta::ThinkingDelta { thinking } => {
418 current_thinking
419 .get_or_insert_with(ThinkingState::default)
420 .thinking
421 .push_str(thinking);
422
423 Some(Ok(RawStreamingChoice::ReasoningDelta {
424 id: None,
425 reasoning: thinking.clone(),
426 }))
427 }
428 ContentDelta::SignatureDelta { signature } => {
429 current_thinking
430 .get_or_insert_with(ThinkingState::default)
431 .signature
432 .push_str(signature);
433
434 None
436 }
437 },
438 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
439 Content::ToolUse { id, name, .. } => {
440 let internal_call_id = nanoid::nanoid!();
441 *current_tool_call = Some(ToolCallState {
442 name: name.clone(),
443 id: id.clone(),
444 internal_call_id: internal_call_id.clone(),
445 input_json: String::new(),
446 });
447 Some(Ok(RawStreamingChoice::ToolCallDelta {
448 id: id.clone(),
449 internal_call_id,
450 content: ToolCallDeltaContent::Name(name.clone()),
451 }))
452 }
453 Content::Thinking { .. } => {
454 *current_thinking = Some(ThinkingState::default());
455 None
456 }
457 Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
458 id: None,
459 content: ReasoningContent::Redacted { data: data.clone() },
460 })),
461 _ => None,
463 },
464 StreamingEvent::ContentBlockStop { .. } => {
465 if let Some(thinking_state) = Option::take(current_thinking)
466 && !thinking_state.thinking.is_empty()
467 {
468 let signature = if thinking_state.signature.is_empty() {
469 None
470 } else {
471 Some(thinking_state.signature)
472 };
473
474 return Some(Ok(RawStreamingChoice::Reasoning {
475 id: None,
476 content: ReasoningContent::Text {
477 text: thinking_state.thinking,
478 signature,
479 },
480 }));
481 }
482
483 if let Some(tool_call) = Option::take(current_tool_call) {
484 let json_str = if tool_call.input_json.is_empty() {
485 "{}"
486 } else {
487 &tool_call.input_json
488 };
489 match serde_json::from_str(json_str) {
490 Ok(json_value) => {
491 let raw_tool_call =
492 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
493 .with_internal_call_id(tool_call.internal_call_id);
494 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
495 }
496 Err(e) => Some(Err(CompletionError::from(e))),
497 }
498 } else {
499 None
500 }
501 }
502 StreamingEvent::MessageStart { .. }
504 | StreamingEvent::MessageDelta { .. }
505 | StreamingEvent::MessageStop
506 | StreamingEvent::Ping
507 | StreamingEvent::Unknown => None,
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_thinking_delta_deserialization() {
517 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
518 let delta: ContentDelta = serde_json::from_str(json).unwrap();
519
520 match delta {
521 ContentDelta::ThinkingDelta { thinking } => {
522 assert_eq!(thinking, "Let me think about this...");
523 }
524 _ => panic!("Expected ThinkingDelta variant"),
525 }
526 }
527
528 #[test]
529 fn test_signature_delta_deserialization() {
530 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
531 let delta: ContentDelta = serde_json::from_str(json).unwrap();
532
533 match delta {
534 ContentDelta::SignatureDelta { signature } => {
535 assert_eq!(signature, "abc123def456");
536 }
537 _ => panic!("Expected SignatureDelta variant"),
538 }
539 }
540
541 #[test]
542 fn test_thinking_delta_streaming_event_deserialization() {
543 let json = r#"{
544 "type": "content_block_delta",
545 "index": 0,
546 "delta": {
547 "type": "thinking_delta",
548 "thinking": "First, I need to understand the problem."
549 }
550 }"#;
551
552 let event: StreamingEvent = serde_json::from_str(json).unwrap();
553
554 match event {
555 StreamingEvent::ContentBlockDelta { index, delta } => {
556 assert_eq!(index, 0);
557 match delta {
558 ContentDelta::ThinkingDelta { thinking } => {
559 assert_eq!(thinking, "First, I need to understand the problem.");
560 }
561 _ => panic!("Expected ThinkingDelta"),
562 }
563 }
564 _ => panic!("Expected ContentBlockDelta event"),
565 }
566 }
567
568 #[test]
569 fn test_signature_delta_streaming_event_deserialization() {
570 let json = r#"{
571 "type": "content_block_delta",
572 "index": 0,
573 "delta": {
574 "type": "signature_delta",
575 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
576 }
577 }"#;
578
579 let event: StreamingEvent = serde_json::from_str(json).unwrap();
580
581 match event {
582 StreamingEvent::ContentBlockDelta { index, delta } => {
583 assert_eq!(index, 0);
584 match delta {
585 ContentDelta::SignatureDelta { signature } => {
586 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
587 }
588 _ => panic!("Expected SignatureDelta"),
589 }
590 }
591 _ => panic!("Expected ContentBlockDelta event"),
592 }
593 }
594
595 #[test]
596 fn test_handle_thinking_delta_event() {
597 let event = StreamingEvent::ContentBlockDelta {
598 index: 0,
599 delta: ContentDelta::ThinkingDelta {
600 thinking: "Analyzing the request...".to_string(),
601 },
602 };
603
604 let mut tool_call_state = None;
605 let mut thinking_state = None;
606 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
607
608 assert!(result.is_some());
609 let choice = result.unwrap().unwrap();
610
611 match choice {
612 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
613 assert_eq!(id, None);
614 assert_eq!(reasoning, "Analyzing the request...");
615 }
616 _ => panic!("Expected ReasoningDelta choice"),
617 }
618
619 assert!(thinking_state.is_some());
621 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
622 }
623
624 #[test]
625 fn test_handle_signature_delta_event() {
626 let event = StreamingEvent::ContentBlockDelta {
627 index: 0,
628 delta: ContentDelta::SignatureDelta {
629 signature: "test_signature".to_string(),
630 },
631 };
632
633 let mut tool_call_state = None;
634 let mut thinking_state = None;
635 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
636
637 assert!(result.is_none());
639
640 assert!(thinking_state.is_some());
642 assert_eq!(thinking_state.unwrap().signature, "test_signature");
643 }
644
645 #[test]
646 fn test_handle_redacted_thinking_content_block_start_event() {
647 let event = StreamingEvent::ContentBlockStart {
648 index: 0,
649 content_block: Content::RedactedThinking {
650 data: "redacted_blob".to_string(),
651 },
652 };
653 let mut tool_call_state = None;
654 let mut thinking_state = None;
655 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
656
657 assert!(result.is_some());
658 match result.unwrap().unwrap() {
659 RawStreamingChoice::Reasoning {
660 content: ReasoningContent::Redacted { data },
661 ..
662 } => {
663 assert_eq!(data, "redacted_blob");
664 }
665 _ => panic!("Expected Redacted reasoning chunk"),
666 }
667 }
668
669 #[test]
670 fn test_handle_text_delta_event() {
671 let event = StreamingEvent::ContentBlockDelta {
672 index: 0,
673 delta: ContentDelta::TextDelta {
674 text: "Hello, world!".to_string(),
675 },
676 };
677
678 let mut tool_call_state = None;
679 let mut thinking_state = None;
680 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
681
682 assert!(result.is_some());
683 let choice = result.unwrap().unwrap();
684
685 match choice {
686 RawStreamingChoice::Message(text) => {
687 assert_eq!(text, "Hello, world!");
688 }
689 _ => panic!("Expected Message choice"),
690 }
691 }
692
693 #[test]
694 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
695 let event = StreamingEvent::ContentBlockDelta {
697 index: 0,
698 delta: ContentDelta::ThinkingDelta {
699 thinking: "Thinking while tool is active...".to_string(),
700 },
701 };
702
703 let mut tool_call_state = Some(ToolCallState {
704 name: "test_tool".to_string(),
705 id: "tool_123".to_string(),
706 internal_call_id: nanoid::nanoid!(),
707 input_json: String::new(),
708 });
709 let mut thinking_state = None;
710
711 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
712
713 assert!(result.is_some());
714 let choice = result.unwrap().unwrap();
715
716 match choice {
717 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
718 assert_eq!(reasoning, "Thinking while tool is active...");
719 }
720 _ => panic!("Expected ReasoningDelta choice"),
721 }
722
723 assert!(tool_call_state.is_some());
725 }
726
727 #[test]
728 fn test_handle_input_json_delta_event() {
729 let event = StreamingEvent::ContentBlockDelta {
730 index: 0,
731 delta: ContentDelta::InputJsonDelta {
732 partial_json: "{\"arg\":\"value".to_string(),
733 },
734 };
735
736 let mut tool_call_state = Some(ToolCallState {
737 name: "test_tool".to_string(),
738 id: "tool_123".to_string(),
739 internal_call_id: nanoid::nanoid!(),
740 input_json: String::new(),
741 });
742 let mut thinking_state = None;
743
744 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
745
746 assert!(result.is_some());
748 let choice = result.unwrap().unwrap();
749
750 match choice {
751 RawStreamingChoice::ToolCallDelta {
752 id,
753 internal_call_id: _,
754 content,
755 } => {
756 assert_eq!(id, "tool_123");
757 match content {
758 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
759 _ => panic!("Expected Delta content"),
760 }
761 }
762 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
763 }
764
765 assert!(tool_call_state.is_some());
767 let state = tool_call_state.unwrap();
768 assert_eq!(state.input_json, "{\"arg\":\"value");
769 }
770
771 #[test]
772 fn test_tool_call_accumulation_with_multiple_deltas() {
773 let mut tool_call_state = Some(ToolCallState {
774 name: "test_tool".to_string(),
775 id: "tool_123".to_string(),
776 internal_call_id: nanoid::nanoid!(),
777 input_json: String::new(),
778 });
779 let mut thinking_state = None;
780
781 let event1 = StreamingEvent::ContentBlockDelta {
783 index: 0,
784 delta: ContentDelta::InputJsonDelta {
785 partial_json: "{\"location\":".to_string(),
786 },
787 };
788 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
789 assert!(result1.is_some());
790
791 let event2 = StreamingEvent::ContentBlockDelta {
793 index: 0,
794 delta: ContentDelta::InputJsonDelta {
795 partial_json: "\"Paris\",".to_string(),
796 },
797 };
798 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
799 assert!(result2.is_some());
800
801 let event3 = StreamingEvent::ContentBlockDelta {
803 index: 0,
804 delta: ContentDelta::InputJsonDelta {
805 partial_json: "\"temp\":\"20C\"}".to_string(),
806 },
807 };
808 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
809 assert!(result3.is_some());
810
811 assert!(tool_call_state.is_some());
813 let state = tool_call_state.as_ref().unwrap();
814 assert_eq!(
815 state.input_json,
816 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
817 );
818
819 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
821 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
822 assert!(final_result.is_some());
823
824 match final_result.unwrap().unwrap() {
825 RawStreamingChoice::ToolCall(RawStreamingToolCall {
826 id,
827 name,
828 arguments,
829 ..
830 }) => {
831 assert_eq!(id, "tool_123");
832 assert_eq!(name, "test_tool");
833 assert_eq!(
834 arguments.get("location").unwrap().as_str().unwrap(),
835 "Paris"
836 );
837 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
838 }
839 other => panic!("Expected ToolCall, got {:?}", other),
840 }
841
842 assert!(tool_call_state.is_none());
844 }
845}