1#![cfg(feature = "conversation")]
41
42use std::fmt;
43
44use crate::client::Client;
45use crate::conversation::{Conversation, UsageRecord};
46use crate::error::{Error, Result};
47use crate::messages::content::{ContentBlock, KnownBlock, ToolResultContent};
48use crate::messages::input::MessageInput;
49use crate::messages::response::Message;
50use crate::tool_dispatch::registry::ToolRegistry;
51use crate::types::StopReason;
52
53type IterationHook = Box<dyn Fn(&Message, u32) + Send + Sync + 'static>;
55
56type CheckpointHook = Box<dyn Fn(&Conversation) + Send + Sync + 'static>;
58
59#[cfg(feature = "pricing")]
62#[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
63pub struct CostBudget {
64 pub max_usd: f64,
66 pub pricing: crate::pricing::PricingTable,
68}
69
70pub struct RunOptions {
74 max_iterations: u32,
75 on_iteration: Option<IterationHook>,
76 on_checkpoint: Option<CheckpointHook>,
77 parallel_tool_dispatch: bool,
78 #[cfg(feature = "pricing")]
79 cost_budget: Option<CostBudget>,
80 cancel_token: Option<tokio_util::sync::CancellationToken>,
81 approver: Option<std::sync::Arc<dyn crate::tool_dispatch::ToolApprover>>,
82}
83
84impl Default for RunOptions {
85 fn default() -> Self {
86 Self {
87 max_iterations: 16,
88 on_iteration: None,
89 on_checkpoint: None,
90 parallel_tool_dispatch: true,
91 #[cfg(feature = "pricing")]
92 cost_budget: None,
93 cancel_token: None,
94 approver: None,
95 }
96 }
97}
98
99impl RunOptions {
100 #[must_use]
102 pub fn new() -> Self {
103 Self::default()
104 }
105
106 #[must_use]
108 pub fn max_iterations(mut self, max: u32) -> Self {
109 self.max_iterations = max;
110 self
111 }
112
113 #[must_use]
117 pub fn on_iteration<F>(mut self, hook: F) -> Self
118 where
119 F: Fn(&Message, u32) + Send + Sync + 'static,
120 {
121 self.on_iteration = Some(Box::new(hook));
122 self
123 }
124
125 #[must_use]
136 pub fn on_checkpoint<F>(mut self, hook: F) -> Self
137 where
138 F: Fn(&Conversation) + Send + Sync + 'static,
139 {
140 self.on_checkpoint = Some(Box::new(hook));
141 self
142 }
143
144 #[must_use]
149 pub fn parallel_tool_dispatch(mut self, parallel: bool) -> Self {
150 self.parallel_tool_dispatch = parallel;
151 self
152 }
153
154 #[cfg(feature = "pricing")]
159 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
160 #[must_use]
161 pub fn cost_budget(mut self, max_usd: f64, pricing: crate::pricing::PricingTable) -> Self {
162 self.cost_budget = Some(CostBudget { max_usd, pricing });
163 self
164 }
165
166 #[must_use]
170 pub fn cancel_token(mut self, token: tokio_util::sync::CancellationToken) -> Self {
171 self.cancel_token = Some(token);
172 self
173 }
174
175 #[must_use]
181 pub fn with_approver(
182 mut self,
183 approver: std::sync::Arc<dyn crate::tool_dispatch::ToolApprover>,
184 ) -> Self {
185 self.approver = Some(approver);
186 self
187 }
188
189 #[must_use]
191 pub fn with_approver_fn<F, Fut>(self, handler: F) -> Self
192 where
193 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
194 Fut: std::future::Future<Output = crate::tool_dispatch::ApprovalDecision> + Send + 'static,
195 {
196 self.with_approver(crate::tool_dispatch::fn_approver(handler))
197 }
198
199 #[must_use]
201 pub fn max_iterations_value(&self) -> u32 {
202 self.max_iterations
203 }
204}
205
206impl fmt::Debug for RunOptions {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 let mut s = f.debug_struct("RunOptions");
209 s.field("max_iterations", &self.max_iterations)
210 .field(
211 "on_iteration",
212 &self.on_iteration.as_ref().map(|_| "<closure>"),
213 )
214 .field(
215 "on_checkpoint",
216 &self.on_checkpoint.as_ref().map(|_| "<closure>"),
217 )
218 .field("parallel_tool_dispatch", &self.parallel_tool_dispatch)
219 .field("cancel_token", &self.cancel_token.is_some())
220 .field("approver", &self.approver.as_ref().map(|_| "<approver>"));
221 #[cfg(feature = "pricing")]
222 s.field("cost_budget", &self.cost_budget.as_ref().map(|b| b.max_usd));
223 s.finish()
224 }
225}
226
227impl Client {
228 #[allow(clippy::too_many_lines)] #[allow(clippy::missing_panics_doc)] pub async fn run(
250 &self,
251 conversation: &mut Conversation,
252 registry: &ToolRegistry,
253 options: RunOptions,
254 ) -> Result<Message> {
255 for iteration in 1..=options.max_iterations {
256 let span = tracing::info_span!("agent_iteration", iteration);
257 let _enter = span.enter();
258
259 if let Some(token) = &options.cancel_token
261 && token.is_cancelled()
262 {
263 tracing::info!(iteration, "claude-api: agent loop cancelled");
264 return Err(Error::Cancelled);
265 }
266
267 conversation.compact_if_needed();
271
272 let mut request = conversation.build_request();
276 request.tools = registry.to_messages_tools();
277
278 let response = self.messages().create(request).await?;
279
280 conversation.usage_history.push(UsageRecord {
282 model: conversation.model.clone(),
283 usage: response.usage.clone(),
284 });
285 conversation
286 .messages
287 .push(MessageInput::assistant(response.content.clone()));
288
289 if let Some(hook) = &options.on_iteration {
290 hook(&response, iteration);
291 }
292
293 #[cfg(feature = "pricing")]
295 if let Some(budget) = &options.cost_budget {
296 let spent = conversation.cost(&budget.pricing);
297 if spent > budget.max_usd {
298 tracing::warn!(
299 iteration,
300 spent_usd = spent,
301 budget_usd = budget.max_usd,
302 "claude-api: agent loop exceeded cost budget",
303 );
304 return Err(Error::CostBudgetExceeded {
305 budget_usd: budget.max_usd,
306 spent_usd: spent,
307 });
308 }
309 }
310
311 if response.stop_reason != Some(StopReason::ToolUse) {
312 if let Some(hook) = &options.on_checkpoint {
313 hook(conversation);
314 }
315 return Ok(response);
316 }
317
318 let tool_uses: Vec<(String, String, serde_json::Value)> = response
320 .content
321 .iter()
322 .filter_map(|b| {
323 if let ContentBlock::Known(KnownBlock::ToolUse { id, name, input }) = b {
324 Some((id.clone(), name.clone(), input.clone()))
325 } else {
326 None
327 }
328 })
329 .collect();
330
331 if tool_uses.is_empty() {
333 return Ok(response);
334 }
335
336 let mut plans: Vec<DispatchPlan> = Vec::with_capacity(tool_uses.len());
347 for (id, name, input) in &tool_uses {
348 let plan = if let Some(approver) = &options.approver {
349 match approver.approve(name, input).await {
350 crate::tool_dispatch::ApprovalDecision::Approve => DispatchPlan::Run {
351 id: id.clone(),
352 name: name.clone(),
353 input: input.clone(),
354 },
355 crate::tool_dispatch::ApprovalDecision::ApproveWithInput(new_input) => {
356 tracing::debug!(
357 tool = %name,
358 "claude-api: approver rewrote tool input"
359 );
360 DispatchPlan::Run {
361 id: id.clone(),
362 name: name.clone(),
363 input: new_input,
364 }
365 }
366 crate::tool_dispatch::ApprovalDecision::Substitute(value) => {
367 tracing::debug!(
368 tool = %name,
369 "claude-api: approver substituted result without dispatch"
370 );
371 DispatchPlan::ResultDirect {
372 id: id.clone(),
373 content: value_to_tool_result(value),
374 is_error: None,
375 }
376 }
377 crate::tool_dispatch::ApprovalDecision::Deny(reason) => {
378 tracing::info!(
379 tool = %name,
380 reason = %reason,
381 "claude-api: approver denied tool dispatch"
382 );
383 DispatchPlan::ResultDirect {
384 id: id.clone(),
385 content: ToolResultContent::Text(reason),
386 is_error: Some(true),
387 }
388 }
389 crate::tool_dispatch::ApprovalDecision::Stop(reason) => {
390 tracing::warn!(
391 tool = %name,
392 reason = %reason,
393 "claude-api: approver stopped the agent loop"
394 );
395 return Err(Error::ToolApprovalStopped {
396 tool_name: name.clone(),
397 reason,
398 });
399 }
400 }
401 } else {
402 DispatchPlan::Run {
403 id: id.clone(),
404 name: name.clone(),
405 input: input.clone(),
406 }
407 };
408 plans.push(plan);
409 }
410
411 let dispatched: Vec<(String, String, Result<serde_json::Value, _>)> =
415 if options.parallel_tool_dispatch {
416 let futures = plans
417 .iter()
418 .filter_map(|p| {
419 if let DispatchPlan::Run { id, name, input } = p {
420 Some((id.clone(), name.clone(), input.clone()))
421 } else {
422 None
423 }
424 })
425 .map(|(id, name, input)| async move {
426 let result = registry.dispatch(&name, input).await;
427 (id, name, result)
428 });
429 futures_util::future::join_all(futures).await
430 } else {
431 let mut out = Vec::new();
432 for p in &plans {
433 if let DispatchPlan::Run { id, name, input } = p {
434 let result = registry.dispatch(name, input.clone()).await;
435 out.push((id.clone(), name.clone(), result));
436 }
437 }
438 out
439 };
440
441 let mut dispatched_iter = dispatched.into_iter();
444 let mut tool_results: Vec<ContentBlock> = Vec::with_capacity(plans.len());
445 for plan in plans {
446 let (id, content, is_error) = match plan {
447 DispatchPlan::ResultDirect {
448 id,
449 content,
450 is_error,
451 } => (id, content, is_error),
452 DispatchPlan::Run { .. } => {
453 let (id, name, result) = dispatched_iter
456 .next()
457 .expect("dispatched/plans length mismatch");
458 match result {
459 Ok(value) => (id, value_to_tool_result(value), None),
460 Err(e) => {
461 tracing::warn!(
462 tool = %name,
463 error = %e,
464 "claude-api: tool dispatch error -- surfacing to model as is_error",
465 );
466 (id, ToolResultContent::Text(format!("{e}")), Some(true))
467 }
468 }
469 }
470 };
471 tool_results.push(ContentBlock::Known(KnownBlock::ToolResult {
472 tool_use_id: id,
473 content,
474 is_error,
475 cache_control: None,
476 }));
477 }
478
479 conversation.messages.push(MessageInput::user(tool_results));
480
481 if let Some(hook) = &options.on_checkpoint {
486 hook(conversation);
487 }
488 }
489
490 Err(Error::MaxIterationsExceeded {
491 max: options.max_iterations,
492 })
493 }
494}
495
496enum DispatchPlan {
500 Run {
501 id: String,
502 name: String,
503 input: serde_json::Value,
504 },
505 ResultDirect {
506 id: String,
507 content: ToolResultContent,
508 is_error: Option<bool>,
509 },
510}
511
512fn value_to_tool_result(value: serde_json::Value) -> ToolResultContent {
513 match value {
516 serde_json::Value::String(s) => ToolResultContent::Text(s),
517 other => ToolResultContent::Text(other.to_string()),
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use crate::conversation::Conversation;
525 use crate::messages::tools::Tool as MessagesTool;
526 use crate::tool_dispatch::ApprovalDecision;
527 use crate::tool_dispatch::tool::ToolError;
528 use crate::types::ModelId;
529 use pretty_assertions::assert_eq;
530 use serde_json::{Value, json};
531 use std::sync::Arc;
532 use std::sync::atomic::{AtomicU32, Ordering};
533 use wiremock::matchers::{body_partial_json, method, path};
534 use wiremock::{Mock, MockServer, ResponseTemplate};
535
536 fn client_for(mock: &MockServer) -> Client {
537 Client::builder()
538 .api_key("sk-ant-test")
539 .base_url(mock.uri())
540 .build()
541 .unwrap()
542 }
543
544 fn echo_registry() -> ToolRegistry {
545 let mut r = ToolRegistry::new();
546 r.register(
547 "echo",
548 json!({"type": "object", "properties": {"text": {"type": "string"}}}),
549 |input| async move { Ok(input) },
550 );
551 r
552 }
553
554 fn assistant_text(text: &str, stop: &str) -> Value {
555 json!({
556 "id": "msg_t",
557 "type": "message",
558 "role": "assistant",
559 "content": [{"type": "text", "text": text}],
560 "model": "claude-sonnet-4-6",
561 "stop_reason": stop,
562 "usage": {"input_tokens": 5, "output_tokens": 3}
563 })
564 }
565
566 #[allow(clippy::needless_pass_by_value)]
567 fn assistant_tool_use(id: &str, name: &str, input: Value) -> Value {
568 json!({
569 "id": "msg_t",
570 "type": "message",
571 "role": "assistant",
572 "content": [
573 {"type": "text", "text": "calling tool"},
574 {"type": "tool_use", "id": id, "name": name, "input": input}
575 ],
576 "model": "claude-sonnet-4-6",
577 "stop_reason": "tool_use",
578 "usage": {"input_tokens": 10, "output_tokens": 5}
579 })
580 }
581
582 #[tokio::test]
583 async fn single_turn_no_tools_returns_immediately() {
584 let mock = MockServer::start().await;
585 Mock::given(method("POST"))
586 .and(path("/v1/messages"))
587 .respond_with(
588 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
589 )
590 .expect(1)
591 .mount(&mock)
592 .await;
593
594 let client = client_for(&mock);
595 let registry = ToolRegistry::new();
596 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
597 convo.push_user("hi");
598
599 let resp = client
600 .run(&mut convo, ®istry, RunOptions::default())
601 .await
602 .unwrap();
603 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
604 assert_eq!(convo.turn_count(), 1);
605 assert_eq!(convo.messages.len(), 2);
607 }
608
609 #[tokio::test]
610 async fn two_turn_tool_use_loop_completes() {
611 let mock = MockServer::start().await;
612 Mock::given(method("POST"))
614 .and(path("/v1/messages"))
615 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
616 "toolu_1",
617 "echo",
618 json!({"text":"hello"}),
619 )))
620 .up_to_n_times(1)
621 .mount(&mock)
622 .await;
623 Mock::given(method("POST"))
625 .and(path("/v1/messages"))
626 .and(body_partial_json(json!({
627 "messages": [
628 {"role": "user", "content": "say hello"},
629 {"role": "assistant", "content": [
630 {"type": "text", "text": "calling tool"},
631 {"type": "tool_use", "id": "toolu_1", "name": "echo", "input": {"text":"hello"}}
632 ]},
633 {"role": "user", "content": [
634 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"hello\"}"}
635 ]}
636 ]
637 })))
638 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_text("said hello!", "end_turn")))
639 .mount(&mock)
640 .await;
641
642 let client = client_for(&mock);
643 let mut convo = Conversation::new(ModelId::SONNET_4_6, 256);
644 convo.push_user("say hello");
645
646 let resp = client
647 .run(&mut convo, &echo_registry(), RunOptions::default())
648 .await
649 .unwrap();
650
651 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
652 assert_eq!(convo.turn_count(), 2);
654 assert_eq!(convo.messages.len(), 4);
656 }
657
658 #[tokio::test]
659 async fn max_iterations_returns_error_and_records_each_turn() {
660 let mock = MockServer::start().await;
661 Mock::given(method("POST"))
663 .and(path("/v1/messages"))
664 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
665 "toolu_x",
666 "echo",
667 json!({"text":"x"}),
668 )))
669 .mount(&mock)
670 .await;
671
672 let client = client_for(&mock);
673 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
674 convo.push_user("loop");
675
676 let err = client
677 .run(
678 &mut convo,
679 &echo_registry(),
680 RunOptions::default().max_iterations(3),
681 )
682 .await
683 .unwrap_err();
684
685 let Error::MaxIterationsExceeded { max } = err else {
686 panic!("expected MaxIterationsExceeded, got {err:?}");
687 };
688 assert_eq!(max, 3);
689 assert_eq!(convo.turn_count(), 3);
690 assert_eq!(convo.messages.len(), 1 + 3 * 2);
692 }
693
694 #[tokio::test]
695 async fn tool_error_becomes_is_error_tool_result() {
696 let mock = MockServer::start().await;
697 Mock::given(method("POST"))
699 .and(path("/v1/messages"))
700 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
701 "toolu_e",
702 "boom",
703 json!({}),
704 )))
705 .up_to_n_times(1)
706 .mount(&mock)
707 .await;
708 Mock::given(method("POST"))
710 .and(path("/v1/messages"))
711 .and(body_partial_json(json!({
712 "messages": [
713 {"role": "user", "content": "fail"},
714 {"role": "assistant"},
715 {"role": "user", "content": [{
716 "type": "tool_result",
717 "tool_use_id": "toolu_e",
718 "is_error": true
719 }]}
720 ]
721 })))
722 .respond_with(
723 ResponseTemplate::new(200).set_body_json(assistant_text("recovered", "end_turn")),
724 )
725 .mount(&mock)
726 .await;
727
728 let client = client_for(&mock);
729 let mut registry = ToolRegistry::new();
730 registry.register("boom", json!({}), |_input| async move {
731 Err(ToolError::execution(std::io::Error::other("kaboom")))
732 });
733
734 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
735 convo.push_user("fail");
736 let resp = client
737 .run(&mut convo, ®istry, RunOptions::default())
738 .await
739 .unwrap();
740 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
741 }
742
743 #[tokio::test]
744 async fn unknown_tool_becomes_is_error_with_unknown_message() {
745 let mock = MockServer::start().await;
746 Mock::given(method("POST"))
747 .and(path("/v1/messages"))
748 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
749 "toolu_u",
750 "missing",
751 json!({}),
752 )))
753 .up_to_n_times(1)
754 .mount(&mock)
755 .await;
756 Mock::given(method("POST"))
757 .and(path("/v1/messages"))
758 .respond_with(
759 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
760 )
761 .mount(&mock)
762 .await;
763
764 let client = client_for(&mock);
765 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
766 convo.push_user("call missing");
767
768 let _ = client
769 .run(&mut convo, &ToolRegistry::new(), RunOptions::default())
770 .await
771 .unwrap();
772
773 let user_turn = &convo.messages[2];
775 let serialized = serde_json::to_string(&user_turn.content).unwrap();
776 assert!(
777 serialized.contains("no tool registered with name 'missing'"),
778 "{serialized}"
779 );
780 assert!(serialized.contains("\"is_error\":true"));
781 }
782
783 #[tokio::test]
784 async fn run_uses_registry_tools_not_conversation_tools() {
785 let mock = MockServer::start().await;
789 Mock::given(method("POST"))
790 .and(path("/v1/messages"))
791 .and(body_partial_json(json!({
792 "tools": [{"name": "echo"}]
793 })))
794 .respond_with(
795 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
796 )
797 .mount(&mock)
798 .await;
799
800 let client = client_for(&mock);
801 let mut convo =
803 Conversation::new(ModelId::SONNET_4_6, 64).with_tools(vec![MessagesTool::Custom(
804 crate::messages::tools::CustomTool::new("stale", json!({"type": "object"})),
805 )]);
806 convo.push_user("hi");
807
808 let _ = client
809 .run(&mut convo, &echo_registry(), RunOptions::default())
810 .await
811 .unwrap();
812 }
813
814 #[tokio::test]
815 async fn on_iteration_callback_fires_per_iteration() {
816 let mock = MockServer::start().await;
817 Mock::given(method("POST"))
818 .and(path("/v1/messages"))
819 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
820 "toolu_h",
821 "echo",
822 json!({"text":"x"}),
823 )))
824 .up_to_n_times(1)
825 .mount(&mock)
826 .await;
827 Mock::given(method("POST"))
828 .and(path("/v1/messages"))
829 .respond_with(
830 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
831 )
832 .mount(&mock)
833 .await;
834
835 let counter = Arc::new(AtomicU32::new(0));
836 let counter_clone = Arc::clone(&counter);
837 let options = RunOptions::default().on_iteration(move |_msg, n| {
838 counter_clone.fetch_add(1, Ordering::SeqCst);
839 assert!(n >= 1);
841 });
842
843 let client = client_for(&mock);
844 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
845 convo.push_user("hi");
846
847 let _ = client
848 .run(&mut convo, &echo_registry(), options)
849 .await
850 .unwrap();
851 assert_eq!(counter.load(Ordering::SeqCst), 2);
852 }
853
854 #[tokio::test]
857 async fn parallel_tool_dispatch_runs_concurrently() {
858 let mock = MockServer::start().await;
862 Mock::given(method("POST"))
863 .and(path("/v1/messages"))
864 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
865 "id": "msg_p",
866 "type": "message",
867 "role": "assistant",
868 "content": [
869 {"type": "tool_use", "id": "t1", "name": "slow", "input": {"k": 1}},
870 {"type": "tool_use", "id": "t2", "name": "slow", "input": {"k": 2}},
871 ],
872 "model": "claude-sonnet-4-6",
873 "stop_reason": "tool_use",
874 "usage": {"input_tokens": 10, "output_tokens": 5}
875 })))
876 .up_to_n_times(1)
877 .mount(&mock)
878 .await;
879 Mock::given(method("POST"))
880 .and(path("/v1/messages"))
881 .respond_with(
882 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
883 )
884 .mount(&mock)
885 .await;
886
887 let mut registry = ToolRegistry::new();
888 registry.register("slow", json!({}), |input| async move {
889 tokio::time::sleep(std::time::Duration::from_millis(80)).await;
890 Ok(input)
891 });
892
893 let client = client_for(&mock);
894 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
895 convo.push_user("call slow tools");
896
897 let started = std::time::Instant::now();
898 let _ = client
899 .run(&mut convo, ®istry, RunOptions::default())
900 .await
901 .unwrap();
902 let elapsed = started.elapsed();
903
904 assert!(
905 elapsed.as_millis() < 500,
906 "parallel dispatch should be fast; got {elapsed:?}"
907 );
908 assert!(
909 elapsed.as_millis() > 50,
910 "tools didn't actually run; got {elapsed:?}"
911 );
912 }
913
914 #[tokio::test]
915 async fn parallel_dispatch_can_be_disabled() {
916 let mock = MockServer::start().await;
920 Mock::given(method("POST"))
921 .and(path("/v1/messages"))
922 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
923 "id": "msg_seq",
924 "type": "message",
925 "role": "assistant",
926 "content": [
927 {"type": "tool_use", "id": "t1", "name": "echo", "input": {"v": "first"}},
928 {"type": "tool_use", "id": "t2", "name": "echo", "input": {"v": "second"}},
929 ],
930 "model": "claude-sonnet-4-6",
931 "stop_reason": "tool_use",
932 "usage": {"input_tokens": 10, "output_tokens": 5}
933 })))
934 .up_to_n_times(1)
935 .mount(&mock)
936 .await;
937 Mock::given(method("POST"))
938 .and(path("/v1/messages"))
939 .and(body_partial_json(json!({
940 "messages": [
941 {"role": "user"},
942 {"role": "assistant"},
943 {"role": "user", "content": [
944 {"type": "tool_result", "tool_use_id": "t1"},
945 {"type": "tool_result", "tool_use_id": "t2"}
946 ]}
947 ]
948 })))
949 .respond_with(
950 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
951 )
952 .mount(&mock)
953 .await;
954
955 let mut registry = ToolRegistry::new();
956 registry.register("echo", json!({}), |input| async move { Ok(input) });
957
958 let client = client_for(&mock);
959 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
960 convo.push_user("two tools");
961 let _ = client
962 .run(
963 &mut convo,
964 ®istry,
965 RunOptions::default().parallel_tool_dispatch(false),
966 )
967 .await
968 .unwrap();
969 }
970
971 #[cfg(feature = "pricing")]
972 #[tokio::test]
973 async fn cost_budget_aborts_loop_when_exceeded() {
974 let mock = MockServer::start().await;
976 Mock::given(method("POST"))
977 .and(path("/v1/messages"))
978 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
979 "id": "msg_b",
980 "type": "message",
981 "role": "assistant",
982 "content": [
983 {"type": "tool_use", "id": "t1", "name": "noop", "input": {}}
984 ],
985 "model": "claude-sonnet-4-6",
986 "stop_reason": "tool_use",
987 "usage": {"input_tokens": 1_000_000, "output_tokens": 0}
988 })))
989 .mount(&mock)
990 .await;
991
992 let mut registry = ToolRegistry::new();
993 registry.register("noop", json!({}), |_input| async move { Ok(json!({})) });
994
995 let client = client_for(&mock);
996 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
997 convo.push_user("burn money");
998
999 let err = client
1000 .run(
1001 &mut convo,
1002 ®istry,
1003 RunOptions::default()
1004 .max_iterations(8)
1005 .cost_budget(1.00, crate::pricing::PricingTable::default()),
1006 )
1007 .await
1008 .unwrap_err();
1009 let Error::CostBudgetExceeded {
1010 budget_usd,
1011 spent_usd,
1012 } = err
1013 else {
1014 panic!("expected CostBudgetExceeded, got {err:?}");
1015 };
1016 assert!((budget_usd - 1.00).abs() < 1e-9);
1018 assert!(
1019 spent_usd > 1.00,
1020 "spent_usd ({spent_usd}) should exceed budget"
1021 );
1022 }
1023
1024 #[tokio::test]
1025 async fn cancel_token_aborts_before_first_request() {
1026 let mock = MockServer::start().await;
1027 Mock::given(method("POST"))
1030 .and(path("/v1/messages"))
1031 .respond_with(
1032 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1033 )
1034 .expect(0)
1035 .mount(&mock)
1036 .await;
1037
1038 let token = tokio_util::sync::CancellationToken::new();
1039 token.cancel(); let client = client_for(&mock);
1042 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1043 convo.push_user("hi");
1044
1045 let err = client
1046 .run(
1047 &mut convo,
1048 &ToolRegistry::new(),
1049 RunOptions::default().cancel_token(token),
1050 )
1051 .await
1052 .unwrap_err();
1053 assert!(matches!(err, Error::Cancelled), "got {err:?}");
1054 }
1055
1056 #[tokio::test]
1057 async fn cancel_token_aborts_between_iterations() {
1058 let mock = MockServer::start().await;
1059 Mock::given(method("POST"))
1061 .and(path("/v1/messages"))
1062 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1063 "t1",
1064 "noop",
1065 json!({}),
1066 )))
1067 .up_to_n_times(1)
1068 .mount(&mock)
1069 .await;
1070 Mock::given(method("POST"))
1072 .and(path("/v1/messages"))
1073 .respond_with(
1074 ResponseTemplate::new(200).set_body_json(assistant_text("won't run", "end_turn")),
1075 )
1076 .expect(0)
1077 .mount(&mock)
1078 .await;
1079
1080 let token = tokio_util::sync::CancellationToken::new();
1081 let token_for_hook = token.clone();
1082
1083 let mut registry = ToolRegistry::new();
1084 registry.register("noop", json!({}), |_| async move { Ok(json!({})) });
1085
1086 let client = client_for(&mock);
1087 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1088 convo.push_user("hi");
1089
1090 let err = client
1091 .run(
1092 &mut convo,
1093 ®istry,
1094 RunOptions::default()
1095 .cancel_token(token)
1096 .on_iteration(move |_msg, _n| token_for_hook.cancel()),
1097 )
1098 .await
1099 .unwrap_err();
1100 assert!(matches!(err, Error::Cancelled), "got {err:?}");
1101 }
1102
1103 #[tokio::test]
1104 async fn on_checkpoint_fires_after_each_tool_result_turn_and_at_finish() {
1105 let mock = MockServer::start().await;
1106 Mock::given(method("POST"))
1108 .and(path("/v1/messages"))
1109 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1110 "toolu_1",
1111 "echo",
1112 json!({"text": "hi"}),
1113 )))
1114 .up_to_n_times(1)
1115 .mount(&mock)
1116 .await;
1117 Mock::given(method("POST"))
1119 .and(path("/v1/messages"))
1120 .respond_with(
1121 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1122 )
1123 .mount(&mock)
1124 .await;
1125
1126 let captured: Arc<std::sync::Mutex<Vec<usize>>> =
1127 Arc::new(std::sync::Mutex::new(Vec::new()));
1128 let sink = Arc::clone(&captured);
1129 let opts = RunOptions::default().on_checkpoint(move |c| {
1130 sink.lock().unwrap().push(c.messages.len());
1131 });
1132
1133 let client = client_for(&mock);
1134 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1135 convo.push_user("go");
1136 client
1137 .run(&mut convo, &echo_registry(), opts)
1138 .await
1139 .unwrap();
1140
1141 let snapshots = captured.lock().unwrap();
1147 assert_eq!(*snapshots, vec![3, 4]);
1148 }
1149
1150 #[tokio::test]
1151 async fn on_checkpoint_does_not_fire_when_unset() {
1152 let mock = MockServer::start().await;
1155 Mock::given(method("POST"))
1156 .and(path("/v1/messages"))
1157 .respond_with(
1158 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1159 )
1160 .mount(&mock)
1161 .await;
1162 let client = client_for(&mock);
1163 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1164 convo.push_user("hi");
1165 client
1166 .run(&mut convo, &ToolRegistry::new(), RunOptions::default())
1167 .await
1168 .unwrap();
1169 }
1170
1171 #[tokio::test]
1172 async fn checkpoint_supports_resume_via_serde() {
1173 let mock = MockServer::start().await;
1179 Mock::given(method("POST"))
1180 .and(path("/v1/messages"))
1181 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1182 "toolu_1",
1183 "echo",
1184 json!({"text": "first"}),
1185 )))
1186 .up_to_n_times(1)
1187 .mount(&mock)
1188 .await;
1189 Mock::given(method("POST"))
1190 .and(path("/v1/messages"))
1191 .respond_with(
1192 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1193 )
1194 .mount(&mock)
1195 .await;
1196
1197 let snapshot: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
1198 let sink = Arc::clone(&snapshot);
1199 let opts = RunOptions::default()
1200 .max_iterations(1)
1201 .on_checkpoint(move |c| {
1202 *sink.lock().unwrap() = Some(serde_json::to_string(c).unwrap());
1203 });
1204 let client = client_for(&mock);
1205 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1206 convo.push_user("go");
1207
1208 let _ = client.run(&mut convo, &echo_registry(), opts).await;
1212 let json = snapshot.lock().unwrap().clone().expect("checkpoint fired");
1213
1214 drop(convo);
1216 let mut resumed: Conversation = serde_json::from_str(&json).unwrap();
1217 let final_msg = client
1218 .run(
1219 &mut resumed,
1220 &echo_registry(),
1221 RunOptions::default().max_iterations(4),
1222 )
1223 .await
1224 .unwrap();
1225 assert_eq!(final_msg.stop_reason, Some(StopReason::EndTurn));
1226 assert!(resumed.messages.len() >= 4);
1228 }
1229
1230 #[tokio::test]
1231 async fn approver_approve_passes_through_to_dispatch() {
1232 let mock = MockServer::start().await;
1233 Mock::given(method("POST"))
1234 .and(path("/v1/messages"))
1235 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1236 "toolu_1",
1237 "echo",
1238 json!({"text": "hi"}),
1239 )))
1240 .up_to_n_times(1)
1241 .mount(&mock)
1242 .await;
1243 Mock::given(method("POST"))
1244 .and(path("/v1/messages"))
1245 .and(body_partial_json(json!({
1246 "messages": [
1247 {"role": "user", "content": "go"},
1248 {"role": "assistant"},
1249 {"role": "user", "content": [
1250 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"hi\"}"}
1251 ]}
1252 ]
1253 })))
1254 .respond_with(
1255 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1256 )
1257 .mount(&mock)
1258 .await;
1259
1260 let client = client_for(&mock);
1261 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1262 convo.push_user("go");
1263
1264 let opts = RunOptions::default()
1265 .with_approver_fn(|_name, _input| async { ApprovalDecision::Approve });
1266 let resp = client
1267 .run(&mut convo, &echo_registry(), opts)
1268 .await
1269 .unwrap();
1270 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
1271 }
1272
1273 #[tokio::test]
1274 async fn approver_approve_with_input_rewrites_dispatch_payload() {
1275 let mock = MockServer::start().await;
1276 Mock::given(method("POST"))
1277 .and(path("/v1/messages"))
1278 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1279 "toolu_1",
1280 "echo",
1281 json!({"text": "secret"}),
1282 )))
1283 .up_to_n_times(1)
1284 .mount(&mock)
1285 .await;
1286 Mock::given(method("POST"))
1288 .and(path("/v1/messages"))
1289 .and(body_partial_json(json!({
1290 "messages": [
1291 {"role": "user", "content": "go"},
1292 {"role": "assistant"},
1293 {"role": "user", "content": [
1294 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"REDACTED\"}"}
1295 ]}
1296 ]
1297 })))
1298 .respond_with(
1299 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1300 )
1301 .mount(&mock)
1302 .await;
1303
1304 let client = client_for(&mock);
1305 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1306 convo.push_user("go");
1307
1308 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1309 ApprovalDecision::ApproveWithInput(json!({"text": "REDACTED"}))
1310 });
1311 client
1312 .run(&mut convo, &echo_registry(), opts)
1313 .await
1314 .unwrap();
1315 }
1316
1317 #[tokio::test]
1318 async fn approver_substitute_skips_dispatch_and_returns_value() {
1319 let mut registry = ToolRegistry::new();
1321 registry.register("dangerous", json!({}), |_| async {
1322 panic!("dispatch should have been skipped by Substitute")
1323 });
1324
1325 let mock = MockServer::start().await;
1326 Mock::given(method("POST"))
1327 .and(path("/v1/messages"))
1328 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1329 "toolu_1",
1330 "dangerous",
1331 json!({"arg": 1}),
1332 )))
1333 .up_to_n_times(1)
1334 .mount(&mock)
1335 .await;
1336 Mock::given(method("POST"))
1337 .and(path("/v1/messages"))
1338 .and(body_partial_json(json!({
1339 "messages": [
1340 {"role": "user", "content": "go"},
1341 {"role": "assistant"},
1342 {"role": "user", "content": [
1343 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "stubbed"}
1344 ]}
1345 ]
1346 })))
1347 .respond_with(
1348 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1349 )
1350 .mount(&mock)
1351 .await;
1352
1353 let client = client_for(&mock);
1354 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1355 convo.push_user("go");
1356
1357 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1358 ApprovalDecision::Substitute(json!("stubbed"))
1359 });
1360 client.run(&mut convo, ®istry, opts).await.unwrap();
1361 }
1362
1363 #[tokio::test]
1364 async fn approver_deny_returns_is_error_tool_result_and_loop_continues() {
1365 let mock = MockServer::start().await;
1366 Mock::given(method("POST"))
1367 .and(path("/v1/messages"))
1368 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1369 "toolu_1",
1370 "echo",
1371 json!({"text": "hi"}),
1372 )))
1373 .up_to_n_times(1)
1374 .mount(&mock)
1375 .await;
1376 Mock::given(method("POST"))
1377 .and(path("/v1/messages"))
1378 .and(body_partial_json(json!({
1379 "messages": [
1380 {"role": "user", "content": "go"},
1381 {"role": "assistant"},
1382 {"role": "user", "content": [
1383 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "policy violation: no echo today", "is_error": true}
1384 ]}
1385 ]
1386 })))
1387 .respond_with(
1388 ResponseTemplate::new(200).set_body_json(assistant_text("ack", "end_turn")),
1389 )
1390 .mount(&mock)
1391 .await;
1392
1393 let client = client_for(&mock);
1394 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1395 convo.push_user("go");
1396
1397 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1398 ApprovalDecision::Deny("policy violation: no echo today".into())
1399 });
1400 let resp = client
1401 .run(&mut convo, &echo_registry(), opts)
1402 .await
1403 .unwrap();
1404 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
1405 }
1406
1407 #[tokio::test]
1408 async fn approver_stop_aborts_loop_with_typed_error() {
1409 let mock = MockServer::start().await;
1410 Mock::given(method("POST"))
1411 .and(path("/v1/messages"))
1412 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1413 "toolu_1",
1414 "echo",
1415 json!({"text": "hi"}),
1416 )))
1417 .mount(&mock)
1418 .await;
1419
1420 let client = client_for(&mock);
1421 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1422 convo.push_user("go");
1423
1424 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1425 ApprovalDecision::Stop("user cancelled".into())
1426 });
1427 let err = client
1428 .run(&mut convo, &echo_registry(), opts)
1429 .await
1430 .unwrap_err();
1431 match err {
1432 Error::ToolApprovalStopped { tool_name, reason } => {
1433 assert_eq!(tool_name, "echo");
1434 assert_eq!(reason, "user cancelled");
1435 }
1436 other => panic!("expected ToolApprovalStopped, got {other:?}"),
1437 }
1438 }
1439
1440 #[tokio::test]
1441 async fn approver_per_call_decision_can_mix_approve_and_deny() {
1442 let mock = MockServer::start().await;
1444 let dual_tool_use = json!({
1445 "id": "msg_t",
1446 "type": "message",
1447 "role": "assistant",
1448 "content": [
1449 {"type": "tool_use", "id": "toolu_1", "name": "echo", "input": {"text": "ok"}},
1450 {"type": "tool_use", "id": "toolu_2", "name": "echo", "input": {"text": "block"}},
1451 ],
1452 "model": "claude-sonnet-4-6",
1453 "stop_reason": "tool_use",
1454 "usage": {"input_tokens": 10, "output_tokens": 5}
1455 });
1456 Mock::given(method("POST"))
1457 .and(path("/v1/messages"))
1458 .respond_with(ResponseTemplate::new(200).set_body_json(dual_tool_use))
1459 .up_to_n_times(1)
1460 .mount(&mock)
1461 .await;
1462 Mock::given(method("POST"))
1463 .and(path("/v1/messages"))
1464 .respond_with(
1465 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1466 )
1467 .mount(&mock)
1468 .await;
1469
1470 let client = client_for(&mock);
1471 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1472 convo.push_user("go");
1473
1474 let opts = RunOptions::default().with_approver_fn(|_name, input| {
1475 let blocked = input.get("text").and_then(Value::as_str) == Some("block");
1476 async move {
1477 if blocked {
1478 ApprovalDecision::Deny("blocked".into())
1479 } else {
1480 ApprovalDecision::Approve
1481 }
1482 }
1483 });
1484 client
1485 .run(&mut convo, &echo_registry(), opts)
1486 .await
1487 .unwrap();
1488
1489 let tool_result_turn = &convo.messages[2];
1493 let serialized = serde_json::to_value(tool_result_turn).unwrap();
1494 let blocks = serialized
1495 .get("content")
1496 .and_then(Value::as_array)
1497 .expect("content array");
1498 assert_eq!(blocks.len(), 2);
1499 assert_eq!(blocks[0]["tool_use_id"], "toolu_1");
1500 assert!(blocks[0].get("is_error").is_none());
1501 assert_eq!(blocks[1]["tool_use_id"], "toolu_2");
1502 assert_eq!(blocks[1]["is_error"], true);
1503 assert_eq!(blocks[1]["content"], "blocked");
1504 }
1505
1506 #[tokio::test]
1507 async fn tool_returning_string_value_passes_through_cleanly() {
1508 let result = value_to_tool_result(json!("plain text"));
1510 let ToolResultContent::Text(t) = result else {
1511 panic!("expected Text");
1512 };
1513 assert_eq!(t, "plain text");
1514 }
1515
1516 #[tokio::test]
1517 async fn tool_returning_object_value_serializes_to_json_string() {
1518 let result = value_to_tool_result(json!({"k": 42}));
1519 let ToolResultContent::Text(t) = result else {
1520 panic!("expected Text");
1521 };
1522 let parsed: Value = serde_json::from_str(&t).unwrap();
1524 assert_eq!(parsed, json!({"k": 42}));
1525 }
1526}