1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control, split_system_messages_from_history,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25 MessageStart {
26 message: MessageStart,
27 },
28 ContentBlockStart {
29 index: usize,
30 content_block: Content,
31 },
32 ContentBlockDelta {
33 index: usize,
34 delta: ContentDelta,
35 },
36 ContentBlockStop {
37 index: usize,
38 },
39 MessageDelta {
40 delta: MessageDelta,
41 usage: PartialUsage,
42 },
43 MessageStop,
44 Ping,
45 #[serde(other)]
46 Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51 pub id: String,
52 pub role: String,
53 pub content: Vec<Content>,
54 pub model: String,
55 pub stop_reason: Option<String>,
56 pub stop_sequence: Option<String>,
57 pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63 TextDelta { text: String },
64 InputJsonDelta { partial_json: String },
65 ThinkingDelta { thinking: String },
66 SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71 pub stop_reason: Option<String>,
72 pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77 pub output_tokens: usize,
78 #[serde(default)]
79 pub input_tokens: Option<usize>,
80}
81
82impl GetTokenUsage for PartialUsage {
83 fn token_usage(&self) -> Option<crate::completion::Usage> {
84 let mut usage = crate::completion::Usage::new();
85
86 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
87 usage.output_tokens = self.output_tokens as u64;
88 usage.total_tokens = usage.input_tokens + usage.output_tokens;
89 Some(usage)
90 }
91}
92
93#[derive(Default)]
94struct ToolCallState {
95 name: String,
96 id: String,
97 internal_call_id: String,
98 input_json: String,
99}
100
101#[derive(Default)]
102struct ThinkingState {
103 thinking: String,
104 signature: String,
105}
106
107#[derive(Clone, Debug, Deserialize, Serialize)]
108pub struct StreamingCompletionResponse {
109 pub usage: PartialUsage,
110}
111
112impl GetTokenUsage for StreamingCompletionResponse {
113 fn token_usage(&self) -> Option<crate::completion::Usage> {
114 let mut usage = crate::completion::Usage::new();
115 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
116 usage.output_tokens = self.usage.output_tokens as u64;
117 usage.total_tokens =
118 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
119
120 Some(usage)
121 }
122}
123
124impl<T> CompletionModel<T>
125where
126 T: HttpClientExt + Clone + Default + 'static,
127{
128 pub(crate) async fn stream(
129 &self,
130 mut completion_request: CompletionRequest,
131 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132 {
133 let request_model = completion_request
134 .model
135 .clone()
136 .unwrap_or_else(|| self.model.clone());
137 let span = if tracing::Span::current().is_disabled() {
138 info_span!(
139 target: "rig::completions",
140 "chat_streaming",
141 gen_ai.operation.name = "chat_streaming",
142 gen_ai.provider.name = "anthropic",
143 gen_ai.request.model = &request_model,
144 gen_ai.system_instructions = &completion_request.preamble,
145 gen_ai.response.id = tracing::field::Empty,
146 gen_ai.response.model = &request_model,
147 gen_ai.usage.output_tokens = tracing::field::Empty,
148 gen_ai.usage.input_tokens = tracing::field::Empty,
149 gen_ai.usage.cached_tokens = tracing::field::Empty,
150 gen_ai.input.messages = tracing::field::Empty,
151 gen_ai.output.messages = tracing::field::Empty,
152 )
153 } else {
154 tracing::Span::current()
155 };
156 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
157 tokens
158 } else if let Some(tokens) = self.default_max_tokens {
159 tokens
160 } else {
161 return Err(CompletionError::RequestError(
162 "`max_tokens` must be set for Anthropic".into(),
163 ));
164 };
165
166 let mut full_history = vec![];
167 if let Some(docs) = completion_request.normalized_documents() {
168 full_history.push(docs);
169 }
170 full_history.extend(completion_request.chat_history);
171 let (history_system, full_history) = split_system_messages_from_history(full_history);
172
173 let mut messages = full_history
174 .into_iter()
175 .map(Message::try_from)
176 .collect::<Result<Vec<Message>, _>>()?;
177
178 let mut system: Vec<SystemContent> =
180 if let Some(preamble) = completion_request.preamble.as_ref() {
181 if preamble.is_empty() {
182 vec![]
183 } else {
184 vec![SystemContent::Text {
185 text: preamble.clone(),
186 cache_control: None,
187 }]
188 }
189 } else {
190 vec![]
191 };
192 system.extend(history_system);
193
194 if self.prompt_caching {
196 apply_cache_control(&mut system, &mut messages);
197 }
198
199 let mut body = json!({
200 "model": request_model,
201 "messages": messages,
202 "max_tokens": max_tokens,
203 "stream": true,
204 });
205
206 if !system.is_empty() {
208 merge_inplace(&mut body, json!({ "system": system }));
209 }
210
211 if let Some(temperature) = completion_request.temperature {
212 merge_inplace(&mut body, json!({ "temperature": temperature }));
213 }
214
215 let mut additional_params_payload = completion_request
216 .additional_params
217 .take()
218 .unwrap_or(Value::Null);
219 let mut additional_tools =
220 extract_tools_from_additional_params(&mut additional_params_payload)?;
221
222 let mut tools = completion_request
223 .tools
224 .into_iter()
225 .map(|tool| ToolDefinition {
226 name: tool.name,
227 description: Some(tool.description),
228 input_schema: tool.parameters,
229 })
230 .map(serde_json::to_value)
231 .collect::<Result<Vec<_>, _>>()?;
232 tools.append(&mut additional_tools);
233
234 if !tools.is_empty() {
235 merge_inplace(
236 &mut body,
237 json!({
238 "tools": tools,
239 "tool_choice": ToolChoice::Auto,
240 }),
241 );
242 }
243
244 if !additional_params_payload.is_null() {
245 merge_inplace(&mut body, additional_params_payload)
246 }
247
248 if enabled!(Level::TRACE) {
249 tracing::trace!(
250 target: "rig::completions",
251 "Anthropic completion request: {}",
252 serde_json::to_string_pretty(&body)?
253 );
254 }
255
256 let body: Vec<u8> = serde_json::to_vec(&body)?;
257
258 let req = self
259 .client
260 .post("/v1/messages")?
261 .body(body)
262 .map_err(http_client::Error::Protocol)?;
263
264 let stream = GenericEventSource::new(self.client.clone(), req);
265
266 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
268 let mut current_tool_call: Option<ToolCallState> = None;
269 let mut current_thinking: Option<ThinkingState> = None;
270 let mut sse_stream = Box::pin(stream);
271 let mut input_tokens = 0;
272 let mut final_usage = None;
273
274 let mut text_content = String::new();
275
276 while let Some(sse_result) = sse_stream.next().await {
277 match sse_result {
278 Ok(Event::Open) => {}
279 Ok(Event::Message(sse)) => {
280 match serde_json::from_str::<StreamingEvent>(&sse.data) {
282 Ok(event) => {
283 match &event {
284 StreamingEvent::MessageStart { message } => {
285 input_tokens = message.usage.input_tokens;
286
287 let span = tracing::Span::current();
288 span.record("gen_ai.response.id", &message.id);
289 span.record("gen_ai.response.model_name", &message.model);
290 },
291 StreamingEvent::MessageDelta { delta, usage } => {
292 if delta.stop_reason.is_some() {
293 let usage = PartialUsage {
294 output_tokens: usage.output_tokens,
295 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
296 };
297
298 let span = tracing::Span::current();
299 span.record_token_usage(&usage);
300 final_usage = Some(usage);
301 break;
302 }
303 }
304 _ => {}
305 }
306
307 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
308 if let Ok(RawStreamingChoice::Message(ref text)) = result {
309 text_content += text;
310 }
311 yield result;
312 }
313 },
314 Err(e) => {
315 if !sse.data.trim().is_empty() {
316 yield Err(CompletionError::ResponseError(
317 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
318 ));
319 }
320 }
321 }
322 },
323 Err(e) => {
324 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
325 break;
326 }
327 }
328 }
329
330 sse_stream.close();
332
333 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
334 usage: final_usage.unwrap_or_default()
335 }))
336 }.instrument(span));
337
338 Ok(streaming::StreamingCompletionResponse::stream(stream))
339 }
340}
341
342fn extract_tools_from_additional_params(
343 additional_params: &mut Value,
344) -> Result<Vec<Value>, CompletionError> {
345 if let Some(map) = additional_params.as_object_mut()
346 && let Some(raw_tools) = map.remove("tools")
347 {
348 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
349 CompletionError::RequestError(
350 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
351 )
352 });
353 }
354
355 Ok(Vec::new())
356}
357
358fn handle_event(
359 event: &StreamingEvent,
360 current_tool_call: &mut Option<ToolCallState>,
361 current_thinking: &mut Option<ThinkingState>,
362) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
363 match event {
364 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
365 ContentDelta::TextDelta { text } => {
366 if current_tool_call.is_none() {
367 return Some(Ok(RawStreamingChoice::Message(text.clone())));
368 }
369 None
370 }
371 ContentDelta::InputJsonDelta { partial_json } => {
372 if let Some(tool_call) = current_tool_call {
373 tool_call.input_json.push_str(partial_json);
374 return Some(Ok(RawStreamingChoice::ToolCallDelta {
376 id: tool_call.id.clone(),
377 internal_call_id: tool_call.internal_call_id.clone(),
378 content: ToolCallDeltaContent::Delta(partial_json.clone()),
379 }));
380 }
381 None
382 }
383 ContentDelta::ThinkingDelta { thinking } => {
384 current_thinking
385 .get_or_insert_with(ThinkingState::default)
386 .thinking
387 .push_str(thinking);
388
389 Some(Ok(RawStreamingChoice::ReasoningDelta {
390 id: None,
391 reasoning: thinking.clone(),
392 }))
393 }
394 ContentDelta::SignatureDelta { signature } => {
395 current_thinking
396 .get_or_insert_with(ThinkingState::default)
397 .signature
398 .push_str(signature);
399
400 None
402 }
403 },
404 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
405 Content::ToolUse { id, name, .. } => {
406 let internal_call_id = nanoid::nanoid!();
407 *current_tool_call = Some(ToolCallState {
408 name: name.clone(),
409 id: id.clone(),
410 internal_call_id: internal_call_id.clone(),
411 input_json: String::new(),
412 });
413 Some(Ok(RawStreamingChoice::ToolCallDelta {
414 id: id.clone(),
415 internal_call_id,
416 content: ToolCallDeltaContent::Name(name.clone()),
417 }))
418 }
419 Content::Thinking { .. } => {
420 *current_thinking = Some(ThinkingState::default());
421 None
422 }
423 Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
424 id: None,
425 content: ReasoningContent::Redacted { data: data.clone() },
426 })),
427 _ => None,
429 },
430 StreamingEvent::ContentBlockStop { .. } => {
431 if let Some(thinking_state) = Option::take(current_thinking)
432 && !thinking_state.thinking.is_empty()
433 {
434 let signature = if thinking_state.signature.is_empty() {
435 None
436 } else {
437 Some(thinking_state.signature)
438 };
439
440 return Some(Ok(RawStreamingChoice::Reasoning {
441 id: None,
442 content: ReasoningContent::Text {
443 text: thinking_state.thinking,
444 signature,
445 },
446 }));
447 }
448
449 if let Some(tool_call) = Option::take(current_tool_call) {
450 let json_str = if tool_call.input_json.is_empty() {
451 "{}"
452 } else {
453 &tool_call.input_json
454 };
455 match serde_json::from_str(json_str) {
456 Ok(json_value) => {
457 let raw_tool_call =
458 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
459 .with_internal_call_id(tool_call.internal_call_id);
460 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
461 }
462 Err(e) => Some(Err(CompletionError::from(e))),
463 }
464 } else {
465 None
466 }
467 }
468 StreamingEvent::MessageStart { .. }
470 | StreamingEvent::MessageDelta { .. }
471 | StreamingEvent::MessageStop
472 | StreamingEvent::Ping
473 | StreamingEvent::Unknown => None,
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_thinking_delta_deserialization() {
483 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
484 let delta: ContentDelta = serde_json::from_str(json).unwrap();
485
486 match delta {
487 ContentDelta::ThinkingDelta { thinking } => {
488 assert_eq!(thinking, "Let me think about this...");
489 }
490 _ => panic!("Expected ThinkingDelta variant"),
491 }
492 }
493
494 #[test]
495 fn test_signature_delta_deserialization() {
496 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
497 let delta: ContentDelta = serde_json::from_str(json).unwrap();
498
499 match delta {
500 ContentDelta::SignatureDelta { signature } => {
501 assert_eq!(signature, "abc123def456");
502 }
503 _ => panic!("Expected SignatureDelta variant"),
504 }
505 }
506
507 #[test]
508 fn test_thinking_delta_streaming_event_deserialization() {
509 let json = r#"{
510 "type": "content_block_delta",
511 "index": 0,
512 "delta": {
513 "type": "thinking_delta",
514 "thinking": "First, I need to understand the problem."
515 }
516 }"#;
517
518 let event: StreamingEvent = serde_json::from_str(json).unwrap();
519
520 match event {
521 StreamingEvent::ContentBlockDelta { index, delta } => {
522 assert_eq!(index, 0);
523 match delta {
524 ContentDelta::ThinkingDelta { thinking } => {
525 assert_eq!(thinking, "First, I need to understand the problem.");
526 }
527 _ => panic!("Expected ThinkingDelta"),
528 }
529 }
530 _ => panic!("Expected ContentBlockDelta event"),
531 }
532 }
533
534 #[test]
535 fn test_signature_delta_streaming_event_deserialization() {
536 let json = r#"{
537 "type": "content_block_delta",
538 "index": 0,
539 "delta": {
540 "type": "signature_delta",
541 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
542 }
543 }"#;
544
545 let event: StreamingEvent = serde_json::from_str(json).unwrap();
546
547 match event {
548 StreamingEvent::ContentBlockDelta { index, delta } => {
549 assert_eq!(index, 0);
550 match delta {
551 ContentDelta::SignatureDelta { signature } => {
552 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
553 }
554 _ => panic!("Expected SignatureDelta"),
555 }
556 }
557 _ => panic!("Expected ContentBlockDelta event"),
558 }
559 }
560
561 #[test]
562 fn test_handle_thinking_delta_event() {
563 let event = StreamingEvent::ContentBlockDelta {
564 index: 0,
565 delta: ContentDelta::ThinkingDelta {
566 thinking: "Analyzing the request...".to_string(),
567 },
568 };
569
570 let mut tool_call_state = None;
571 let mut thinking_state = None;
572 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
573
574 assert!(result.is_some());
575 let choice = result.unwrap().unwrap();
576
577 match choice {
578 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
579 assert_eq!(id, None);
580 assert_eq!(reasoning, "Analyzing the request...");
581 }
582 _ => panic!("Expected ReasoningDelta choice"),
583 }
584
585 assert!(thinking_state.is_some());
587 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
588 }
589
590 #[test]
591 fn test_handle_signature_delta_event() {
592 let event = StreamingEvent::ContentBlockDelta {
593 index: 0,
594 delta: ContentDelta::SignatureDelta {
595 signature: "test_signature".to_string(),
596 },
597 };
598
599 let mut tool_call_state = None;
600 let mut thinking_state = None;
601 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
602
603 assert!(result.is_none());
605
606 assert!(thinking_state.is_some());
608 assert_eq!(thinking_state.unwrap().signature, "test_signature");
609 }
610
611 #[test]
612 fn test_handle_redacted_thinking_content_block_start_event() {
613 let event = StreamingEvent::ContentBlockStart {
614 index: 0,
615 content_block: Content::RedactedThinking {
616 data: "redacted_blob".to_string(),
617 },
618 };
619 let mut tool_call_state = None;
620 let mut thinking_state = None;
621 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
622
623 assert!(result.is_some());
624 match result.unwrap().unwrap() {
625 RawStreamingChoice::Reasoning {
626 content: ReasoningContent::Redacted { data },
627 ..
628 } => {
629 assert_eq!(data, "redacted_blob");
630 }
631 _ => panic!("Expected Redacted reasoning chunk"),
632 }
633 }
634
635 #[test]
636 fn test_handle_text_delta_event() {
637 let event = StreamingEvent::ContentBlockDelta {
638 index: 0,
639 delta: ContentDelta::TextDelta {
640 text: "Hello, world!".to_string(),
641 },
642 };
643
644 let mut tool_call_state = None;
645 let mut thinking_state = None;
646 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
647
648 assert!(result.is_some());
649 let choice = result.unwrap().unwrap();
650
651 match choice {
652 RawStreamingChoice::Message(text) => {
653 assert_eq!(text, "Hello, world!");
654 }
655 _ => panic!("Expected Message choice"),
656 }
657 }
658
659 #[test]
660 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
661 let event = StreamingEvent::ContentBlockDelta {
663 index: 0,
664 delta: ContentDelta::ThinkingDelta {
665 thinking: "Thinking while tool is active...".to_string(),
666 },
667 };
668
669 let mut tool_call_state = Some(ToolCallState {
670 name: "test_tool".to_string(),
671 id: "tool_123".to_string(),
672 internal_call_id: nanoid::nanoid!(),
673 input_json: String::new(),
674 });
675 let mut thinking_state = None;
676
677 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
678
679 assert!(result.is_some());
680 let choice = result.unwrap().unwrap();
681
682 match choice {
683 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
684 assert_eq!(reasoning, "Thinking while tool is active...");
685 }
686 _ => panic!("Expected ReasoningDelta choice"),
687 }
688
689 assert!(tool_call_state.is_some());
691 }
692
693 #[test]
694 fn test_handle_input_json_delta_event() {
695 let event = StreamingEvent::ContentBlockDelta {
696 index: 0,
697 delta: ContentDelta::InputJsonDelta {
698 partial_json: "{\"arg\":\"value".to_string(),
699 },
700 };
701
702 let mut tool_call_state = Some(ToolCallState {
703 name: "test_tool".to_string(),
704 id: "tool_123".to_string(),
705 internal_call_id: nanoid::nanoid!(),
706 input_json: String::new(),
707 });
708 let mut thinking_state = None;
709
710 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
711
712 assert!(result.is_some());
714 let choice = result.unwrap().unwrap();
715
716 match choice {
717 RawStreamingChoice::ToolCallDelta {
718 id,
719 internal_call_id: _,
720 content,
721 } => {
722 assert_eq!(id, "tool_123");
723 match content {
724 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
725 _ => panic!("Expected Delta content"),
726 }
727 }
728 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
729 }
730
731 assert!(tool_call_state.is_some());
733 let state = tool_call_state.unwrap();
734 assert_eq!(state.input_json, "{\"arg\":\"value");
735 }
736
737 #[test]
738 fn test_tool_call_accumulation_with_multiple_deltas() {
739 let mut tool_call_state = Some(ToolCallState {
740 name: "test_tool".to_string(),
741 id: "tool_123".to_string(),
742 internal_call_id: nanoid::nanoid!(),
743 input_json: String::new(),
744 });
745 let mut thinking_state = None;
746
747 let event1 = StreamingEvent::ContentBlockDelta {
749 index: 0,
750 delta: ContentDelta::InputJsonDelta {
751 partial_json: "{\"location\":".to_string(),
752 },
753 };
754 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
755 assert!(result1.is_some());
756
757 let event2 = StreamingEvent::ContentBlockDelta {
759 index: 0,
760 delta: ContentDelta::InputJsonDelta {
761 partial_json: "\"Paris\",".to_string(),
762 },
763 };
764 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
765 assert!(result2.is_some());
766
767 let event3 = StreamingEvent::ContentBlockDelta {
769 index: 0,
770 delta: ContentDelta::InputJsonDelta {
771 partial_json: "\"temp\":\"20C\"}".to_string(),
772 },
773 };
774 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
775 assert!(result3.is_some());
776
777 assert!(tool_call_state.is_some());
779 let state = tool_call_state.as_ref().unwrap();
780 assert_eq!(
781 state.input_json,
782 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
783 );
784
785 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
787 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
788 assert!(final_result.is_some());
789
790 match final_result.unwrap().unwrap() {
791 RawStreamingChoice::ToolCall(RawStreamingToolCall {
792 id,
793 name,
794 arguments,
795 ..
796 }) => {
797 assert_eq!(id, "tool_123");
798 assert_eq!(name, "test_tool");
799 assert_eq!(
800 arguments.get("location").unwrap().as_str().unwrap(),
801 "Paris"
802 );
803 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
804 }
805 other => panic!("Expected ToolCall, got {:?}", other),
806 }
807
808 assert!(tool_call_state.is_none());
810 }
811}