1#![cfg(feature = "conversation")]
11
12use std::fmt;
13
14use crate::client::Client;
15use crate::conversation::{Conversation, UsageRecord};
16use crate::error::{Error, Result};
17use crate::messages::content::{ContentBlock, KnownBlock, ToolResultContent};
18use crate::messages::input::MessageInput;
19use crate::messages::response::Message;
20use crate::tool_dispatch::registry::ToolRegistry;
21use crate::types::StopReason;
22
23type IterationHook = Box<dyn Fn(&Message, u32) + Send + Sync + 'static>;
25
26type CheckpointHook = Box<dyn Fn(&Conversation) + Send + Sync + 'static>;
28
29#[cfg(feature = "pricing")]
32#[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
33pub struct CostBudget {
34 pub max_usd: f64,
36 pub pricing: crate::pricing::PricingTable,
38}
39
40pub struct RunOptions {
44 max_iterations: u32,
45 on_iteration: Option<IterationHook>,
46 on_checkpoint: Option<CheckpointHook>,
47 parallel_tool_dispatch: bool,
48 #[cfg(feature = "pricing")]
49 cost_budget: Option<CostBudget>,
50 cancel_token: Option<tokio_util::sync::CancellationToken>,
51 approver: Option<std::sync::Arc<dyn crate::tool_dispatch::ToolApprover>>,
52}
53
54impl Default for RunOptions {
55 fn default() -> Self {
56 Self {
57 max_iterations: 16,
58 on_iteration: None,
59 on_checkpoint: None,
60 parallel_tool_dispatch: true,
61 #[cfg(feature = "pricing")]
62 cost_budget: None,
63 cancel_token: None,
64 approver: None,
65 }
66 }
67}
68
69impl RunOptions {
70 #[must_use]
72 pub fn new() -> Self {
73 Self::default()
74 }
75
76 #[must_use]
78 pub fn max_iterations(mut self, max: u32) -> Self {
79 self.max_iterations = max;
80 self
81 }
82
83 #[must_use]
87 pub fn on_iteration<F>(mut self, hook: F) -> Self
88 where
89 F: Fn(&Message, u32) + Send + Sync + 'static,
90 {
91 self.on_iteration = Some(Box::new(hook));
92 self
93 }
94
95 #[must_use]
106 pub fn on_checkpoint<F>(mut self, hook: F) -> Self
107 where
108 F: Fn(&Conversation) + Send + Sync + 'static,
109 {
110 self.on_checkpoint = Some(Box::new(hook));
111 self
112 }
113
114 #[must_use]
119 pub fn parallel_tool_dispatch(mut self, parallel: bool) -> Self {
120 self.parallel_tool_dispatch = parallel;
121 self
122 }
123
124 #[cfg(feature = "pricing")]
129 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
130 #[must_use]
131 pub fn cost_budget(mut self, max_usd: f64, pricing: crate::pricing::PricingTable) -> Self {
132 self.cost_budget = Some(CostBudget { max_usd, pricing });
133 self
134 }
135
136 #[must_use]
140 pub fn cancel_token(mut self, token: tokio_util::sync::CancellationToken) -> Self {
141 self.cancel_token = Some(token);
142 self
143 }
144
145 #[must_use]
151 pub fn with_approver(
152 mut self,
153 approver: std::sync::Arc<dyn crate::tool_dispatch::ToolApprover>,
154 ) -> Self {
155 self.approver = Some(approver);
156 self
157 }
158
159 #[must_use]
161 pub fn with_approver_fn<F, Fut>(self, handler: F) -> Self
162 where
163 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
164 Fut: std::future::Future<Output = crate::tool_dispatch::ApprovalDecision> + Send + 'static,
165 {
166 self.with_approver(crate::tool_dispatch::fn_approver(handler))
167 }
168
169 #[must_use]
171 pub fn max_iterations_value(&self) -> u32 {
172 self.max_iterations
173 }
174}
175
176impl fmt::Debug for RunOptions {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 let mut s = f.debug_struct("RunOptions");
179 s.field("max_iterations", &self.max_iterations)
180 .field(
181 "on_iteration",
182 &self.on_iteration.as_ref().map(|_| "<closure>"),
183 )
184 .field(
185 "on_checkpoint",
186 &self.on_checkpoint.as_ref().map(|_| "<closure>"),
187 )
188 .field("parallel_tool_dispatch", &self.parallel_tool_dispatch)
189 .field("cancel_token", &self.cancel_token.is_some())
190 .field("approver", &self.approver.as_ref().map(|_| "<approver>"));
191 #[cfg(feature = "pricing")]
192 s.field("cost_budget", &self.cost_budget.as_ref().map(|b| b.max_usd));
193 s.finish()
194 }
195}
196
197impl Client {
198 #[allow(clippy::too_many_lines)] #[allow(clippy::missing_panics_doc)] pub async fn run(
220 &self,
221 conversation: &mut Conversation,
222 registry: &ToolRegistry,
223 options: RunOptions,
224 ) -> Result<Message> {
225 for iteration in 1..=options.max_iterations {
226 let span = tracing::info_span!("agent_iteration", iteration);
227 let _enter = span.enter();
228
229 if let Some(token) = &options.cancel_token
231 && token.is_cancelled()
232 {
233 tracing::info!(iteration, "claude-api: agent loop cancelled");
234 return Err(Error::Cancelled);
235 }
236
237 conversation.compact_if_needed();
241
242 let mut request = conversation.build_request();
246 request.tools = registry.to_messages_tools();
247
248 let response = self.messages().create(request).await?;
249
250 conversation.usage_history.push(UsageRecord {
252 model: conversation.model.clone(),
253 usage: response.usage.clone(),
254 });
255 conversation
256 .messages
257 .push(MessageInput::assistant(response.content.clone()));
258
259 if let Some(hook) = &options.on_iteration {
260 hook(&response, iteration);
261 }
262
263 #[cfg(feature = "pricing")]
265 if let Some(budget) = &options.cost_budget {
266 let spent = conversation.cost(&budget.pricing);
267 if spent > budget.max_usd {
268 tracing::warn!(
269 iteration,
270 spent_usd = spent,
271 budget_usd = budget.max_usd,
272 "claude-api: agent loop exceeded cost budget",
273 );
274 return Err(Error::CostBudgetExceeded {
275 budget_usd: budget.max_usd,
276 spent_usd: spent,
277 });
278 }
279 }
280
281 if response.stop_reason != Some(StopReason::ToolUse) {
282 if let Some(hook) = &options.on_checkpoint {
283 hook(conversation);
284 }
285 return Ok(response);
286 }
287
288 let tool_uses: Vec<(String, String, serde_json::Value)> = response
290 .content
291 .iter()
292 .filter_map(|b| {
293 if let ContentBlock::Known(KnownBlock::ToolUse { id, name, input }) = b {
294 Some((id.clone(), name.clone(), input.clone()))
295 } else {
296 None
297 }
298 })
299 .collect();
300
301 if tool_uses.is_empty() {
303 return Ok(response);
304 }
305
306 let mut plans: Vec<DispatchPlan> = Vec::with_capacity(tool_uses.len());
317 for (id, name, input) in &tool_uses {
318 let plan = if let Some(approver) = &options.approver {
319 match approver.approve(name, input).await {
320 crate::tool_dispatch::ApprovalDecision::Approve => DispatchPlan::Run {
321 id: id.clone(),
322 name: name.clone(),
323 input: input.clone(),
324 },
325 crate::tool_dispatch::ApprovalDecision::ApproveWithInput(new_input) => {
326 tracing::debug!(
327 tool = %name,
328 "claude-api: approver rewrote tool input"
329 );
330 DispatchPlan::Run {
331 id: id.clone(),
332 name: name.clone(),
333 input: new_input,
334 }
335 }
336 crate::tool_dispatch::ApprovalDecision::Substitute(value) => {
337 tracing::debug!(
338 tool = %name,
339 "claude-api: approver substituted result without dispatch"
340 );
341 DispatchPlan::ResultDirect {
342 id: id.clone(),
343 content: value_to_tool_result(value),
344 is_error: None,
345 }
346 }
347 crate::tool_dispatch::ApprovalDecision::Deny(reason) => {
348 tracing::info!(
349 tool = %name,
350 reason = %reason,
351 "claude-api: approver denied tool dispatch"
352 );
353 DispatchPlan::ResultDirect {
354 id: id.clone(),
355 content: ToolResultContent::Text(reason),
356 is_error: Some(true),
357 }
358 }
359 crate::tool_dispatch::ApprovalDecision::Stop(reason) => {
360 tracing::warn!(
361 tool = %name,
362 reason = %reason,
363 "claude-api: approver stopped the agent loop"
364 );
365 return Err(Error::ToolApprovalStopped {
366 tool_name: name.clone(),
367 reason,
368 });
369 }
370 }
371 } else {
372 DispatchPlan::Run {
373 id: id.clone(),
374 name: name.clone(),
375 input: input.clone(),
376 }
377 };
378 plans.push(plan);
379 }
380
381 let dispatched: Vec<(String, String, Result<serde_json::Value, _>)> =
385 if options.parallel_tool_dispatch {
386 let futures = plans
387 .iter()
388 .filter_map(|p| {
389 if let DispatchPlan::Run { id, name, input } = p {
390 Some((id.clone(), name.clone(), input.clone()))
391 } else {
392 None
393 }
394 })
395 .map(|(id, name, input)| async move {
396 let result = registry.dispatch(&name, input).await;
397 (id, name, result)
398 });
399 futures_util::future::join_all(futures).await
400 } else {
401 let mut out = Vec::new();
402 for p in &plans {
403 if let DispatchPlan::Run { id, name, input } = p {
404 let result = registry.dispatch(name, input.clone()).await;
405 out.push((id.clone(), name.clone(), result));
406 }
407 }
408 out
409 };
410
411 let mut dispatched_iter = dispatched.into_iter();
414 let mut tool_results: Vec<ContentBlock> = Vec::with_capacity(plans.len());
415 for plan in plans {
416 let (id, content, is_error) = match plan {
417 DispatchPlan::ResultDirect {
418 id,
419 content,
420 is_error,
421 } => (id, content, is_error),
422 DispatchPlan::Run { .. } => {
423 let (id, name, result) = dispatched_iter
426 .next()
427 .expect("dispatched/plans length mismatch");
428 match result {
429 Ok(value) => (id, value_to_tool_result(value), None),
430 Err(e) => {
431 tracing::warn!(
432 tool = %name,
433 error = %e,
434 "claude-api: tool dispatch error -- surfacing to model as is_error",
435 );
436 (id, ToolResultContent::Text(format!("{e}")), Some(true))
437 }
438 }
439 }
440 };
441 tool_results.push(ContentBlock::Known(KnownBlock::ToolResult {
442 tool_use_id: id,
443 content,
444 is_error,
445 cache_control: None,
446 }));
447 }
448
449 conversation.messages.push(MessageInput::user(tool_results));
450
451 if let Some(hook) = &options.on_checkpoint {
456 hook(conversation);
457 }
458 }
459
460 Err(Error::MaxIterationsExceeded {
461 max: options.max_iterations,
462 })
463 }
464}
465
466enum DispatchPlan {
470 Run {
471 id: String,
472 name: String,
473 input: serde_json::Value,
474 },
475 ResultDirect {
476 id: String,
477 content: ToolResultContent,
478 is_error: Option<bool>,
479 },
480}
481
482fn value_to_tool_result(value: serde_json::Value) -> ToolResultContent {
483 match value {
486 serde_json::Value::String(s) => ToolResultContent::Text(s),
487 other => ToolResultContent::Text(other.to_string()),
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::conversation::Conversation;
495 use crate::messages::tools::Tool as MessagesTool;
496 use crate::tool_dispatch::ApprovalDecision;
497 use crate::tool_dispatch::tool::ToolError;
498 use crate::types::ModelId;
499 use pretty_assertions::assert_eq;
500 use serde_json::{Value, json};
501 use std::sync::Arc;
502 use std::sync::atomic::{AtomicU32, Ordering};
503 use wiremock::matchers::{body_partial_json, method, path};
504 use wiremock::{Mock, MockServer, ResponseTemplate};
505
506 fn client_for(mock: &MockServer) -> Client {
507 Client::builder()
508 .api_key("sk-ant-test")
509 .base_url(mock.uri())
510 .build()
511 .unwrap()
512 }
513
514 fn echo_registry() -> ToolRegistry {
515 let mut r = ToolRegistry::new();
516 r.register(
517 "echo",
518 json!({"type": "object", "properties": {"text": {"type": "string"}}}),
519 |input| async move { Ok(input) },
520 );
521 r
522 }
523
524 fn assistant_text(text: &str, stop: &str) -> Value {
525 json!({
526 "id": "msg_t",
527 "type": "message",
528 "role": "assistant",
529 "content": [{"type": "text", "text": text}],
530 "model": "claude-sonnet-4-6",
531 "stop_reason": stop,
532 "usage": {"input_tokens": 5, "output_tokens": 3}
533 })
534 }
535
536 #[allow(clippy::needless_pass_by_value)]
537 fn assistant_tool_use(id: &str, name: &str, input: Value) -> Value {
538 json!({
539 "id": "msg_t",
540 "type": "message",
541 "role": "assistant",
542 "content": [
543 {"type": "text", "text": "calling tool"},
544 {"type": "tool_use", "id": id, "name": name, "input": input}
545 ],
546 "model": "claude-sonnet-4-6",
547 "stop_reason": "tool_use",
548 "usage": {"input_tokens": 10, "output_tokens": 5}
549 })
550 }
551
552 #[tokio::test]
553 async fn single_turn_no_tools_returns_immediately() {
554 let mock = MockServer::start().await;
555 Mock::given(method("POST"))
556 .and(path("/v1/messages"))
557 .respond_with(
558 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
559 )
560 .expect(1)
561 .mount(&mock)
562 .await;
563
564 let client = client_for(&mock);
565 let registry = ToolRegistry::new();
566 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
567 convo.push_user("hi");
568
569 let resp = client
570 .run(&mut convo, ®istry, RunOptions::default())
571 .await
572 .unwrap();
573 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
574 assert_eq!(convo.turn_count(), 1);
575 assert_eq!(convo.messages.len(), 2);
577 }
578
579 #[tokio::test]
580 async fn two_turn_tool_use_loop_completes() {
581 let mock = MockServer::start().await;
582 Mock::given(method("POST"))
584 .and(path("/v1/messages"))
585 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
586 "toolu_1",
587 "echo",
588 json!({"text":"hello"}),
589 )))
590 .up_to_n_times(1)
591 .mount(&mock)
592 .await;
593 Mock::given(method("POST"))
595 .and(path("/v1/messages"))
596 .and(body_partial_json(json!({
597 "messages": [
598 {"role": "user", "content": "say hello"},
599 {"role": "assistant", "content": [
600 {"type": "text", "text": "calling tool"},
601 {"type": "tool_use", "id": "toolu_1", "name": "echo", "input": {"text":"hello"}}
602 ]},
603 {"role": "user", "content": [
604 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"hello\"}"}
605 ]}
606 ]
607 })))
608 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_text("said hello!", "end_turn")))
609 .mount(&mock)
610 .await;
611
612 let client = client_for(&mock);
613 let mut convo = Conversation::new(ModelId::SONNET_4_6, 256);
614 convo.push_user("say hello");
615
616 let resp = client
617 .run(&mut convo, &echo_registry(), RunOptions::default())
618 .await
619 .unwrap();
620
621 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
622 assert_eq!(convo.turn_count(), 2);
624 assert_eq!(convo.messages.len(), 4);
626 }
627
628 #[tokio::test]
629 async fn max_iterations_returns_error_and_records_each_turn() {
630 let mock = MockServer::start().await;
631 Mock::given(method("POST"))
633 .and(path("/v1/messages"))
634 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
635 "toolu_x",
636 "echo",
637 json!({"text":"x"}),
638 )))
639 .mount(&mock)
640 .await;
641
642 let client = client_for(&mock);
643 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
644 convo.push_user("loop");
645
646 let err = client
647 .run(
648 &mut convo,
649 &echo_registry(),
650 RunOptions::default().max_iterations(3),
651 )
652 .await
653 .unwrap_err();
654
655 let Error::MaxIterationsExceeded { max } = err else {
656 panic!("expected MaxIterationsExceeded, got {err:?}");
657 };
658 assert_eq!(max, 3);
659 assert_eq!(convo.turn_count(), 3);
660 assert_eq!(convo.messages.len(), 1 + 3 * 2);
662 }
663
664 #[tokio::test]
665 async fn tool_error_becomes_is_error_tool_result() {
666 let mock = MockServer::start().await;
667 Mock::given(method("POST"))
669 .and(path("/v1/messages"))
670 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
671 "toolu_e",
672 "boom",
673 json!({}),
674 )))
675 .up_to_n_times(1)
676 .mount(&mock)
677 .await;
678 Mock::given(method("POST"))
680 .and(path("/v1/messages"))
681 .and(body_partial_json(json!({
682 "messages": [
683 {"role": "user", "content": "fail"},
684 {"role": "assistant"},
685 {"role": "user", "content": [{
686 "type": "tool_result",
687 "tool_use_id": "toolu_e",
688 "is_error": true
689 }]}
690 ]
691 })))
692 .respond_with(
693 ResponseTemplate::new(200).set_body_json(assistant_text("recovered", "end_turn")),
694 )
695 .mount(&mock)
696 .await;
697
698 let client = client_for(&mock);
699 let mut registry = ToolRegistry::new();
700 registry.register("boom", json!({}), |_input| async move {
701 Err(ToolError::execution(std::io::Error::other("kaboom")))
702 });
703
704 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
705 convo.push_user("fail");
706 let resp = client
707 .run(&mut convo, ®istry, RunOptions::default())
708 .await
709 .unwrap();
710 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
711 }
712
713 #[tokio::test]
714 async fn unknown_tool_becomes_is_error_with_unknown_message() {
715 let mock = MockServer::start().await;
716 Mock::given(method("POST"))
717 .and(path("/v1/messages"))
718 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
719 "toolu_u",
720 "missing",
721 json!({}),
722 )))
723 .up_to_n_times(1)
724 .mount(&mock)
725 .await;
726 Mock::given(method("POST"))
727 .and(path("/v1/messages"))
728 .respond_with(
729 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
730 )
731 .mount(&mock)
732 .await;
733
734 let client = client_for(&mock);
735 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
736 convo.push_user("call missing");
737
738 let _ = client
739 .run(&mut convo, &ToolRegistry::new(), RunOptions::default())
740 .await
741 .unwrap();
742
743 let user_turn = &convo.messages[2];
745 let serialized = serde_json::to_string(&user_turn.content).unwrap();
746 assert!(
747 serialized.contains("no tool registered with name 'missing'"),
748 "{serialized}"
749 );
750 assert!(serialized.contains("\"is_error\":true"));
751 }
752
753 #[tokio::test]
754 async fn run_uses_registry_tools_not_conversation_tools() {
755 let mock = MockServer::start().await;
759 Mock::given(method("POST"))
760 .and(path("/v1/messages"))
761 .and(body_partial_json(json!({
762 "tools": [{"name": "echo"}]
763 })))
764 .respond_with(
765 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
766 )
767 .mount(&mock)
768 .await;
769
770 let client = client_for(&mock);
771 let mut convo =
773 Conversation::new(ModelId::SONNET_4_6, 64).with_tools(vec![MessagesTool::Custom(
774 crate::messages::tools::CustomTool::new("stale", json!({"type": "object"})),
775 )]);
776 convo.push_user("hi");
777
778 let _ = client
779 .run(&mut convo, &echo_registry(), RunOptions::default())
780 .await
781 .unwrap();
782 }
783
784 #[tokio::test]
785 async fn on_iteration_callback_fires_per_iteration() {
786 let mock = MockServer::start().await;
787 Mock::given(method("POST"))
788 .and(path("/v1/messages"))
789 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
790 "toolu_h",
791 "echo",
792 json!({"text":"x"}),
793 )))
794 .up_to_n_times(1)
795 .mount(&mock)
796 .await;
797 Mock::given(method("POST"))
798 .and(path("/v1/messages"))
799 .respond_with(
800 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
801 )
802 .mount(&mock)
803 .await;
804
805 let counter = Arc::new(AtomicU32::new(0));
806 let counter_clone = Arc::clone(&counter);
807 let options = RunOptions::default().on_iteration(move |_msg, n| {
808 counter_clone.fetch_add(1, Ordering::SeqCst);
809 assert!(n >= 1);
811 });
812
813 let client = client_for(&mock);
814 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
815 convo.push_user("hi");
816
817 let _ = client
818 .run(&mut convo, &echo_registry(), options)
819 .await
820 .unwrap();
821 assert_eq!(counter.load(Ordering::SeqCst), 2);
822 }
823
824 #[tokio::test]
827 async fn parallel_tool_dispatch_runs_concurrently() {
828 let mock = MockServer::start().await;
832 Mock::given(method("POST"))
833 .and(path("/v1/messages"))
834 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
835 "id": "msg_p",
836 "type": "message",
837 "role": "assistant",
838 "content": [
839 {"type": "tool_use", "id": "t1", "name": "slow", "input": {"k": 1}},
840 {"type": "tool_use", "id": "t2", "name": "slow", "input": {"k": 2}},
841 ],
842 "model": "claude-sonnet-4-6",
843 "stop_reason": "tool_use",
844 "usage": {"input_tokens": 10, "output_tokens": 5}
845 })))
846 .up_to_n_times(1)
847 .mount(&mock)
848 .await;
849 Mock::given(method("POST"))
850 .and(path("/v1/messages"))
851 .respond_with(
852 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
853 )
854 .mount(&mock)
855 .await;
856
857 let mut registry = ToolRegistry::new();
858 registry.register("slow", json!({}), |input| async move {
859 tokio::time::sleep(std::time::Duration::from_millis(80)).await;
860 Ok(input)
861 });
862
863 let client = client_for(&mock);
864 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
865 convo.push_user("call slow tools");
866
867 let started = std::time::Instant::now();
868 let _ = client
869 .run(&mut convo, ®istry, RunOptions::default())
870 .await
871 .unwrap();
872 let elapsed = started.elapsed();
873
874 assert!(
875 elapsed.as_millis() < 500,
876 "parallel dispatch should be fast; got {elapsed:?}"
877 );
878 assert!(
879 elapsed.as_millis() > 50,
880 "tools didn't actually run; got {elapsed:?}"
881 );
882 }
883
884 #[tokio::test]
885 async fn parallel_dispatch_can_be_disabled() {
886 let mock = MockServer::start().await;
890 Mock::given(method("POST"))
891 .and(path("/v1/messages"))
892 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
893 "id": "msg_seq",
894 "type": "message",
895 "role": "assistant",
896 "content": [
897 {"type": "tool_use", "id": "t1", "name": "echo", "input": {"v": "first"}},
898 {"type": "tool_use", "id": "t2", "name": "echo", "input": {"v": "second"}},
899 ],
900 "model": "claude-sonnet-4-6",
901 "stop_reason": "tool_use",
902 "usage": {"input_tokens": 10, "output_tokens": 5}
903 })))
904 .up_to_n_times(1)
905 .mount(&mock)
906 .await;
907 Mock::given(method("POST"))
908 .and(path("/v1/messages"))
909 .and(body_partial_json(json!({
910 "messages": [
911 {"role": "user"},
912 {"role": "assistant"},
913 {"role": "user", "content": [
914 {"type": "tool_result", "tool_use_id": "t1"},
915 {"type": "tool_result", "tool_use_id": "t2"}
916 ]}
917 ]
918 })))
919 .respond_with(
920 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
921 )
922 .mount(&mock)
923 .await;
924
925 let mut registry = ToolRegistry::new();
926 registry.register("echo", json!({}), |input| async move { Ok(input) });
927
928 let client = client_for(&mock);
929 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
930 convo.push_user("two tools");
931 let _ = client
932 .run(
933 &mut convo,
934 ®istry,
935 RunOptions::default().parallel_tool_dispatch(false),
936 )
937 .await
938 .unwrap();
939 }
940
941 #[cfg(feature = "pricing")]
942 #[tokio::test]
943 async fn cost_budget_aborts_loop_when_exceeded() {
944 let mock = MockServer::start().await;
946 Mock::given(method("POST"))
947 .and(path("/v1/messages"))
948 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
949 "id": "msg_b",
950 "type": "message",
951 "role": "assistant",
952 "content": [
953 {"type": "tool_use", "id": "t1", "name": "noop", "input": {}}
954 ],
955 "model": "claude-sonnet-4-6",
956 "stop_reason": "tool_use",
957 "usage": {"input_tokens": 1_000_000, "output_tokens": 0}
958 })))
959 .mount(&mock)
960 .await;
961
962 let mut registry = ToolRegistry::new();
963 registry.register("noop", json!({}), |_input| async move { Ok(json!({})) });
964
965 let client = client_for(&mock);
966 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
967 convo.push_user("burn money");
968
969 let err = client
970 .run(
971 &mut convo,
972 ®istry,
973 RunOptions::default()
974 .max_iterations(8)
975 .cost_budget(1.00, crate::pricing::PricingTable::default()),
976 )
977 .await
978 .unwrap_err();
979 let Error::CostBudgetExceeded {
980 budget_usd,
981 spent_usd,
982 } = err
983 else {
984 panic!("expected CostBudgetExceeded, got {err:?}");
985 };
986 assert!((budget_usd - 1.00).abs() < 1e-9);
988 assert!(
989 spent_usd > 1.00,
990 "spent_usd ({spent_usd}) should exceed budget"
991 );
992 }
993
994 #[tokio::test]
995 async fn cancel_token_aborts_before_first_request() {
996 let mock = MockServer::start().await;
997 Mock::given(method("POST"))
1000 .and(path("/v1/messages"))
1001 .respond_with(
1002 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1003 )
1004 .expect(0)
1005 .mount(&mock)
1006 .await;
1007
1008 let token = tokio_util::sync::CancellationToken::new();
1009 token.cancel(); let client = client_for(&mock);
1012 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1013 convo.push_user("hi");
1014
1015 let err = client
1016 .run(
1017 &mut convo,
1018 &ToolRegistry::new(),
1019 RunOptions::default().cancel_token(token),
1020 )
1021 .await
1022 .unwrap_err();
1023 assert!(matches!(err, Error::Cancelled), "got {err:?}");
1024 }
1025
1026 #[tokio::test]
1027 async fn cancel_token_aborts_between_iterations() {
1028 let mock = MockServer::start().await;
1029 Mock::given(method("POST"))
1031 .and(path("/v1/messages"))
1032 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1033 "t1",
1034 "noop",
1035 json!({}),
1036 )))
1037 .up_to_n_times(1)
1038 .mount(&mock)
1039 .await;
1040 Mock::given(method("POST"))
1042 .and(path("/v1/messages"))
1043 .respond_with(
1044 ResponseTemplate::new(200).set_body_json(assistant_text("won't run", "end_turn")),
1045 )
1046 .expect(0)
1047 .mount(&mock)
1048 .await;
1049
1050 let token = tokio_util::sync::CancellationToken::new();
1051 let token_for_hook = token.clone();
1052
1053 let mut registry = ToolRegistry::new();
1054 registry.register("noop", json!({}), |_| async move { Ok(json!({})) });
1055
1056 let client = client_for(&mock);
1057 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1058 convo.push_user("hi");
1059
1060 let err = client
1061 .run(
1062 &mut convo,
1063 ®istry,
1064 RunOptions::default()
1065 .cancel_token(token)
1066 .on_iteration(move |_msg, _n| token_for_hook.cancel()),
1067 )
1068 .await
1069 .unwrap_err();
1070 assert!(matches!(err, Error::Cancelled), "got {err:?}");
1071 }
1072
1073 #[tokio::test]
1074 async fn on_checkpoint_fires_after_each_tool_result_turn_and_at_finish() {
1075 let mock = MockServer::start().await;
1076 Mock::given(method("POST"))
1078 .and(path("/v1/messages"))
1079 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1080 "toolu_1",
1081 "echo",
1082 json!({"text": "hi"}),
1083 )))
1084 .up_to_n_times(1)
1085 .mount(&mock)
1086 .await;
1087 Mock::given(method("POST"))
1089 .and(path("/v1/messages"))
1090 .respond_with(
1091 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1092 )
1093 .mount(&mock)
1094 .await;
1095
1096 let captured: Arc<std::sync::Mutex<Vec<usize>>> =
1097 Arc::new(std::sync::Mutex::new(Vec::new()));
1098 let sink = Arc::clone(&captured);
1099 let opts = RunOptions::default().on_checkpoint(move |c| {
1100 sink.lock().unwrap().push(c.messages.len());
1101 });
1102
1103 let client = client_for(&mock);
1104 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1105 convo.push_user("go");
1106 client
1107 .run(&mut convo, &echo_registry(), opts)
1108 .await
1109 .unwrap();
1110
1111 let snapshots = captured.lock().unwrap();
1117 assert_eq!(*snapshots, vec![3, 4]);
1118 }
1119
1120 #[tokio::test]
1121 async fn on_checkpoint_does_not_fire_when_unset() {
1122 let mock = MockServer::start().await;
1125 Mock::given(method("POST"))
1126 .and(path("/v1/messages"))
1127 .respond_with(
1128 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1129 )
1130 .mount(&mock)
1131 .await;
1132 let client = client_for(&mock);
1133 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1134 convo.push_user("hi");
1135 client
1136 .run(&mut convo, &ToolRegistry::new(), RunOptions::default())
1137 .await
1138 .unwrap();
1139 }
1140
1141 #[tokio::test]
1142 async fn checkpoint_supports_resume_via_serde() {
1143 let mock = MockServer::start().await;
1149 Mock::given(method("POST"))
1150 .and(path("/v1/messages"))
1151 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1152 "toolu_1",
1153 "echo",
1154 json!({"text": "first"}),
1155 )))
1156 .up_to_n_times(1)
1157 .mount(&mock)
1158 .await;
1159 Mock::given(method("POST"))
1160 .and(path("/v1/messages"))
1161 .respond_with(
1162 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1163 )
1164 .mount(&mock)
1165 .await;
1166
1167 let snapshot: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
1168 let sink = Arc::clone(&snapshot);
1169 let opts = RunOptions::default()
1170 .max_iterations(1)
1171 .on_checkpoint(move |c| {
1172 *sink.lock().unwrap() = Some(serde_json::to_string(c).unwrap());
1173 });
1174 let client = client_for(&mock);
1175 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1176 convo.push_user("go");
1177
1178 let _ = client.run(&mut convo, &echo_registry(), opts).await;
1182 let json = snapshot.lock().unwrap().clone().expect("checkpoint fired");
1183
1184 drop(convo);
1186 let mut resumed: Conversation = serde_json::from_str(&json).unwrap();
1187 let final_msg = client
1188 .run(
1189 &mut resumed,
1190 &echo_registry(),
1191 RunOptions::default().max_iterations(4),
1192 )
1193 .await
1194 .unwrap();
1195 assert_eq!(final_msg.stop_reason, Some(StopReason::EndTurn));
1196 assert!(resumed.messages.len() >= 4);
1198 }
1199
1200 #[tokio::test]
1201 async fn approver_approve_passes_through_to_dispatch() {
1202 let mock = MockServer::start().await;
1203 Mock::given(method("POST"))
1204 .and(path("/v1/messages"))
1205 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1206 "toolu_1",
1207 "echo",
1208 json!({"text": "hi"}),
1209 )))
1210 .up_to_n_times(1)
1211 .mount(&mock)
1212 .await;
1213 Mock::given(method("POST"))
1214 .and(path("/v1/messages"))
1215 .and(body_partial_json(json!({
1216 "messages": [
1217 {"role": "user", "content": "go"},
1218 {"role": "assistant"},
1219 {"role": "user", "content": [
1220 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"hi\"}"}
1221 ]}
1222 ]
1223 })))
1224 .respond_with(
1225 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1226 )
1227 .mount(&mock)
1228 .await;
1229
1230 let client = client_for(&mock);
1231 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1232 convo.push_user("go");
1233
1234 let opts = RunOptions::default()
1235 .with_approver_fn(|_name, _input| async { ApprovalDecision::Approve });
1236 let resp = client
1237 .run(&mut convo, &echo_registry(), opts)
1238 .await
1239 .unwrap();
1240 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
1241 }
1242
1243 #[tokio::test]
1244 async fn approver_approve_with_input_rewrites_dispatch_payload() {
1245 let mock = MockServer::start().await;
1246 Mock::given(method("POST"))
1247 .and(path("/v1/messages"))
1248 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1249 "toolu_1",
1250 "echo",
1251 json!({"text": "secret"}),
1252 )))
1253 .up_to_n_times(1)
1254 .mount(&mock)
1255 .await;
1256 Mock::given(method("POST"))
1258 .and(path("/v1/messages"))
1259 .and(body_partial_json(json!({
1260 "messages": [
1261 {"role": "user", "content": "go"},
1262 {"role": "assistant"},
1263 {"role": "user", "content": [
1264 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "{\"text\":\"REDACTED\"}"}
1265 ]}
1266 ]
1267 })))
1268 .respond_with(
1269 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1270 )
1271 .mount(&mock)
1272 .await;
1273
1274 let client = client_for(&mock);
1275 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1276 convo.push_user("go");
1277
1278 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1279 ApprovalDecision::ApproveWithInput(json!({"text": "REDACTED"}))
1280 });
1281 client
1282 .run(&mut convo, &echo_registry(), opts)
1283 .await
1284 .unwrap();
1285 }
1286
1287 #[tokio::test]
1288 async fn approver_substitute_skips_dispatch_and_returns_value() {
1289 let mut registry = ToolRegistry::new();
1291 registry.register("dangerous", json!({}), |_| async {
1292 panic!("dispatch should have been skipped by Substitute")
1293 });
1294
1295 let mock = MockServer::start().await;
1296 Mock::given(method("POST"))
1297 .and(path("/v1/messages"))
1298 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1299 "toolu_1",
1300 "dangerous",
1301 json!({"arg": 1}),
1302 )))
1303 .up_to_n_times(1)
1304 .mount(&mock)
1305 .await;
1306 Mock::given(method("POST"))
1307 .and(path("/v1/messages"))
1308 .and(body_partial_json(json!({
1309 "messages": [
1310 {"role": "user", "content": "go"},
1311 {"role": "assistant"},
1312 {"role": "user", "content": [
1313 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "stubbed"}
1314 ]}
1315 ]
1316 })))
1317 .respond_with(
1318 ResponseTemplate::new(200).set_body_json(assistant_text("ok", "end_turn")),
1319 )
1320 .mount(&mock)
1321 .await;
1322
1323 let client = client_for(&mock);
1324 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1325 convo.push_user("go");
1326
1327 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1328 ApprovalDecision::Substitute(json!("stubbed"))
1329 });
1330 client.run(&mut convo, ®istry, opts).await.unwrap();
1331 }
1332
1333 #[tokio::test]
1334 async fn approver_deny_returns_is_error_tool_result_and_loop_continues() {
1335 let mock = MockServer::start().await;
1336 Mock::given(method("POST"))
1337 .and(path("/v1/messages"))
1338 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1339 "toolu_1",
1340 "echo",
1341 json!({"text": "hi"}),
1342 )))
1343 .up_to_n_times(1)
1344 .mount(&mock)
1345 .await;
1346 Mock::given(method("POST"))
1347 .and(path("/v1/messages"))
1348 .and(body_partial_json(json!({
1349 "messages": [
1350 {"role": "user", "content": "go"},
1351 {"role": "assistant"},
1352 {"role": "user", "content": [
1353 {"type": "tool_result", "tool_use_id": "toolu_1", "content": "policy violation: no echo today", "is_error": true}
1354 ]}
1355 ]
1356 })))
1357 .respond_with(
1358 ResponseTemplate::new(200).set_body_json(assistant_text("ack", "end_turn")),
1359 )
1360 .mount(&mock)
1361 .await;
1362
1363 let client = client_for(&mock);
1364 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1365 convo.push_user("go");
1366
1367 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1368 ApprovalDecision::Deny("policy violation: no echo today".into())
1369 });
1370 let resp = client
1371 .run(&mut convo, &echo_registry(), opts)
1372 .await
1373 .unwrap();
1374 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
1375 }
1376
1377 #[tokio::test]
1378 async fn approver_stop_aborts_loop_with_typed_error() {
1379 let mock = MockServer::start().await;
1380 Mock::given(method("POST"))
1381 .and(path("/v1/messages"))
1382 .respond_with(ResponseTemplate::new(200).set_body_json(assistant_tool_use(
1383 "toolu_1",
1384 "echo",
1385 json!({"text": "hi"}),
1386 )))
1387 .mount(&mock)
1388 .await;
1389
1390 let client = client_for(&mock);
1391 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1392 convo.push_user("go");
1393
1394 let opts = RunOptions::default().with_approver_fn(|_name, _input| async {
1395 ApprovalDecision::Stop("user cancelled".into())
1396 });
1397 let err = client
1398 .run(&mut convo, &echo_registry(), opts)
1399 .await
1400 .unwrap_err();
1401 match err {
1402 Error::ToolApprovalStopped { tool_name, reason } => {
1403 assert_eq!(tool_name, "echo");
1404 assert_eq!(reason, "user cancelled");
1405 }
1406 other => panic!("expected ToolApprovalStopped, got {other:?}"),
1407 }
1408 }
1409
1410 #[tokio::test]
1411 async fn approver_per_call_decision_can_mix_approve_and_deny() {
1412 let mock = MockServer::start().await;
1414 let dual_tool_use = json!({
1415 "id": "msg_t",
1416 "type": "message",
1417 "role": "assistant",
1418 "content": [
1419 {"type": "tool_use", "id": "toolu_1", "name": "echo", "input": {"text": "ok"}},
1420 {"type": "tool_use", "id": "toolu_2", "name": "echo", "input": {"text": "block"}},
1421 ],
1422 "model": "claude-sonnet-4-6",
1423 "stop_reason": "tool_use",
1424 "usage": {"input_tokens": 10, "output_tokens": 5}
1425 });
1426 Mock::given(method("POST"))
1427 .and(path("/v1/messages"))
1428 .respond_with(ResponseTemplate::new(200).set_body_json(dual_tool_use))
1429 .up_to_n_times(1)
1430 .mount(&mock)
1431 .await;
1432 Mock::given(method("POST"))
1433 .and(path("/v1/messages"))
1434 .respond_with(
1435 ResponseTemplate::new(200).set_body_json(assistant_text("done", "end_turn")),
1436 )
1437 .mount(&mock)
1438 .await;
1439
1440 let client = client_for(&mock);
1441 let mut convo = Conversation::new(ModelId::SONNET_4_6, 64);
1442 convo.push_user("go");
1443
1444 let opts = RunOptions::default().with_approver_fn(|_name, input| {
1445 let blocked = input.get("text").and_then(Value::as_str) == Some("block");
1446 async move {
1447 if blocked {
1448 ApprovalDecision::Deny("blocked".into())
1449 } else {
1450 ApprovalDecision::Approve
1451 }
1452 }
1453 });
1454 client
1455 .run(&mut convo, &echo_registry(), opts)
1456 .await
1457 .unwrap();
1458
1459 let tool_result_turn = &convo.messages[2];
1463 let serialized = serde_json::to_value(tool_result_turn).unwrap();
1464 let blocks = serialized
1465 .get("content")
1466 .and_then(Value::as_array)
1467 .expect("content array");
1468 assert_eq!(blocks.len(), 2);
1469 assert_eq!(blocks[0]["tool_use_id"], "toolu_1");
1470 assert!(blocks[0].get("is_error").is_none());
1471 assert_eq!(blocks[1]["tool_use_id"], "toolu_2");
1472 assert_eq!(blocks[1]["is_error"], true);
1473 assert_eq!(blocks[1]["content"], "blocked");
1474 }
1475
1476 #[tokio::test]
1477 async fn tool_returning_string_value_passes_through_cleanly() {
1478 let result = value_to_tool_result(json!("plain text"));
1480 let ToolResultContent::Text(t) = result else {
1481 panic!("expected Text");
1482 };
1483 assert_eq!(t, "plain text");
1484 }
1485
1486 #[tokio::test]
1487 async fn tool_returning_object_value_serializes_to_json_string() {
1488 let result = value_to_tool_result(json!({"k": 42}));
1489 let ToolResultContent::Text(t) = result else {
1490 panic!("expected Text");
1491 };
1492 let parsed: Value = serde_json::from_str(&t).unwrap();
1494 assert_eq!(parsed, json!({"k": 42}));
1495 }
1496}