1pub mod audit;
4pub mod batch;
5pub mod blackboard;
6pub(crate) mod blackboard_tools;
7mod builder;
8pub mod cache;
9pub mod context;
10pub mod dag;
11pub mod debate;
12mod doom_loop;
13pub mod evaluator;
14pub mod events;
15pub mod guardrail;
16pub mod guardrails;
17pub mod handoff;
18pub mod instructions;
19pub mod mixture;
20pub mod observability;
21pub mod orchestrator;
22pub mod permission;
23pub mod prompts;
24pub mod pruner;
25pub mod routing;
26mod runner;
27pub mod tenant_tracker;
28pub mod token_estimator;
29pub mod tool_filter;
30pub mod voting;
31pub mod workflow;
32
33#[cfg(test)]
34pub(crate) mod test_helpers;
35
36pub use builder::AgentRunnerBuilder;
38pub use runner::{AgentOutput, AgentRunner, OnInput};
39#[cfg(test)]
41use crate::error::Error;
42#[cfg(test)]
43use crate::llm::LlmProvider;
44#[cfg(test)]
45use crate::llm::types::{Message, ToolCall, ToolDefinition};
46#[cfg(test)]
47use crate::tool::{Tool, ToolOutput};
48#[cfg(test)]
49use audit::AuditTrail;
50#[cfg(test)]
51use context::ContextStrategy;
52#[cfg(test)]
53use doom_loop::DoomLoopTracker;
54#[cfg(test)]
55use events::{AgentEvent, OnEvent};
56#[cfg(test)]
57use std::sync::Arc;
58#[cfg(test)]
59use std::time::Duration;
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::llm::types::{
65 CompletionRequest, CompletionResponse, ContentBlock, StopReason, TokenUsage,
66 };
67 use serde_json::json;
68 use std::sync::Mutex;
69
70 struct MockProvider {
73 responses: Mutex<Vec<CompletionResponse>>,
74 }
75
76 impl MockProvider {
77 fn new(responses: Vec<CompletionResponse>) -> Self {
78 Self {
79 responses: Mutex::new(responses),
80 }
81 }
82 }
83
84 impl LlmProvider for MockProvider {
85 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
86 let mut responses = self.responses.lock().expect("mock lock poisoned");
87 if responses.is_empty() {
88 return Err(Error::Agent("no more mock responses".into()));
89 }
90 Ok(responses.remove(0))
91 }
92 }
93
94 struct MockTool {
97 def: ToolDefinition,
98 response: String,
99 is_error: bool,
100 }
101
102 impl MockTool {
103 fn new(name: &str, response: &str) -> Self {
104 Self {
105 def: ToolDefinition {
106 name: name.into(),
107 description: format!("Mock tool {name}"),
108 input_schema: json!({"type": "object"}),
109 },
110 response: response.into(),
111 is_error: false,
112 }
113 }
114
115 fn failing(name: &str, error_msg: &str) -> Self {
116 Self {
117 def: ToolDefinition {
118 name: name.into(),
119 description: format!("Failing mock tool {name}"),
120 input_schema: json!({"type": "object"}),
121 },
122 response: error_msg.into(),
123 is_error: true,
124 }
125 }
126 }
127
128 impl Tool for MockTool {
129 fn definition(&self) -> ToolDefinition {
130 self.def.clone()
131 }
132
133 fn execute(
134 &self,
135 _ctx: &crate::ExecutionContext,
136 _input: serde_json::Value,
137 ) -> std::pin::Pin<
138 Box<dyn std::future::Future<Output = Result<ToolOutput, Error>> + Send + '_>,
139 > {
140 let response = self.response.clone();
141 let is_error = self.is_error;
142 Box::pin(async move {
143 if is_error {
144 Ok(ToolOutput::error(response))
145 } else {
146 Ok(ToolOutput::success(response))
147 }
148 })
149 }
150 }
151
152 #[tokio::test]
153 async fn agent_returns_text_on_end_turn() {
154 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
155 content: vec![ContentBlock::Text {
156 text: "Hello!".into(),
157 }],
158 stop_reason: StopReason::EndTurn,
159 usage: TokenUsage {
160 input_tokens: 10,
161 output_tokens: 5,
162 ..Default::default()
163 },
164 model: None,
165 }]));
166
167 let runner = AgentRunner::builder(provider)
168 .name("test")
169 .system_prompt("You are helpful.")
170 .build()
171 .unwrap();
172
173 let output = runner.execute("say hello").await.unwrap();
174 assert_eq!(output.result, "Hello!");
175 assert_eq!(output.tool_calls_made, 0);
176 assert_eq!(output.tokens_used.input_tokens, 10);
177 }
178
179 #[tokio::test]
180 async fn estimated_cost_usd_populated_for_known_model() {
181 struct CostMockProvider;
183 impl LlmProvider for CostMockProvider {
184 async fn complete(
185 &self,
186 _request: CompletionRequest,
187 ) -> Result<CompletionResponse, Error> {
188 Ok(CompletionResponse {
189 content: vec![ContentBlock::Text {
190 text: "response".into(),
191 }],
192 stop_reason: StopReason::EndTurn,
193 usage: TokenUsage {
194 input_tokens: 1000,
195 output_tokens: 500,
196 ..Default::default()
197 },
198 model: None,
199 })
200 }
201 fn model_name(&self) -> Option<&str> {
202 Some("claude-sonnet-4-20250514")
203 }
204 }
205
206 let provider = Arc::new(CostMockProvider);
207 let runner = AgentRunner::builder(provider)
208 .name("cost-test")
209 .system_prompt("sys")
210 .build()
211 .unwrap();
212
213 let output = runner.execute("task").await.unwrap();
214 assert!(
215 output.estimated_cost_usd.is_some(),
216 "expected cost estimate for known model"
217 );
218 let cost = output.estimated_cost_usd.unwrap();
219 assert!(
221 (cost - 0.0105).abs() < 0.001,
222 "expected ~$0.0105, got: {cost}"
223 );
224 }
225
226 #[tokio::test]
227 async fn estimated_cost_usd_none_for_unknown_model() {
228 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
229 content: vec![ContentBlock::Text { text: "hi".into() }],
230 stop_reason: StopReason::EndTurn,
231 usage: TokenUsage::default(),
232 model: None,
233 }]));
234
235 let runner = AgentRunner::builder(provider)
236 .name("test")
237 .system_prompt("sys")
238 .build()
239 .unwrap();
240
241 let output = runner.execute("task").await.unwrap();
242 assert!(
243 output.estimated_cost_usd.is_none(),
244 "expected None for mock provider without model_name"
245 );
246 }
247
248 #[tokio::test]
249 async fn agent_executes_tool_and_continues() {
250 let provider = Arc::new(MockProvider::new(vec![
251 CompletionResponse {
252 content: vec![ContentBlock::ToolUse {
253 id: "call-1".into(),
254 name: "search".into(),
255 input: json!({"q": "rust"}),
256 }],
257 stop_reason: StopReason::ToolUse,
258 usage: TokenUsage {
259 input_tokens: 20,
260 output_tokens: 10,
261 ..Default::default()
262 },
263 model: None,
264 },
265 CompletionResponse {
266 content: vec![ContentBlock::Text {
267 text: "Found it!".into(),
268 }],
269 stop_reason: StopReason::EndTurn,
270 usage: TokenUsage {
271 input_tokens: 30,
272 output_tokens: 15,
273 ..Default::default()
274 },
275 model: None,
276 },
277 ]));
278
279 let runner = AgentRunner::builder(provider)
280 .name("test")
281 .system_prompt("You are helpful.")
282 .tool(Arc::new(MockTool::new("search", "search results here")))
283 .build()
284 .unwrap();
285
286 let output = runner.execute("find rust info").await.unwrap();
287 assert_eq!(output.result, "Found it!");
288 assert_eq!(output.tool_calls_made, 1);
289 assert_eq!(output.tokens_used.input_tokens, 50);
290 assert_eq!(output.tokens_used.output_tokens, 25);
291 }
292
293 #[tokio::test]
294 async fn agent_errors_on_max_turns() {
295 let provider = Arc::new(MockProvider::new(vec![
296 CompletionResponse {
297 content: vec![ContentBlock::ToolUse {
298 id: "c1".into(),
299 name: "search".into(),
300 input: json!({}),
301 }],
302 stop_reason: StopReason::ToolUse,
303 usage: TokenUsage::default(),
304 model: None,
305 },
306 CompletionResponse {
307 content: vec![ContentBlock::ToolUse {
308 id: "c2".into(),
309 name: "search".into(),
310 input: json!({}),
311 }],
312 stop_reason: StopReason::ToolUse,
313 usage: TokenUsage::default(),
314 model: None,
315 },
316 ]));
317
318 let runner = AgentRunner::builder(provider)
319 .name("test")
320 .system_prompt("sys")
321 .tool(Arc::new(MockTool::new("search", "result")))
322 .max_turns(2)
323 .build()
324 .unwrap();
325
326 let err = runner.execute("loop forever").await.unwrap_err();
327 assert!(
328 matches!(
329 err,
330 Error::WithPartialUsage {
331 ref source,
332 ..
333 } if matches!(**source, Error::MaxTurnsExceeded(2))
334 ),
335 "expected MaxTurnsExceeded(2), got: {err:?}"
336 );
337 }
338
339 #[tokio::test]
340 async fn agent_error_carries_partial_token_usage() {
341 let provider = Arc::new(MockProvider::new(vec![
343 CompletionResponse {
345 content: vec![ContentBlock::ToolUse {
346 id: "c1".into(),
347 name: "search".into(),
348 input: json!({}),
349 }],
350 stop_reason: StopReason::ToolUse,
351 usage: TokenUsage {
352 input_tokens: 100,
353 output_tokens: 50,
354 cache_creation_input_tokens: 30,
355 cache_read_input_tokens: 0,
356 reasoning_tokens: 0,
357 },
358 model: None,
359 },
360 CompletionResponse {
362 content: vec![ContentBlock::ToolUse {
363 id: "c2".into(),
364 name: "search".into(),
365 input: json!({}),
366 }],
367 stop_reason: StopReason::ToolUse,
368 usage: TokenUsage {
369 input_tokens: 120,
370 output_tokens: 60,
371 cache_creation_input_tokens: 0,
372 cache_read_input_tokens: 25,
373 reasoning_tokens: 0,
374 },
375 model: None,
376 },
377 ]));
378
379 let runner = AgentRunner::builder(provider)
380 .name("test")
381 .system_prompt("sys")
382 .tool(Arc::new(MockTool::new("search", "result")))
383 .max_turns(2)
384 .build()
385 .unwrap();
386
387 let err = runner.execute("loop forever").await.unwrap_err();
388 let partial = err.partial_usage();
389 assert_eq!(partial.input_tokens, 220, "100 + 120");
390 assert_eq!(partial.output_tokens, 110, "50 + 60");
391 assert_eq!(partial.cache_creation_input_tokens, 30);
392 assert_eq!(partial.cache_read_input_tokens, 25);
393 }
394
395 #[tokio::test]
396 async fn agent_returns_error_for_unknown_tool() {
397 let provider = Arc::new(MockProvider::new(vec![
399 CompletionResponse {
400 content: vec![ContentBlock::ToolUse {
401 id: "c1".into(),
402 name: "nonexistent".into(),
403 input: json!({}),
404 }],
405 stop_reason: StopReason::ToolUse,
406 usage: TokenUsage::default(),
407 model: None,
408 },
409 CompletionResponse {
410 content: vec![ContentBlock::Text {
411 text: "Sorry about that.".into(),
412 }],
413 stop_reason: StopReason::EndTurn,
414 usage: TokenUsage::default(),
415 model: None,
416 },
417 ]));
418
419 let runner = AgentRunner::builder(provider)
420 .name("test")
421 .system_prompt("sys")
422 .build()
423 .unwrap();
424
425 let output = runner.execute("use unknown tool").await.unwrap();
427 assert_eq!(output.result, "Sorry about that.");
428 assert_eq!(output.tool_calls_made, 1);
429 }
430
431 #[tokio::test]
432 async fn agent_executes_parallel_tool_calls() {
433 let provider = Arc::new(MockProvider::new(vec![
434 CompletionResponse {
435 content: vec![
436 ContentBlock::ToolUse {
437 id: "c1".into(),
438 name: "search".into(),
439 input: json!({"q": "a"}),
440 },
441 ContentBlock::ToolUse {
442 id: "c2".into(),
443 name: "read".into(),
444 input: json!({"path": "/tmp"}),
445 },
446 ],
447 stop_reason: StopReason::ToolUse,
448 usage: TokenUsage::default(),
449 model: None,
450 },
451 CompletionResponse {
452 content: vec![ContentBlock::Text {
453 text: "Done!".into(),
454 }],
455 stop_reason: StopReason::EndTurn,
456 usage: TokenUsage::default(),
457 model: None,
458 },
459 ]));
460
461 let runner = AgentRunner::builder(provider)
462 .name("test")
463 .system_prompt("sys")
464 .tool(Arc::new(MockTool::new("search", "found")))
465 .tool(Arc::new(MockTool::new("read", "file content")))
466 .build()
467 .unwrap();
468
469 let output = runner.execute("do both").await.unwrap();
470 assert_eq!(output.result, "Done!");
471 assert_eq!(output.tool_calls_made, 2);
472 }
473
474 #[tokio::test]
475 async fn agent_errors_on_max_tokens() {
476 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
477 content: vec![ContentBlock::Text {
478 text: "truncated...".into(),
479 }],
480 stop_reason: StopReason::MaxTokens,
481 usage: TokenUsage::default(),
482 model: None,
483 }]));
484
485 let runner = AgentRunner::builder(provider)
486 .name("test")
487 .system_prompt("sys")
488 .build()
489 .unwrap();
490
491 let err = runner.execute("write a long essay").await.unwrap_err();
492 assert!(
493 matches!(
494 err,
495 Error::WithPartialUsage {
496 ref source,
497 ..
498 } if matches!(**source, Error::Truncated)
499 ),
500 "expected Truncated, got: {err:?}"
501 );
502 }
503
504 #[tokio::test]
505 async fn agent_handles_tool_error_result() {
506 let provider = Arc::new(MockProvider::new(vec![
507 CompletionResponse {
508 content: vec![ContentBlock::ToolUse {
509 id: "c1".into(),
510 name: "failing".into(),
511 input: json!({}),
512 }],
513 stop_reason: StopReason::ToolUse,
514 usage: TokenUsage::default(),
515 model: None,
516 },
517 CompletionResponse {
518 content: vec![ContentBlock::Text {
519 text: "Tool failed, but I recovered.".into(),
520 }],
521 stop_reason: StopReason::EndTurn,
522 usage: TokenUsage::default(),
523 model: None,
524 },
525 ]));
526
527 let runner = AgentRunner::builder(provider)
528 .name("test")
529 .system_prompt("sys")
530 .tool(Arc::new(MockTool::failing("failing", "something broke")))
531 .build()
532 .unwrap();
533
534 let output = runner.execute("try the tool").await.unwrap();
535 assert_eq!(output.result, "Tool failed, but I recovered.");
536 }
537
538 #[tokio::test]
539 async fn max_tokens_is_configurable() {
540 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
541 content: vec![ContentBlock::Text { text: "ok".into() }],
542 stop_reason: StopReason::EndTurn,
543 usage: TokenUsage::default(),
544 model: None,
545 }]));
546
547 let runner = AgentRunner::builder(provider)
548 .name("test")
549 .system_prompt("sys")
550 .max_tokens(8192)
551 .build()
552 .unwrap();
553
554 let output = runner.execute("test").await.unwrap();
556 assert_eq!(output.result, "ok");
557 }
558
559 #[test]
560 fn build_errors_on_explicit_empty_name() {
561 let provider = Arc::new(MockProvider::new(vec![]));
562 let result = AgentRunner::builder(provider)
563 .name("")
564 .system_prompt("sys")
565 .build();
566 assert!(result.is_err());
567 let err = result.err().unwrap();
568 assert!(
569 err.to_string().contains("agent name must not be empty"),
570 "error: {err}"
571 );
572 }
573
574 #[test]
575 fn build_succeeds_with_default_name() {
576 let provider = Arc::new(MockProvider::new(vec![]));
577 let runner = AgentRunner::builder(provider)
578 .system_prompt("sys")
579 .build()
580 .expect("minimal builder chain must succeed without an explicit name");
581 assert_eq!(runner.name(), "agent");
582 }
583
584 #[test]
585 fn build_errors_on_zero_max_turns() {
586 let provider = Arc::new(MockProvider::new(vec![]));
587 let result = AgentRunner::builder(provider)
588 .name("test")
589 .system_prompt("sys")
590 .max_turns(0)
591 .build();
592 assert!(result.is_err());
593 let err = result.err().unwrap();
594 assert!(
595 err.to_string().contains("max_turns must be at least 1"),
596 "error: {err}"
597 );
598 }
599
600 #[test]
601 fn build_errors_on_zero_max_tokens() {
602 let provider = Arc::new(MockProvider::new(vec![]));
603 let result = AgentRunner::builder(provider)
604 .name("test")
605 .system_prompt("sys")
606 .max_tokens(0)
607 .build();
608 assert!(result.is_err());
609 let err = result.err().unwrap();
610 assert!(
611 err.to_string().contains("max_tokens must be at least 1"),
612 "error: {err}"
613 );
614 }
615
616 #[test]
617 fn build_errors_on_sliding_window_with_summarize_threshold() {
618 let provider = Arc::new(MockProvider::new(vec![]));
619 let result = AgentRunner::builder(provider)
620 .name("test")
621 .system_prompt("sys")
622 .context_strategy(ContextStrategy::SlidingWindow { max_tokens: 50000 })
623 .summarize_threshold(8000)
624 .build();
625 assert!(result.is_err());
626 let err = result.err().unwrap();
627 assert!(
628 err.to_string()
629 .contains("cannot use summarize_threshold with SlidingWindow"),
630 "error: {err}"
631 );
632 }
633
634 #[test]
635 fn build_errors_on_on_input_with_structured_schema() {
636 let provider = Arc::new(MockProvider::new(vec![]));
637 let on_input: Arc<OnInput> = Arc::new(|| Box::pin(async { None }));
638 let result = AgentRunner::builder(provider)
639 .name("test")
640 .system_prompt("sys")
641 .on_input(on_input)
642 .structured_schema(serde_json::json!({"type": "object"}))
643 .build();
644 assert!(result.is_err());
645 let err = result.err().unwrap();
646 assert!(
647 err.to_string().contains(
648 "on_input (interactive mode) and structured_schema are mutually exclusive"
649 ),
650 "error: {err}"
651 );
652 }
653
654 #[tokio::test]
655 async fn instruction_text_prepended_to_system_prompt() {
656 struct CapturingProvider {
658 captured_system: Mutex<Option<String>>,
659 }
660 impl LlmProvider for CapturingProvider {
661 async fn complete(
662 &self,
663 request: CompletionRequest,
664 ) -> Result<CompletionResponse, Error> {
665 *self.captured_system.lock().expect("lock") = Some(request.system.clone());
666 Ok(CompletionResponse {
667 content: vec![ContentBlock::Text {
668 text: "done".into(),
669 }],
670 stop_reason: StopReason::EndTurn,
671 usage: TokenUsage::default(),
672 model: None,
673 })
674 }
675 }
676
677 let provider = Arc::new(CapturingProvider {
678 captured_system: Mutex::new(None),
679 });
680 let runner = AgentRunner::builder(provider.clone())
681 .name("test")
682 .system_prompt("You are an agent.")
683 .instruction_text("Be careful with files.")
684 .build()
685 .unwrap();
686 let _output = runner.execute("task").await.unwrap();
687 let system = provider
688 .captured_system
689 .lock()
690 .expect("lock")
691 .clone()
692 .expect("system prompt should have been captured");
693 assert!(
694 system.contains("# Project Instructions"),
695 "system prompt should contain instruction header: {system}"
696 );
697 assert!(
698 system.contains("Be careful with files."),
699 "system prompt should contain instruction text: {system}"
700 );
701 assert!(
702 system.contains("You are an agent."),
703 "system prompt should contain original prompt: {system}"
704 );
705 let instruction_pos = system.find("Be careful with files.").unwrap();
707 let prompt_pos = system.find("You are an agent.").unwrap();
708 assert!(
709 instruction_pos < prompt_pos,
710 "instructions should precede the original system prompt"
711 );
712 }
713
714 #[test]
715 fn instruction_text_empty_is_noop() {
716 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
717 content: vec![ContentBlock::Text {
718 text: "done".into(),
719 }],
720 stop_reason: StopReason::EndTurn,
721 usage: TokenUsage::default(),
722 model: None,
723 }]));
724 let builder = AgentRunner::builder(provider)
726 .name("test")
727 .system_prompt("You are an agent.")
728 .instruction_text(""); assert!(
731 builder.instruction_text.is_none(),
732 "empty instruction text should not be stored"
733 );
734 let _runner = builder.build().unwrap();
735 }
736
737 #[tokio::test]
738 async fn context_strategy_builder_sets_sliding_window() {
739 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
740 content: vec![ContentBlock::Text { text: "ok".into() }],
741 stop_reason: StopReason::EndTurn,
742 usage: TokenUsage::default(),
743 model: None,
744 }]));
745
746 let runner = AgentRunner::builder(provider)
747 .name("test")
748 .system_prompt("sys")
749 .context_strategy(ContextStrategy::SlidingWindow { max_tokens: 50000 })
750 .build()
751 .unwrap();
752
753 assert_eq!(
754 runner.context_strategy,
755 ContextStrategy::SlidingWindow { max_tokens: 50000 }
756 );
757 }
758
759 #[tokio::test]
760 async fn agent_uses_stream_complete_when_on_text_set() {
761 use std::sync::atomic::{AtomicBool, Ordering};
762
763 struct StreamTrackingProvider {
764 stream_called: Arc<AtomicBool>,
765 }
766
767 impl LlmProvider for StreamTrackingProvider {
768 async fn complete(
769 &self,
770 _request: CompletionRequest,
771 ) -> Result<CompletionResponse, Error> {
772 Ok(CompletionResponse {
773 content: vec![ContentBlock::Text {
774 text: "non-stream".into(),
775 }],
776 stop_reason: StopReason::EndTurn,
777 usage: TokenUsage::default(),
778 model: None,
779 })
780 }
781
782 async fn stream_complete(
783 &self,
784 _request: CompletionRequest,
785 on_text: &crate::llm::OnText,
786 ) -> Result<CompletionResponse, Error> {
787 self.stream_called.store(true, Ordering::SeqCst);
788 on_text("streamed ");
789 on_text("text");
790 Ok(CompletionResponse {
791 content: vec![ContentBlock::Text {
792 text: "streamed text".into(),
793 }],
794 stop_reason: StopReason::EndTurn,
795 usage: TokenUsage::default(),
796 model: None,
797 })
798 }
799 }
800
801 let stream_called = Arc::new(AtomicBool::new(false));
802 let provider = Arc::new(StreamTrackingProvider {
803 stream_called: stream_called.clone(),
804 });
805
806 let received = Arc::new(Mutex::new(Vec::<String>::new()));
807 let received_clone = received.clone();
808 let callback: Arc<crate::llm::OnText> = Arc::new(move |text: &str| {
809 received_clone.lock().expect("lock").push(text.to_string());
810 });
811
812 let runner = AgentRunner::builder(provider)
813 .name("test")
814 .system_prompt("sys")
815 .on_text(callback)
816 .build()
817 .unwrap();
818
819 let output = runner.execute("test").await.unwrap();
820 assert!(
821 stream_called.load(Ordering::SeqCst),
822 "stream_complete should have been called"
823 );
824 assert_eq!(output.result, "streamed text");
825
826 let texts = received.lock().expect("lock");
827 assert_eq!(*texts, vec!["streamed ", "text"]);
828 }
829
830 #[tokio::test]
831 async fn context_strategy_defaults_to_unlimited() {
832 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
833 content: vec![ContentBlock::Text { text: "ok".into() }],
834 stop_reason: StopReason::EndTurn,
835 usage: TokenUsage::default(),
836 model: None,
837 }]));
838
839 let runner = AgentRunner::builder(provider)
840 .name("test")
841 .system_prompt("sys")
842 .build()
843 .unwrap();
844
845 assert_eq!(runner.context_strategy, ContextStrategy::Unlimited);
846 }
847
848 #[tokio::test]
849 async fn approval_callback_approves_tool_execution() {
850 use std::sync::atomic::{AtomicBool, Ordering};
851
852 let approved = Arc::new(AtomicBool::new(false));
853 let approved_clone = approved.clone();
854
855 let provider = Arc::new(MockProvider::new(vec![
856 CompletionResponse {
857 content: vec![ContentBlock::ToolUse {
858 id: "c1".into(),
859 name: "search".into(),
860 input: json!({"q": "rust"}),
861 }],
862 stop_reason: StopReason::ToolUse,
863 usage: TokenUsage::default(),
864 model: None,
865 },
866 CompletionResponse {
867 content: vec![ContentBlock::Text {
868 text: "Found it!".into(),
869 }],
870 stop_reason: StopReason::EndTurn,
871 usage: TokenUsage::default(),
872 model: None,
873 },
874 ]));
875
876 let callback: Arc<crate::llm::OnApproval> = Arc::new(move |_calls| {
877 approved_clone.store(true, Ordering::SeqCst);
878 crate::llm::ApprovalDecision::Allow
879 });
880
881 let runner = AgentRunner::builder(provider)
882 .name("test")
883 .system_prompt("sys")
884 .tool(Arc::new(MockTool::new("search", "results")))
885 .on_approval(callback)
886 .build()
887 .unwrap();
888
889 let output = runner.execute("test").await.unwrap();
890 assert!(
891 approved.load(Ordering::SeqCst),
892 "approval callback was called"
893 );
894 assert_eq!(output.result, "Found it!");
895 assert_eq!(output.tool_calls_made, 1);
896 }
897
898 #[tokio::test]
899 async fn approval_callback_denies_tool_execution() {
900 let provider = Arc::new(MockProvider::new(vec![
901 CompletionResponse {
902 content: vec![ContentBlock::ToolUse {
903 id: "c1".into(),
904 name: "search".into(),
905 input: json!({"q": "rust"}),
906 }],
907 stop_reason: StopReason::ToolUse,
908 usage: TokenUsage::default(),
909 model: None,
910 },
911 CompletionResponse {
913 content: vec![ContentBlock::Text {
914 text: "I understand, I won't execute that.".into(),
915 }],
916 stop_reason: StopReason::EndTurn,
917 usage: TokenUsage::default(),
918 model: None,
919 },
920 ]));
921
922 let callback: Arc<crate::llm::OnApproval> =
923 Arc::new(|_calls| crate::llm::ApprovalDecision::Deny);
924
925 let runner = AgentRunner::builder(provider)
926 .name("test")
927 .system_prompt("sys")
928 .tool(Arc::new(MockTool::new("search", "results")))
929 .on_approval(callback)
930 .build()
931 .unwrap();
932
933 let output = runner.execute("test").await.unwrap();
934 assert_eq!(output.result, "I understand, I won't execute that.");
935 assert_eq!(output.tool_calls_made, 1);
937 }
938
939 #[tokio::test]
940 async fn approval_callback_receives_correct_tool_calls() {
941 let received_calls = Arc::new(Mutex::new(Vec::<String>::new()));
942 let received_clone = received_calls.clone();
943
944 let provider = Arc::new(MockProvider::new(vec![
945 CompletionResponse {
946 content: vec![
947 ContentBlock::ToolUse {
948 id: "c1".into(),
949 name: "search".into(),
950 input: json!({"q": "rust"}),
951 },
952 ContentBlock::ToolUse {
953 id: "c2".into(),
954 name: "read".into(),
955 input: json!({"path": "/tmp"}),
956 },
957 ],
958 stop_reason: StopReason::ToolUse,
959 usage: TokenUsage::default(),
960 model: None,
961 },
962 CompletionResponse {
963 content: vec![ContentBlock::Text {
964 text: "Done!".into(),
965 }],
966 stop_reason: StopReason::EndTurn,
967 usage: TokenUsage::default(),
968 model: None,
969 },
970 ]));
971
972 let callback: Arc<crate::llm::OnApproval> = Arc::new(move |calls| {
973 let names: Vec<String> = calls.iter().map(|c| c.name.clone()).collect();
974 received_clone.lock().expect("lock").extend(names);
975 crate::llm::ApprovalDecision::Allow
976 });
977
978 let runner = AgentRunner::builder(provider)
979 .name("test")
980 .system_prompt("sys")
981 .tool(Arc::new(MockTool::new("search", "found")))
982 .tool(Arc::new(MockTool::new("read", "content")))
983 .on_approval(callback)
984 .build()
985 .unwrap();
986
987 runner.execute("test").await.unwrap();
988
989 let calls = received_calls.lock().expect("lock");
990 assert_eq!(*calls, vec!["search", "read"]);
991 }
992
993 #[tokio::test]
994 async fn tool_timeout_returns_error_to_llm() {
995 struct SlowTool;
997 impl Tool for SlowTool {
998 fn definition(&self) -> ToolDefinition {
999 ToolDefinition {
1000 name: "slow_tool".into(),
1001 description: "Takes forever".into(),
1002 input_schema: json!({"type": "object"}),
1003 }
1004 }
1005 fn execute(
1006 &self,
1007 _ctx: &crate::ExecutionContext,
1008 _input: serde_json::Value,
1009 ) -> std::pin::Pin<
1010 Box<dyn std::future::Future<Output = Result<ToolOutput, Error>> + Send + '_>,
1011 > {
1012 Box::pin(async {
1013 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1014 Ok(ToolOutput::success("should never reach here"))
1015 })
1016 }
1017 }
1018
1019 let provider = Arc::new(MockProvider::new(vec![
1020 CompletionResponse {
1021 content: vec![ContentBlock::ToolUse {
1022 id: "c1".into(),
1023 name: "slow_tool".into(),
1024 input: json!({}),
1025 }],
1026 stop_reason: StopReason::ToolUse,
1027 usage: TokenUsage::default(),
1028 model: None,
1029 },
1030 CompletionResponse {
1031 content: vec![ContentBlock::Text {
1032 text: "Tool timed out, moving on.".into(),
1033 }],
1034 stop_reason: StopReason::EndTurn,
1035 usage: TokenUsage::default(),
1036 model: None,
1037 },
1038 ]));
1039
1040 let runner = AgentRunner::builder(provider)
1041 .name("test")
1042 .system_prompt("sys")
1043 .tool(Arc::new(SlowTool))
1044 .tool_timeout(std::time::Duration::from_millis(50))
1045 .build()
1046 .unwrap();
1047
1048 let output = runner.execute("run slow tool").await.unwrap();
1049 assert_eq!(output.result, "Tool timed out, moving on.");
1050 assert_eq!(output.tool_calls_made, 1);
1051 }
1052
1053 #[tokio::test]
1054 async fn tool_timeout_does_not_affect_fast_tools() {
1055 let provider = Arc::new(MockProvider::new(vec![
1057 CompletionResponse {
1058 content: vec![ContentBlock::ToolUse {
1059 id: "c1".into(),
1060 name: "search".into(),
1061 input: json!({}),
1062 }],
1063 stop_reason: StopReason::ToolUse,
1064 usage: TokenUsage::default(),
1065 model: None,
1066 },
1067 CompletionResponse {
1068 content: vec![ContentBlock::Text {
1069 text: "Got results!".into(),
1070 }],
1071 stop_reason: StopReason::EndTurn,
1072 usage: TokenUsage::default(),
1073 model: None,
1074 },
1075 ]));
1076
1077 let runner = AgentRunner::builder(provider)
1078 .name("test")
1079 .system_prompt("sys")
1080 .tool(Arc::new(MockTool::new("search", "search results")))
1081 .tool_timeout(std::time::Duration::from_secs(30))
1082 .build()
1083 .unwrap();
1084
1085 let output = runner.execute("search").await.unwrap();
1086 assert_eq!(output.result, "Got results!");
1087 assert_eq!(output.tool_calls_made, 1);
1088 }
1089
1090 #[tokio::test]
1091 async fn no_tool_timeout_allows_unlimited_execution() {
1092 let provider = Arc::new(MockProvider::new(vec![
1094 CompletionResponse {
1095 content: vec![ContentBlock::ToolUse {
1096 id: "c1".into(),
1097 name: "search".into(),
1098 input: json!({}),
1099 }],
1100 stop_reason: StopReason::ToolUse,
1101 usage: TokenUsage::default(),
1102 model: None,
1103 },
1104 CompletionResponse {
1105 content: vec![ContentBlock::Text {
1106 text: "Done!".into(),
1107 }],
1108 stop_reason: StopReason::EndTurn,
1109 usage: TokenUsage::default(),
1110 model: None,
1111 },
1112 ]));
1113
1114 let runner = AgentRunner::builder(provider)
1115 .name("test")
1116 .system_prompt("sys")
1117 .tool(Arc::new(MockTool::new("search", "result")))
1118 .build()
1119 .unwrap();
1120
1121 let output = runner.execute("test").await.unwrap();
1123 assert_eq!(output.result, "Done!");
1124 }
1125
1126 #[tokio::test]
1127 async fn no_approval_callback_executes_tools_directly() {
1128 let provider = Arc::new(MockProvider::new(vec![
1130 CompletionResponse {
1131 content: vec![ContentBlock::ToolUse {
1132 id: "c1".into(),
1133 name: "search".into(),
1134 input: json!({}),
1135 }],
1136 stop_reason: StopReason::ToolUse,
1137 usage: TokenUsage::default(),
1138 model: None,
1139 },
1140 CompletionResponse {
1141 content: vec![ContentBlock::Text {
1142 text: "Done!".into(),
1143 }],
1144 stop_reason: StopReason::EndTurn,
1145 usage: TokenUsage::default(),
1146 model: None,
1147 },
1148 ]));
1149
1150 let runner = AgentRunner::builder(provider)
1151 .name("test")
1152 .system_prompt("sys")
1153 .tool(Arc::new(MockTool::new("search", "result")))
1154 .build()
1155 .unwrap();
1156
1157 let output = runner.execute("test").await.unwrap();
1158 assert_eq!(output.result, "Done!");
1159 assert_eq!(output.tool_calls_made, 1);
1160 }
1161
1162 #[tokio::test]
1163 async fn schema_validation_rejects_bad_input() {
1164 struct StrictTool;
1166 impl Tool for StrictTool {
1167 fn definition(&self) -> ToolDefinition {
1168 ToolDefinition {
1169 name: "search".into(),
1170 description: "Search".into(),
1171 input_schema: json!({
1172 "type": "object",
1173 "properties": {
1174 "query": {"type": "string"}
1175 },
1176 "required": ["query"]
1177 }),
1178 }
1179 }
1180 fn execute(
1181 &self,
1182 _ctx: &crate::ExecutionContext,
1183 _input: serde_json::Value,
1184 ) -> std::pin::Pin<
1185 Box<dyn std::future::Future<Output = Result<ToolOutput, Error>> + Send + '_>,
1186 > {
1187 Box::pin(async { Ok(ToolOutput::success("should not be called")) })
1188 }
1189 }
1190
1191 let provider = Arc::new(MockProvider::new(vec![
1193 CompletionResponse {
1195 content: vec![ContentBlock::ToolUse {
1196 id: "c1".into(),
1197 name: "search".into(),
1198 input: json!({"wrong_field": 42}), }],
1200 stop_reason: StopReason::ToolUse,
1201 usage: TokenUsage::default(),
1202 model: None,
1203 },
1204 CompletionResponse {
1206 content: vec![ContentBlock::Text {
1207 text: "I see the validation error.".into(),
1208 }],
1209 stop_reason: StopReason::EndTurn,
1210 usage: TokenUsage::default(),
1211 model: None,
1212 },
1213 ]));
1214
1215 let runner = AgentRunner::builder(provider)
1216 .name("test")
1217 .system_prompt("sys")
1218 .tool(Arc::new(StrictTool))
1219 .build()
1220 .unwrap();
1221
1222 let output = runner.execute("search for something").await.unwrap();
1223 assert_eq!(output.result, "I see the validation error.");
1225 assert_eq!(output.tool_calls_made, 1); }
1227
1228 #[tokio::test]
1229 async fn large_tool_output_is_truncated() {
1230 struct BigTool;
1232 impl Tool for BigTool {
1233 fn definition(&self) -> ToolDefinition {
1234 ToolDefinition {
1235 name: "big".into(),
1236 description: "Returns big output".into(),
1237 input_schema: json!({"type": "object"}),
1238 }
1239 }
1240 fn execute(
1241 &self,
1242 _ctx: &crate::ExecutionContext,
1243 _input: serde_json::Value,
1244 ) -> std::pin::Pin<
1245 Box<dyn std::future::Future<Output = Result<ToolOutput, Error>> + Send + '_>,
1246 > {
1247 Box::pin(async { Ok(ToolOutput::success("x".repeat(10_000))) })
1248 }
1249 }
1250
1251 let provider = Arc::new(MockProvider::new(vec![
1253 CompletionResponse {
1254 content: vec![ContentBlock::ToolUse {
1255 id: "c1".into(),
1256 name: "big".into(),
1257 input: json!({}),
1258 }],
1259 stop_reason: StopReason::ToolUse,
1260 usage: TokenUsage::default(),
1261 model: None,
1262 },
1263 CompletionResponse {
1264 content: vec![ContentBlock::Text {
1265 text: "Got truncated result.".into(),
1266 }],
1267 stop_reason: StopReason::EndTurn,
1268 usage: TokenUsage::default(),
1269 model: None,
1270 },
1271 ]));
1272
1273 let runner = AgentRunner::builder(provider)
1274 .name("test")
1275 .system_prompt("sys")
1276 .tool(Arc::new(BigTool))
1277 .max_tool_output_bytes(500)
1278 .build()
1279 .unwrap();
1280
1281 let output = runner.execute("get big data").await.unwrap();
1282 assert_eq!(output.result, "Got truncated result.");
1283 assert_eq!(output.tool_calls_made, 1);
1284 }
1285
1286 #[tokio::test]
1287 async fn structured_output_extracts_respond_tool() {
1288 let schema = json!({
1290 "type": "object",
1291 "properties": {
1292 "answer": {"type": "string"},
1293 "confidence": {"type": "number"}
1294 },
1295 "required": ["answer", "confidence"]
1296 });
1297
1298 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
1299 content: vec![ContentBlock::ToolUse {
1300 id: "c1".into(),
1301 name: "__respond__".into(),
1302 input: json!({"answer": "42", "confidence": 0.95}),
1303 }],
1304 stop_reason: StopReason::ToolUse,
1305 usage: TokenUsage {
1306 input_tokens: 20,
1307 output_tokens: 15,
1308 ..Default::default()
1309 },
1310 model: None,
1311 }]));
1312
1313 let runner = AgentRunner::builder(provider)
1314 .name("test")
1315 .system_prompt("You are helpful.")
1316 .structured_schema(schema)
1317 .build()
1318 .unwrap();
1319
1320 let output = runner.execute("what is the answer?").await.unwrap();
1321 assert!(output.structured.is_some());
1322 let structured = output.structured.unwrap();
1323 assert_eq!(structured["answer"], "42");
1324 assert_eq!(structured["confidence"], 0.95);
1325 assert_eq!(output.tool_calls_made, 1);
1326 }
1327
1328 #[tokio::test]
1329 async fn structured_output_none_without_schema() {
1330 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
1332 content: vec![ContentBlock::Text {
1333 text: "Hello!".into(),
1334 }],
1335 stop_reason: StopReason::EndTurn,
1336 usage: TokenUsage::default(),
1337 model: None,
1338 }]));
1339
1340 let runner = AgentRunner::builder(provider)
1341 .name("test")
1342 .system_prompt("sys")
1343 .build()
1344 .unwrap();
1345
1346 let output = runner.execute("test").await.unwrap();
1347 assert!(output.structured.is_none());
1348 }
1349
1350 #[tokio::test]
1351 async fn structured_output_allows_real_tools_first() {
1352 let schema = json!({
1354 "type": "object",
1355 "properties": { "result": {"type": "string"} },
1356 "required": ["result"]
1357 });
1358
1359 let provider = Arc::new(MockProvider::new(vec![
1360 CompletionResponse {
1362 content: vec![ContentBlock::ToolUse {
1363 id: "c1".into(),
1364 name: "search".into(),
1365 input: json!({"q": "data"}),
1366 }],
1367 stop_reason: StopReason::ToolUse,
1368 usage: TokenUsage::default(),
1369 model: None,
1370 },
1371 CompletionResponse {
1373 content: vec![ContentBlock::ToolUse {
1374 id: "c2".into(),
1375 name: "__respond__".into(),
1376 input: json!({"result": "found it"}),
1377 }],
1378 stop_reason: StopReason::ToolUse,
1379 usage: TokenUsage::default(),
1380 model: None,
1381 },
1382 ]));
1383
1384 let runner = AgentRunner::builder(provider)
1385 .name("test")
1386 .system_prompt("sys")
1387 .tool(Arc::new(MockTool::new("search", "search results")))
1388 .structured_schema(schema)
1389 .build()
1390 .unwrap();
1391
1392 let output = runner.execute("find data").await.unwrap();
1393 assert!(output.structured.is_some());
1394 assert_eq!(output.structured.unwrap()["result"], "found it");
1395 assert_eq!(output.tool_calls_made, 2);
1397 }
1398
1399 #[test]
1400 fn structured_schema_injects_respond_tool_definition() {
1401 let schema = json!({
1402 "type": "object",
1403 "properties": { "answer": {"type": "string"} }
1404 });
1405
1406 let provider = Arc::new(MockProvider::new(vec![]));
1407 let runner = AgentRunner::builder(provider)
1408 .name("test")
1409 .system_prompt("sys")
1410 .structured_schema(schema.clone())
1411 .build()
1412 .unwrap();
1413
1414 assert!(runner.tool_defs.iter().any(|d| d.name == "__respond__"));
1416 assert!(!runner.tools.contains_key("__respond__"));
1417 let respond_def = runner
1418 .tool_defs
1419 .iter()
1420 .find(|d| d.name == "__respond__")
1421 .unwrap();
1422 assert_eq!(respond_def.input_schema, schema);
1423 }
1424
1425 #[tokio::test]
1426 async fn structured_output_counts_all_tool_calls_in_respond_turn() {
1427 let schema = json!({
1429 "type": "object",
1430 "properties": { "result": {"type": "string"} }
1431 });
1432
1433 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
1434 content: vec![
1435 ContentBlock::ToolUse {
1436 id: "c1".into(),
1437 name: "search".into(),
1438 input: json!({"q": "data"}),
1439 },
1440 ContentBlock::ToolUse {
1441 id: "c2".into(),
1442 name: "__respond__".into(),
1443 input: json!({"result": "done"}),
1444 },
1445 ],
1446 stop_reason: StopReason::ToolUse,
1447 usage: TokenUsage::default(),
1448 model: None,
1449 }]));
1450
1451 let runner = AgentRunner::builder(provider)
1452 .name("test")
1453 .system_prompt("sys")
1454 .tool(Arc::new(MockTool::new("search", "results")))
1455 .structured_schema(schema)
1456 .build()
1457 .unwrap();
1458
1459 let output = runner.execute("test").await.unwrap();
1460 assert!(output.structured.is_some());
1461 assert_eq!(output.tool_calls_made, 2);
1463 }
1464
1465 #[tokio::test]
1466 async fn structured_output_max_turns_when_respond_never_called() {
1467 let schema = json!({
1470 "type": "object",
1471 "properties": { "result": {"type": "string"} }
1472 });
1473
1474 let provider = Arc::new(MockProvider::new(vec![
1475 CompletionResponse {
1476 content: vec![ContentBlock::ToolUse {
1477 id: "c1".into(),
1478 name: "search".into(),
1479 input: json!({}),
1480 }],
1481 stop_reason: StopReason::ToolUse,
1482 usage: TokenUsage::default(),
1483 model: None,
1484 },
1485 CompletionResponse {
1486 content: vec![ContentBlock::ToolUse {
1487 id: "c2".into(),
1488 name: "search".into(),
1489 input: json!({}),
1490 }],
1491 stop_reason: StopReason::ToolUse,
1492 usage: TokenUsage::default(),
1493 model: None,
1494 },
1495 ]));
1496
1497 let runner = AgentRunner::builder(provider)
1498 .name("test")
1499 .system_prompt("sys")
1500 .tool(Arc::new(MockTool::new("search", "results")))
1501 .structured_schema(schema)
1502 .max_turns(2)
1503 .build()
1504 .unwrap();
1505
1506 let err = runner.execute("test").await.unwrap_err();
1507 assert!(
1508 matches!(
1509 err,
1510 Error::WithPartialUsage {
1511 ref source,
1512 ..
1513 } if matches!(**source, Error::MaxTurnsExceeded(2))
1514 ),
1515 "expected MaxTurnsExceeded(2), got: {err:?}"
1516 );
1517 }
1518
1519 #[test]
1520 fn no_respond_tool_without_schema() {
1521 let provider = Arc::new(MockProvider::new(vec![]));
1522 let runner = AgentRunner::builder(provider)
1523 .name("test")
1524 .system_prompt("sys")
1525 .build()
1526 .unwrap();
1527
1528 assert!(!runner.tool_defs.iter().any(|d| d.name == "__respond__"));
1529 }
1530
1531 #[tokio::test]
1532 async fn small_tool_output_not_truncated_with_limit() {
1533 let provider = Arc::new(MockProvider::new(vec![
1535 CompletionResponse {
1536 content: vec![ContentBlock::ToolUse {
1537 id: "c1".into(),
1538 name: "search".into(),
1539 input: json!({}),
1540 }],
1541 stop_reason: StopReason::ToolUse,
1542 usage: TokenUsage::default(),
1543 model: None,
1544 },
1545 CompletionResponse {
1546 content: vec![ContentBlock::Text {
1547 text: "Done!".into(),
1548 }],
1549 stop_reason: StopReason::EndTurn,
1550 usage: TokenUsage::default(),
1551 model: None,
1552 },
1553 ]));
1554
1555 let runner = AgentRunner::builder(provider)
1556 .name("test")
1557 .system_prompt("sys")
1558 .tool(Arc::new(MockTool::new("search", "small result")))
1559 .max_tool_output_bytes(1000)
1560 .build()
1561 .unwrap();
1562
1563 let output = runner.execute("search").await.unwrap();
1564 assert_eq!(output.result, "Done!");
1565 }
1566
1567 #[test]
1568 fn agent_output_roundtrips() {
1569 let output = AgentOutput {
1570 result: "Hello!".into(),
1571 tool_calls_made: 3,
1572 tokens_used: TokenUsage {
1573 input_tokens: 100,
1574 output_tokens: 50,
1575 ..Default::default()
1576 },
1577 structured: Some(json!({"answer": "42"})),
1578 estimated_cost_usd: Some(0.0342),
1579 model_name: Some("claude-sonnet-4-6-20250610".into()),
1580 };
1581 let json_str = serde_json::to_string(&output).unwrap();
1582 let parsed: AgentOutput = serde_json::from_str(&json_str).unwrap();
1583 assert_eq!(parsed.result, "Hello!");
1584 assert_eq!(parsed.tool_calls_made, 3);
1585 assert_eq!(parsed.tokens_used.input_tokens, 100);
1586 assert_eq!(parsed.structured, Some(json!({"answer": "42"})));
1587 assert_eq!(parsed.estimated_cost_usd, Some(0.0342));
1588 assert_eq!(
1589 parsed.model_name.as_deref(),
1590 Some("claude-sonnet-4-6-20250610")
1591 );
1592 }
1593
1594 #[test]
1595 fn agent_output_structured_none_serializes() {
1596 let output = AgentOutput {
1597 result: "ok".into(),
1598 tool_calls_made: 0,
1599 tokens_used: TokenUsage::default(),
1600 structured: None,
1601 estimated_cost_usd: None,
1602 model_name: None,
1603 };
1604 let json_str = serde_json::to_string(&output).unwrap();
1605 let parsed: AgentOutput = serde_json::from_str(&json_str).unwrap();
1606 assert!(parsed.structured.is_none());
1607 assert!(parsed.model_name.is_none());
1608 }
1609
1610 #[test]
1611 fn agent_output_backward_compat_no_model_name() {
1612 let json = r#"{"result":"ok","tool_calls_made":0,"tokens_used":{"input_tokens":0,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"reasoning_tokens":0}}"#;
1614 let parsed: AgentOutput = serde_json::from_str(json).unwrap();
1615 assert!(parsed.model_name.is_none());
1616 assert_eq!(parsed.result, "ok");
1617 }
1618
1619 #[tokio::test]
1620 async fn structured_output_errors_when_llm_ignores_respond() {
1621 let schema = json!({
1625 "type": "object",
1626 "properties": { "answer": {"type": "string"} },
1627 "required": ["answer"]
1628 });
1629
1630 let provider = Arc::new(MockProvider::new(vec![
1631 CompletionResponse {
1633 content: vec![ContentBlock::Text {
1634 text: "Here is the answer.".into(),
1635 }],
1636 stop_reason: StopReason::EndTurn,
1637 usage: TokenUsage::default(),
1638 model: None,
1639 },
1640 ]));
1641
1642 let runner = AgentRunner::builder(provider)
1643 .name("test")
1644 .system_prompt("sys")
1645 .structured_schema(schema)
1646 .build()
1647 .unwrap();
1648
1649 let err = runner.execute("test").await.unwrap_err();
1650 assert!(
1651 err.to_string().contains("__respond__"),
1652 "error should mention __respond__: {err}"
1653 );
1654 }
1655
1656 #[tokio::test]
1657 async fn structured_output_does_not_force_tool_choice() {
1658 use std::sync::atomic::{AtomicBool, Ordering};
1662
1663 struct ToolChoiceTracker {
1664 tool_choice_any_seen: Arc<AtomicBool>,
1665 }
1666
1667 impl LlmProvider for ToolChoiceTracker {
1668 async fn complete(
1669 &self,
1670 request: CompletionRequest,
1671 ) -> Result<CompletionResponse, Error> {
1672 if request.tool_choice == Some(crate::llm::types::ToolChoice::Any) {
1673 self.tool_choice_any_seen.store(true, Ordering::SeqCst);
1674 }
1675 Ok(CompletionResponse {
1676 content: vec![ContentBlock::ToolUse {
1677 id: "c1".into(),
1678 name: "__respond__".into(),
1679 input: json!({"answer": "42"}),
1680 }],
1681 stop_reason: StopReason::ToolUse,
1682 usage: TokenUsage::default(),
1683 model: None,
1684 })
1685 }
1686 }
1687
1688 let seen = Arc::new(AtomicBool::new(false));
1689 let provider = Arc::new(ToolChoiceTracker {
1690 tool_choice_any_seen: seen.clone(),
1691 });
1692
1693 let schema = json!({
1694 "type": "object",
1695 "properties": { "answer": {"type": "string"} }
1696 });
1697
1698 let runner = AgentRunner::builder(provider)
1699 .name("test")
1700 .system_prompt("sys")
1701 .structured_schema(schema)
1702 .build()
1703 .unwrap();
1704
1705 let output = runner.execute("test").await.unwrap();
1706 assert!(
1707 !seen.load(Ordering::SeqCst),
1708 "tool_choice should NOT be forced to Any"
1709 );
1710 assert!(
1711 output.structured.is_some(),
1712 "structured output should still work"
1713 );
1714 }
1715
1716 #[tokio::test]
1717 async fn respond_tool_skips_co_submitted_real_tools() {
1718 use std::sync::atomic::{AtomicBool, Ordering};
1721
1722 let tool_executed = Arc::new(AtomicBool::new(false));
1723 let tool_executed_clone = tool_executed.clone();
1724
1725 struct TrackingTool {
1726 executed: Arc<AtomicBool>,
1727 }
1728 impl Tool for TrackingTool {
1729 fn definition(&self) -> ToolDefinition {
1730 ToolDefinition {
1731 name: "real_tool".into(),
1732 description: "A real tool".into(),
1733 input_schema: json!({"type": "object"}),
1734 }
1735 }
1736 fn execute(
1737 &self,
1738 _ctx: &crate::ExecutionContext,
1739 _input: serde_json::Value,
1740 ) -> std::pin::Pin<
1741 Box<dyn std::future::Future<Output = Result<ToolOutput, Error>> + Send + '_>,
1742 > {
1743 self.executed.store(true, Ordering::SeqCst);
1744 Box::pin(async { Ok(ToolOutput::success("done")) })
1745 }
1746 }
1747
1748 let provider = Arc::new(MockProvider::new(vec![
1749 CompletionResponse {
1751 content: vec![
1752 ContentBlock::ToolUse {
1753 id: "c1".into(),
1754 name: "real_tool".into(),
1755 input: json!({}),
1756 },
1757 ContentBlock::ToolUse {
1758 id: "c2".into(),
1759 name: "__respond__".into(),
1760 input: json!({"answer": "42"}),
1761 },
1762 ],
1763 stop_reason: StopReason::ToolUse,
1764 usage: TokenUsage::default(),
1765 model: None,
1766 },
1767 ]));
1768
1769 let schema = json!({
1770 "type": "object",
1771 "properties": { "answer": {"type": "string"} }
1772 });
1773
1774 let runner = AgentRunner::builder(provider)
1775 .name("test")
1776 .system_prompt("sys")
1777 .tool(Arc::new(TrackingTool {
1778 executed: tool_executed_clone,
1779 }))
1780 .structured_schema(schema)
1781 .build()
1782 .unwrap();
1783
1784 let output = runner.execute("test").await.unwrap();
1785
1786 assert!(
1787 output.structured.is_some(),
1788 "should return structured output"
1789 );
1790 assert_eq!(output.tool_calls_made, 2, "should count both tool calls");
1791 assert!(
1792 !tool_executed.load(Ordering::SeqCst),
1793 "real_tool should NOT have been executed when __respond__ is present"
1794 );
1795 }
1796
1797 #[tokio::test]
1798 async fn structured_output_validated_against_schema() {
1799 let schema = json!({
1802 "type": "object",
1803 "properties": {
1804 "answer": {"type": "string"},
1805 "confidence": {"type": "number"}
1806 },
1807 "required": ["answer", "confidence"]
1808 });
1809
1810 let provider = Arc::new(MockProvider::new(vec![
1811 CompletionResponse {
1813 content: vec![ContentBlock::ToolUse {
1814 id: "c1".into(),
1815 name: "__respond__".into(),
1816 input: json!({"answer": "42"}), }],
1818 stop_reason: StopReason::ToolUse,
1819 usage: TokenUsage::default(),
1820 model: None,
1821 },
1822 CompletionResponse {
1824 content: vec![ContentBlock::ToolUse {
1825 id: "c2".into(),
1826 name: "__respond__".into(),
1827 input: json!({"answer": "42", "confidence": 0.95}),
1828 }],
1829 stop_reason: StopReason::ToolUse,
1830 usage: TokenUsage::default(),
1831 model: None,
1832 },
1833 ]));
1834
1835 let runner = AgentRunner::builder(provider)
1836 .name("test")
1837 .system_prompt("sys")
1838 .structured_schema(schema)
1839 .build()
1840 .unwrap();
1841
1842 let output = runner.execute("test").await.unwrap();
1843 assert!(output.structured.is_some());
1844 assert_eq!(output.structured.unwrap()["confidence"], 0.95);
1845 assert_eq!(output.tool_calls_made, 2);
1847 }
1848
1849 #[tokio::test]
1850 async fn structured_output_validation_wrong_type() {
1851 let schema = json!({
1853 "type": "object",
1854 "properties": {
1855 "count": {"type": "integer"}
1856 },
1857 "required": ["count"]
1858 });
1859
1860 let provider = Arc::new(MockProvider::new(vec![
1861 CompletionResponse {
1863 content: vec![ContentBlock::ToolUse {
1864 id: "c1".into(),
1865 name: "__respond__".into(),
1866 input: json!({"count": "not a number"}),
1867 }],
1868 stop_reason: StopReason::ToolUse,
1869 usage: TokenUsage::default(),
1870 model: None,
1871 },
1872 CompletionResponse {
1874 content: vec![ContentBlock::ToolUse {
1875 id: "c2".into(),
1876 name: "__respond__".into(),
1877 input: json!({"count": 42}),
1878 }],
1879 stop_reason: StopReason::ToolUse,
1880 usage: TokenUsage::default(),
1881 model: None,
1882 },
1883 ]));
1884
1885 let runner = AgentRunner::builder(provider)
1886 .name("test")
1887 .system_prompt("sys")
1888 .structured_schema(schema)
1889 .build()
1890 .unwrap();
1891
1892 let output = runner.execute("test").await.unwrap();
1893 assert_eq!(output.structured.unwrap()["count"], 42);
1894 }
1895
1896 #[tokio::test]
1897 async fn structured_output_valid_on_first_try() {
1898 let schema = json!({
1900 "type": "object",
1901 "properties": {
1902 "result": {"type": "string"}
1903 },
1904 "required": ["result"]
1905 });
1906
1907 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
1908 content: vec![ContentBlock::ToolUse {
1909 id: "c1".into(),
1910 name: "__respond__".into(),
1911 input: json!({"result": "hello"}),
1912 }],
1913 stop_reason: StopReason::ToolUse,
1914 usage: TokenUsage::default(),
1915 model: None,
1916 }]));
1917
1918 let runner = AgentRunner::builder(provider)
1919 .name("test")
1920 .system_prompt("sys")
1921 .structured_schema(schema)
1922 .build()
1923 .unwrap();
1924
1925 let output = runner.execute("test").await.unwrap();
1926 assert_eq!(output.structured.unwrap()["result"], "hello");
1927 assert_eq!(output.tool_calls_made, 1);
1928 }
1929
1930 #[tokio::test]
1931 async fn summarization_tokens_accumulated_in_total_usage() {
1932 let provider = Arc::new(MockProvider::new(vec![
1949 CompletionResponse {
1951 content: vec![ContentBlock::ToolUse {
1952 id: "c1".into(),
1953 name: "search".into(),
1954 input: json!({}),
1955 }],
1956 stop_reason: StopReason::ToolUse,
1957 usage: TokenUsage {
1958 input_tokens: 10,
1959 output_tokens: 5,
1960 ..Default::default()
1961 },
1962 model: None,
1963 },
1964 CompletionResponse {
1966 content: vec![ContentBlock::ToolUse {
1967 id: "c2".into(),
1968 name: "search".into(),
1969 input: json!({}),
1970 }],
1971 stop_reason: StopReason::ToolUse,
1972 usage: TokenUsage {
1973 input_tokens: 10,
1974 output_tokens: 5,
1975 ..Default::default()
1976 },
1977 model: None,
1978 },
1979 CompletionResponse {
1981 content: vec![ContentBlock::ToolUse {
1982 id: "c3".into(),
1983 name: "search".into(),
1984 input: json!({}),
1985 }],
1986 stop_reason: StopReason::ToolUse,
1987 usage: TokenUsage {
1988 input_tokens: 10,
1989 output_tokens: 5,
1990 ..Default::default()
1991 },
1992 model: None,
1993 },
1994 CompletionResponse {
1996 content: vec![ContentBlock::Text {
1997 text: "Summary of conversation so far.".into(),
1998 }],
1999 stop_reason: StopReason::EndTurn,
2000 usage: TokenUsage {
2001 input_tokens: 100,
2002 output_tokens: 50,
2003 cache_creation_input_tokens: 25,
2004 cache_read_input_tokens: 10,
2005 reasoning_tokens: 0,
2006 },
2007 model: None,
2008 },
2009 CompletionResponse {
2011 content: vec![ContentBlock::Text {
2012 text: "Final answer.".into(),
2013 }],
2014 stop_reason: StopReason::EndTurn,
2015 usage: TokenUsage {
2016 input_tokens: 10,
2017 output_tokens: 5,
2018 ..Default::default()
2019 },
2020 model: None,
2021 },
2022 ]));
2023
2024 let runner = AgentRunner::builder(provider)
2025 .name("test")
2026 .system_prompt("sys")
2027 .tool(Arc::new(MockTool::new("search", "result")))
2028 .summarize_threshold(1) .max_turns(10)
2030 .build()
2031 .unwrap();
2032
2033 let output = runner.execute("test task").await.unwrap();
2034 assert_eq!(output.result, "Final answer.");
2035 assert_eq!(output.tokens_used.input_tokens, 10 + 10 + 10 + 100 + 10);
2037 assert_eq!(output.tokens_used.output_tokens, 5 + 5 + 5 + 50 + 5);
2039 assert_eq!(output.tokens_used.cache_creation_input_tokens, 25);
2041 assert_eq!(output.tokens_used.cache_read_input_tokens, 10);
2042 }
2043
2044 #[test]
2045 fn knowledge_base_adds_search_tool() {
2046 use crate::knowledge::in_memory::InMemoryKnowledgeBase;
2047
2048 let kb: Arc<dyn crate::knowledge::KnowledgeBase> = Arc::new(InMemoryKnowledgeBase::new());
2049 let provider = Arc::new(MockProvider::new(vec![]));
2050
2051 let runner = AgentRunner::builder(provider)
2052 .name("test")
2053 .system_prompt("sys")
2054 .knowledge(kb)
2055 .build()
2056 .unwrap();
2057
2058 assert!(
2059 runner
2060 .tool_defs
2061 .iter()
2062 .any(|d| d.name == "knowledge_search"),
2063 "agent should have knowledge_search tool"
2064 );
2065 }
2066
2067 #[tokio::test]
2068 async fn on_event_emits_run_started_and_completed() {
2069 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
2070 Arc::new(std::sync::Mutex::new(vec![]));
2071 let events_clone = events.clone();
2072
2073 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
2074 content: vec![ContentBlock::Text {
2075 text: "Done.".into(),
2076 }],
2077 stop_reason: StopReason::EndTurn,
2078 usage: TokenUsage {
2079 input_tokens: 10,
2080 output_tokens: 5,
2081 ..Default::default()
2082 },
2083 model: None,
2084 }]));
2085
2086 let runner = AgentRunner::builder(provider)
2087 .name("test-agent")
2088 .system_prompt("sys")
2089 .on_event(Arc::new(move |e| {
2090 events_clone.lock().unwrap().push(e);
2091 }))
2092 .build()
2093 .unwrap();
2094
2095 runner.execute("hello").await.unwrap();
2096
2097 let events = events.lock().unwrap();
2098 assert!(
2099 events.len() >= 4,
2100 "expected at least 4 events, got {}",
2101 events.len()
2102 );
2103
2104 match &events[0] {
2106 AgentEvent::RunStarted { agent, task } => {
2107 assert_eq!(agent, "test-agent");
2108 assert_eq!(task, "hello");
2109 }
2110 other => panic!("expected RunStarted, got: {other:?}"),
2111 }
2112
2113 match &events[1] {
2115 AgentEvent::TurnStarted { agent, turn, .. } => {
2116 assert_eq!(agent, "test-agent");
2117 assert_eq!(*turn, 1);
2118 }
2119 other => panic!("expected TurnStarted, got: {other:?}"),
2120 }
2121
2122 match &events[2] {
2124 AgentEvent::LlmResponse {
2125 agent,
2126 turn,
2127 tool_call_count,
2128 ..
2129 } => {
2130 assert_eq!(agent, "test-agent");
2131 assert_eq!(*turn, 1);
2132 assert_eq!(*tool_call_count, 0);
2133 }
2134 other => panic!("expected LlmResponse, got: {other:?}"),
2135 }
2136
2137 match events.last().unwrap() {
2139 AgentEvent::RunCompleted {
2140 agent,
2141 tool_calls_made,
2142 ..
2143 } => {
2144 assert_eq!(agent, "test-agent");
2145 assert_eq!(*tool_calls_made, 0);
2146 }
2147 other => panic!("expected RunCompleted, got: {other:?}"),
2148 }
2149 }
2150
2151 #[tokio::test]
2152 async fn on_event_emits_tool_call_events() {
2153 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
2154 Arc::new(std::sync::Mutex::new(vec![]));
2155 let events_clone = events.clone();
2156
2157 let provider = Arc::new(MockProvider::new(vec![
2158 CompletionResponse {
2159 content: vec![ContentBlock::ToolUse {
2160 id: "call-1".into(),
2161 name: "search".into(),
2162 input: json!({}),
2163 }],
2164 stop_reason: StopReason::ToolUse,
2165 usage: TokenUsage::default(),
2166 model: None,
2167 },
2168 CompletionResponse {
2169 content: vec![ContentBlock::Text {
2170 text: "Result.".into(),
2171 }],
2172 stop_reason: StopReason::EndTurn,
2173 usage: TokenUsage::default(),
2174 model: None,
2175 },
2176 ]));
2177
2178 let runner = AgentRunner::builder(provider)
2179 .name("worker")
2180 .system_prompt("sys")
2181 .tool(Arc::new(MockTool::new("search", "found it")))
2182 .on_event(Arc::new(move |e| {
2183 events_clone.lock().unwrap().push(e);
2184 }))
2185 .build()
2186 .unwrap();
2187
2188 runner.execute("find stuff").await.unwrap();
2189
2190 let events = events.lock().unwrap();
2191 let tool_started: Vec<_> = events
2192 .iter()
2193 .filter(|e| matches!(e, AgentEvent::ToolCallStarted { .. }))
2194 .collect();
2195 let tool_completed: Vec<_> = events
2196 .iter()
2197 .filter(|e| matches!(e, AgentEvent::ToolCallCompleted { .. }))
2198 .collect();
2199
2200 assert_eq!(tool_started.len(), 1, "expected 1 ToolCallStarted");
2201 assert_eq!(tool_completed.len(), 1, "expected 1 ToolCallCompleted");
2202
2203 match &tool_started[0] {
2204 AgentEvent::ToolCallStarted {
2205 tool_name,
2206 tool_call_id,
2207 ..
2208 } => {
2209 assert_eq!(tool_name, "search");
2210 assert_eq!(tool_call_id, "call-1");
2211 }
2212 _ => unreachable!(),
2213 }
2214
2215 match &tool_completed[0] {
2216 AgentEvent::ToolCallCompleted {
2217 tool_name,
2218 is_error,
2219 ..
2220 } => {
2221 assert_eq!(tool_name, "search");
2222 assert!(!is_error);
2223 }
2224 _ => unreachable!(),
2225 }
2226 }
2227
2228 #[tokio::test]
2229 async fn on_event_emits_run_failed_on_max_turns() {
2230 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
2231 Arc::new(std::sync::Mutex::new(vec![]));
2232 let events_clone = events.clone();
2233
2234 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
2236 content: vec![ContentBlock::ToolUse {
2237 id: "call-1".into(),
2238 name: "search".into(),
2239 input: json!({}),
2240 }],
2241 stop_reason: StopReason::ToolUse,
2242 usage: TokenUsage::default(),
2243 model: None,
2244 }]));
2245
2246 let runner = AgentRunner::builder(provider)
2247 .name("limited")
2248 .system_prompt("sys")
2249 .tool(Arc::new(MockTool::new("search", "found")))
2250 .max_turns(1)
2251 .on_event(Arc::new(move |e| {
2252 events_clone.lock().unwrap().push(e);
2253 }))
2254 .build()
2255 .unwrap();
2256
2257 let result = runner.execute("go").await;
2258 assert!(result.is_err());
2259
2260 let events = events.lock().unwrap();
2261 let run_failed: Vec<_> = events
2262 .iter()
2263 .filter(|e| matches!(e, AgentEvent::RunFailed { .. }))
2264 .collect();
2265 assert_eq!(run_failed.len(), 1, "expected 1 RunFailed event");
2266
2267 match &run_failed[0] {
2268 AgentEvent::RunFailed { agent, error, .. } => {
2269 assert_eq!(agent, "limited");
2270 assert!(error.contains("Max turns"), "error: {error}");
2271 }
2272 _ => unreachable!(),
2273 }
2274 }
2275
2276 #[tokio::test]
2277 async fn no_events_when_callback_not_set() {
2278 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
2280 content: vec![ContentBlock::Text {
2281 text: "Done.".into(),
2282 }],
2283 stop_reason: StopReason::EndTurn,
2284 usage: TokenUsage::default(),
2285 model: None,
2286 }]));
2287
2288 let runner = AgentRunner::builder(provider)
2289 .name("quiet")
2290 .system_prompt("sys")
2291 .build()
2292 .unwrap();
2293
2294 let output = runner.execute("hello").await.unwrap();
2295 assert_eq!(output.result, "Done.");
2296 }
2297
2298 use crate::agent::guardrail::{GuardAction, Guardrail};
2301
2302 struct SystemPromptInjector {
2303 suffix: String,
2304 }
2305
2306 impl Guardrail for SystemPromptInjector {
2307 fn pre_llm(
2308 &self,
2309 request: &mut CompletionRequest,
2310 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
2311 {
2312 request.system = format!("{} {}", request.system, self.suffix);
2313 Box::pin(async { Ok(()) })
2314 }
2315 }
2316
2317 #[tokio::test]
2318 async fn pre_llm_guardrail_modifies_request() {
2319 struct CapturingProvider {
2320 captured_system: Mutex<Option<String>>,
2321 }
2322
2323 impl LlmProvider for CapturingProvider {
2324 async fn complete(
2325 &self,
2326 request: CompletionRequest,
2327 ) -> Result<CompletionResponse, Error> {
2328 *self.captured_system.lock().unwrap() = Some(request.system);
2329 Ok(CompletionResponse {
2330 content: vec![ContentBlock::Text { text: "ok".into() }],
2331 stop_reason: StopReason::EndTurn,
2332 usage: TokenUsage::default(),
2333 model: None,
2334 })
2335 }
2336 }
2337
2338 let provider = Arc::new(CapturingProvider {
2339 captured_system: Mutex::new(None),
2340 });
2341
2342 let guardrail: Arc<dyn Guardrail> = Arc::new(SystemPromptInjector {
2343 suffix: "SAFETY_NOTICE".into(),
2344 });
2345
2346 let runner = AgentRunner::builder(provider.clone())
2347 .name("test")
2348 .system_prompt("You are helpful.")
2349 .guardrail(guardrail)
2350 .build()
2351 .unwrap();
2352
2353 runner.execute("hello").await.unwrap();
2354
2355 let captured = provider.captured_system.lock().unwrap().clone().unwrap();
2356 assert!(
2357 captured.contains("SAFETY_NOTICE"),
2358 "system prompt should contain injected suffix: {captured}"
2359 );
2360 }
2361
2362 #[tokio::test]
2363 async fn post_llm_guardrail_denies_response() {
2364 struct CountingProvider {
2366 call_count: Mutex<usize>,
2367 }
2368
2369 impl LlmProvider for CountingProvider {
2370 async fn complete(
2371 &self,
2372 _request: CompletionRequest,
2373 ) -> Result<CompletionResponse, Error> {
2374 let mut count = self.call_count.lock().unwrap();
2375 *count += 1;
2376 Ok(CompletionResponse {
2377 content: vec![ContentBlock::Text {
2378 text: format!("Response #{count}"),
2379 }],
2380 stop_reason: StopReason::EndTurn,
2381 usage: TokenUsage::default(),
2382 model: None,
2383 })
2384 }
2385 }
2386
2387 struct DenyOnce {
2389 denied: Mutex<bool>,
2390 }
2391
2392 impl Guardrail for DenyOnce {
2393 fn post_llm(
2394 &self,
2395 _response: &mut crate::llm::types::CompletionResponse,
2396 ) -> std::pin::Pin<
2397 Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>,
2398 > {
2399 Box::pin(async {
2400 let mut denied = self.denied.lock().unwrap();
2401 if !*denied {
2402 *denied = true;
2403 Ok(GuardAction::deny("unsafe content"))
2404 } else {
2405 Ok(GuardAction::Allow)
2406 }
2407 })
2408 }
2409 }
2410
2411 let provider = Arc::new(CountingProvider {
2412 call_count: Mutex::new(0),
2413 });
2414
2415 let runner = AgentRunner::builder(provider.clone())
2416 .name("test")
2417 .system_prompt("sys")
2418 .guardrail(Arc::new(DenyOnce {
2419 denied: Mutex::new(false),
2420 }))
2421 .max_turns(3)
2422 .build()
2423 .unwrap();
2424
2425 let output = runner.execute("hello").await.unwrap();
2426 assert_eq!(output.result, "Response #2");
2428 assert_eq!(*provider.call_count.lock().unwrap(), 2);
2430 }
2431
2432 #[tokio::test]
2433 async fn post_llm_denial_maintains_alternating_roles() {
2434 use crate::llm::types::{CompletionResponse, Role};
2437
2438 struct RecordingProvider {
2439 call_count: Mutex<usize>,
2440 last_messages: Mutex<Vec<Role>>,
2441 }
2442
2443 impl LlmProvider for RecordingProvider {
2444 async fn complete(
2445 &self,
2446 request: CompletionRequest,
2447 ) -> Result<CompletionResponse, Error> {
2448 let mut count = self.call_count.lock().unwrap();
2449 *count += 1;
2450 let roles: Vec<Role> = request.messages.iter().map(|m| m.role.clone()).collect();
2452 *self.last_messages.lock().unwrap() = roles;
2453 Ok(CompletionResponse {
2454 content: vec![ContentBlock::Text {
2455 text: format!("Response #{count}"),
2456 }],
2457 stop_reason: StopReason::EndTurn,
2458 usage: TokenUsage::default(),
2459 model: None,
2460 })
2461 }
2462 }
2463
2464 struct DenyOnce {
2465 denied: Mutex<bool>,
2466 }
2467
2468 impl Guardrail for DenyOnce {
2469 fn post_llm(
2470 &self,
2471 _response: &mut CompletionResponse,
2472 ) -> std::pin::Pin<
2473 Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>,
2474 > {
2475 Box::pin(async {
2476 let mut denied = self.denied.lock().unwrap();
2477 if !*denied {
2478 *denied = true;
2479 Ok(GuardAction::deny("blocked"))
2480 } else {
2481 Ok(GuardAction::Allow)
2482 }
2483 })
2484 }
2485 }
2486
2487 let provider = Arc::new(RecordingProvider {
2488 call_count: Mutex::new(0),
2489 last_messages: Mutex::new(vec![]),
2490 });
2491
2492 let runner = AgentRunner::builder(provider.clone())
2493 .name("test")
2494 .system_prompt("sys")
2495 .guardrail(Arc::new(DenyOnce {
2496 denied: Mutex::new(false),
2497 }))
2498 .max_turns(3)
2499 .build()
2500 .unwrap();
2501
2502 let output = runner.execute("hello").await.unwrap();
2503 assert_eq!(output.result, "Response #2");
2504
2505 let roles = provider.last_messages.lock().unwrap();
2507 for pair in roles.windows(2) {
2508 assert_ne!(
2509 pair[0],
2510 pair[1],
2511 "Found consecutive messages with same role: {:?}",
2512 roles.as_slice()
2513 );
2514 }
2515 }
2516
2517 struct DenyingPreTool {
2518 blocked_tool: String,
2519 reason: String,
2520 }
2521
2522 impl Guardrail for DenyingPreTool {
2523 fn pre_tool(
2524 &self,
2525 call: &crate::llm::types::ToolCall,
2526 ) -> std::pin::Pin<
2527 Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>,
2528 > {
2529 let result = if call.name == self.blocked_tool {
2530 GuardAction::deny(&self.reason)
2531 } else {
2532 GuardAction::Allow
2533 };
2534 Box::pin(async move { Ok(result) })
2535 }
2536 }
2537
2538 #[tokio::test]
2539 async fn pre_tool_guardrail_denies_specific_tool() {
2540 let provider = Arc::new(MockProvider::new(vec![
2542 CompletionResponse {
2543 content: vec![ContentBlock::ToolUse {
2544 id: "c1".into(),
2545 name: "dangerous".into(),
2546 input: json!({}),
2547 }],
2548 stop_reason: StopReason::ToolUse,
2549 usage: TokenUsage::default(),
2550 model: None,
2551 },
2552 CompletionResponse {
2553 content: vec![ContentBlock::Text {
2554 text: "OK, skipping dangerous tool.".into(),
2555 }],
2556 stop_reason: StopReason::EndTurn,
2557 usage: TokenUsage::default(),
2558 model: None,
2559 },
2560 ]));
2561
2562 let runner = AgentRunner::builder(provider)
2563 .name("test")
2564 .system_prompt("sys")
2565 .tool(Arc::new(MockTool::new("dangerous", "should not run")))
2566 .guardrail(Arc::new(DenyingPreTool {
2567 blocked_tool: "dangerous".into(),
2568 reason: "tool is blocked".into(),
2569 }))
2570 .build()
2571 .unwrap();
2572
2573 let output = runner.execute("do something").await.unwrap();
2574 assert_eq!(output.result, "OK, skipping dangerous tool.");
2575 assert_eq!(output.tool_calls_made, 1); }
2577
2578 struct RedactingPostTool;
2579
2580 impl Guardrail for RedactingPostTool {
2581 fn post_tool(
2582 &self,
2583 _call: &crate::llm::types::ToolCall,
2584 output: &mut crate::tool::ToolOutput,
2585 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
2586 {
2587 output.content = output.content.replace("SECRET", "[REDACTED]");
2588 Box::pin(async { Ok(()) })
2589 }
2590 }
2591
2592 #[tokio::test]
2593 async fn post_tool_guardrail_redacts_output() {
2594 struct CapturingProvider {
2596 responses: Mutex<Vec<CompletionResponse>>,
2597 tool_results_seen: Mutex<Vec<String>>,
2598 }
2599
2600 impl LlmProvider for CapturingProvider {
2601 async fn complete(
2602 &self,
2603 request: CompletionRequest,
2604 ) -> Result<CompletionResponse, Error> {
2605 for msg in &request.messages {
2607 for block in &msg.content {
2608 if let ContentBlock::ToolResult { content, .. } = block {
2609 self.tool_results_seen.lock().unwrap().push(content.clone());
2610 }
2611 }
2612 }
2613
2614 let mut responses = self.responses.lock().unwrap();
2615 if responses.is_empty() {
2616 return Err(Error::Agent("no more responses".into()));
2617 }
2618 Ok(responses.remove(0))
2619 }
2620 }
2621
2622 let provider = Arc::new(CapturingProvider {
2623 responses: Mutex::new(vec![
2624 CompletionResponse {
2625 content: vec![ContentBlock::ToolUse {
2626 id: "c1".into(),
2627 name: "search".into(),
2628 input: json!({}),
2629 }],
2630 stop_reason: StopReason::ToolUse,
2631 usage: TokenUsage::default(),
2632 model: None,
2633 },
2634 CompletionResponse {
2635 content: vec![ContentBlock::Text {
2636 text: "Done.".into(),
2637 }],
2638 stop_reason: StopReason::EndTurn,
2639 usage: TokenUsage::default(),
2640 model: None,
2641 },
2642 ]),
2643 tool_results_seen: Mutex::new(vec![]),
2644 });
2645
2646 let runner = AgentRunner::builder(provider.clone())
2647 .name("test")
2648 .system_prompt("sys")
2649 .tool(Arc::new(MockTool::new("search", "Found SECRET data")))
2650 .guardrail(Arc::new(RedactingPostTool))
2651 .build()
2652 .unwrap();
2653
2654 runner.execute("search").await.unwrap();
2655
2656 let results = provider.tool_results_seen.lock().unwrap();
2657 assert!(
2658 results.iter().any(|r| r.contains("[REDACTED]")),
2659 "tool result should be redacted: {results:?}"
2660 );
2661 assert!(
2662 !results.iter().any(|r| r.contains("SECRET")),
2663 "tool result should not contain SECRET: {results:?}"
2664 );
2665 }
2666
2667 #[tokio::test]
2668 async fn multiple_guardrails_compose() {
2669 struct AllowGuardrail;
2671 impl Guardrail for AllowGuardrail {}
2672
2673 let provider = Arc::new(MockProvider::new(vec![
2674 CompletionResponse {
2675 content: vec![ContentBlock::ToolUse {
2676 id: "c1".into(),
2677 name: "search".into(),
2678 input: json!({}),
2679 }],
2680 stop_reason: StopReason::ToolUse,
2681 usage: TokenUsage::default(),
2682 model: None,
2683 },
2684 CompletionResponse {
2685 content: vec![ContentBlock::Text {
2686 text: "Denied.".into(),
2687 }],
2688 stop_reason: StopReason::EndTurn,
2689 usage: TokenUsage::default(),
2690 model: None,
2691 },
2692 ]));
2693
2694 let runner = AgentRunner::builder(provider)
2695 .name("test")
2696 .system_prompt("sys")
2697 .tool(Arc::new(MockTool::new("search", "result")))
2698 .guardrail(Arc::new(AllowGuardrail))
2699 .guardrail(Arc::new(DenyingPreTool {
2700 blocked_tool: "search".into(),
2701 reason: "blocked by second guardrail".into(),
2702 }))
2703 .build()
2704 .unwrap();
2705
2706 let output = runner.execute("search").await.unwrap();
2707 assert_eq!(output.result, "Denied.");
2708 }
2709
2710 #[tokio::test]
2711 async fn guardrail_error_aborts_run() {
2712 struct ErrorGuardrail;
2713 impl Guardrail for ErrorGuardrail {
2714 fn pre_llm(
2715 &self,
2716 _request: &mut CompletionRequest,
2717 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
2718 {
2719 Box::pin(async { Err(Error::Guardrail("fatal check failed".into())) })
2720 }
2721 }
2722
2723 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
2724 content: vec![ContentBlock::Text {
2725 text: "should not reach".into(),
2726 }],
2727 stop_reason: StopReason::EndTurn,
2728 usage: TokenUsage::default(),
2729 model: None,
2730 }]));
2731
2732 let runner = AgentRunner::builder(provider)
2733 .name("test")
2734 .system_prompt("sys")
2735 .guardrail(Arc::new(ErrorGuardrail))
2736 .build()
2737 .unwrap();
2738
2739 let err = runner.execute("hello").await.unwrap_err();
2740 assert!(
2741 err.to_string().contains("fatal check failed"),
2742 "error should contain guardrail message: {err}"
2743 );
2744 }
2745
2746 #[tokio::test]
2747 async fn on_approval_and_pre_tool_compose() {
2748 let provider = Arc::new(MockProvider::new(vec![
2750 CompletionResponse {
2751 content: vec![
2752 ContentBlock::ToolUse {
2753 id: "c1".into(),
2754 name: "safe".into(),
2755 input: json!({}),
2756 },
2757 ContentBlock::ToolUse {
2758 id: "c2".into(),
2759 name: "dangerous".into(),
2760 input: json!({}),
2761 },
2762 ],
2763 stop_reason: StopReason::ToolUse,
2764 usage: TokenUsage::default(),
2765 model: None,
2766 },
2767 CompletionResponse {
2768 content: vec![ContentBlock::Text {
2769 text: "Used safe, dangerous blocked.".into(),
2770 }],
2771 stop_reason: StopReason::EndTurn,
2772 usage: TokenUsage::default(),
2773 model: None,
2774 },
2775 ]));
2776
2777 let approval: Arc<crate::llm::OnApproval> =
2778 Arc::new(|_calls: &[_]| crate::llm::ApprovalDecision::Allow);
2779
2780 let runner = AgentRunner::builder(provider)
2781 .name("test")
2782 .system_prompt("sys")
2783 .tool(Arc::new(MockTool::new("safe", "safe result")))
2784 .tool(Arc::new(MockTool::new("dangerous", "should not run")))
2785 .on_approval(approval)
2786 .guardrail(Arc::new(DenyingPreTool {
2787 blocked_tool: "dangerous".into(),
2788 reason: "blocked".into(),
2789 }))
2790 .build()
2791 .unwrap();
2792
2793 let output = runner.execute("do both").await.unwrap();
2794 assert_eq!(output.result, "Used safe, dangerous blocked.");
2795 assert_eq!(output.tool_calls_made, 2);
2796 }
2797
2798 #[tokio::test]
2799 async fn no_guardrails_unchanged_behavior() {
2800 let provider = Arc::new(MockProvider::new(vec![
2801 CompletionResponse {
2802 content: vec![ContentBlock::ToolUse {
2803 id: "c1".into(),
2804 name: "search".into(),
2805 input: json!({}),
2806 }],
2807 stop_reason: StopReason::ToolUse,
2808 usage: TokenUsage::default(),
2809 model: None,
2810 },
2811 CompletionResponse {
2812 content: vec![ContentBlock::Text {
2813 text: "Found it.".into(),
2814 }],
2815 stop_reason: StopReason::EndTurn,
2816 usage: TokenUsage::default(),
2817 model: None,
2818 },
2819 ]));
2820
2821 let runner = AgentRunner::builder(provider)
2822 .name("test")
2823 .system_prompt("sys")
2824 .tool(Arc::new(MockTool::new("search", "result")))
2825 .build()
2826 .unwrap();
2827
2828 let output = runner.execute("search").await.unwrap();
2829 assert_eq!(output.result, "Found it.");
2830 assert_eq!(output.tool_calls_made, 1);
2831 }
2832
2833 #[tokio::test]
2834 async fn on_input_continues_conversation() {
2835 let provider = Arc::new(MockProvider::new(vec![
2838 CompletionResponse {
2839 content: vec![ContentBlock::Text {
2840 text: "Hello! How can I help?".into(),
2841 }],
2842 stop_reason: StopReason::EndTurn,
2843 usage: TokenUsage::default(),
2844 model: None,
2845 },
2846 CompletionResponse {
2847 content: vec![ContentBlock::Text {
2848 text: "Sure, here you go.".into(),
2849 }],
2850 stop_reason: StopReason::EndTurn,
2851 usage: TokenUsage::default(),
2852 model: None,
2853 },
2854 ]));
2855
2856 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
2857 let call_count_clone = call_count.clone();
2858
2859 let on_input: Arc<OnInput> = Arc::new(move || {
2860 let count = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2861 Box::pin(async move {
2862 match count {
2863 0 => Some("Tell me more.".into()),
2864 _ => None, }
2866 })
2867 });
2868
2869 let runner = AgentRunner::builder(provider)
2870 .name("test")
2871 .system_prompt("sys")
2872 .max_turns(10)
2873 .on_input(on_input)
2874 .build()
2875 .unwrap();
2876
2877 let output = runner.execute("Hi").await.unwrap();
2878 assert_eq!(output.result, "Sure, here you go.");
2880 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
2882 }
2883
2884 #[tokio::test]
2885 async fn on_input_empty_string_ends_session() {
2886 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
2887 content: vec![ContentBlock::Text {
2888 text: "Response.".into(),
2889 }],
2890 stop_reason: StopReason::EndTurn,
2891 usage: TokenUsage::default(),
2892 model: None,
2893 }]));
2894
2895 let on_input: Arc<OnInput> = Arc::new(|| {
2896 Box::pin(async { Some(" ".into()) }) });
2898
2899 let runner = AgentRunner::builder(provider)
2900 .name("test")
2901 .system_prompt("sys")
2902 .max_turns(10)
2903 .on_input(on_input)
2904 .build()
2905 .unwrap();
2906
2907 let output = runner.execute("Hi").await.unwrap();
2908 assert_eq!(output.result, "Response.");
2909 }
2910
2911 #[tokio::test]
2912 async fn post_tool_guardrail_error_emits_event() {
2913 use std::sync::atomic::{AtomicBool, Ordering};
2914
2915 struct FailingPostTool;
2916 impl Guardrail for FailingPostTool {
2917 fn post_tool(
2918 &self,
2919 _call: &crate::llm::types::ToolCall,
2920 _output: &mut ToolOutput,
2921 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
2922 {
2923 Box::pin(async { Err(Error::Guardrail("output too large".into())) })
2924 }
2925 }
2926
2927 let provider = Arc::new(MockProvider::new(vec![
2928 CompletionResponse {
2929 content: vec![ContentBlock::ToolUse {
2930 id: "c1".into(),
2931 name: "search".into(),
2932 input: json!({}),
2933 }],
2934 stop_reason: StopReason::ToolUse,
2935 usage: TokenUsage::default(),
2936 model: None,
2937 },
2938 CompletionResponse {
2939 content: vec![ContentBlock::Text { text: "OK.".into() }],
2940 stop_reason: StopReason::EndTurn,
2941 usage: TokenUsage::default(),
2942 model: None,
2943 },
2944 ]));
2945
2946 let saw_post_tool_event = Arc::new(AtomicBool::new(false));
2947 let saw_clone = saw_post_tool_event.clone();
2948 let on_event: Arc<OnEvent> = Arc::new(move |event| {
2949 if let AgentEvent::GuardrailDenied { hook, .. } = &event
2950 && hook == "post_tool"
2951 {
2952 saw_clone.store(true, Ordering::SeqCst);
2953 }
2954 });
2955
2956 let runner = AgentRunner::builder(provider)
2957 .name("test")
2958 .system_prompt("sys")
2959 .tool(Arc::new(MockTool::new("search", "result")))
2960 .guardrail(Arc::new(FailingPostTool))
2961 .on_event(on_event)
2962 .build()
2963 .unwrap();
2964
2965 runner.execute("search").await.unwrap();
2966 assert!(
2967 saw_post_tool_event.load(Ordering::SeqCst),
2968 "should have emitted GuardrailDenied event with hook=post_tool"
2969 );
2970 }
2971
2972 #[tokio::test]
2978 async fn post_tool_guardrail_error_audit_carries_tenant_user() {
2979 struct FailingPostTool;
2980 impl Guardrail for FailingPostTool {
2981 fn post_tool(
2982 &self,
2983 _call: &crate::llm::types::ToolCall,
2984 _output: &mut ToolOutput,
2985 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
2986 {
2987 Box::pin(async { Err(Error::Guardrail("denied by policy".into())) })
2988 }
2989 }
2990
2991 let provider = Arc::new(MockProvider::new(vec![
2992 CompletionResponse {
2993 content: vec![ContentBlock::ToolUse {
2994 id: "c1".into(),
2995 name: "search".into(),
2996 input: json!({}),
2997 }],
2998 stop_reason: StopReason::ToolUse,
2999 usage: TokenUsage::default(),
3000 model: None,
3001 },
3002 CompletionResponse {
3003 content: vec![ContentBlock::Text { text: "OK.".into() }],
3004 stop_reason: StopReason::EndTurn,
3005 usage: TokenUsage::default(),
3006 model: None,
3007 },
3008 ]));
3009
3010 let trail = Arc::new(crate::agent::audit::InMemoryAuditTrail::new());
3011 let runner = AgentRunner::builder(provider)
3012 .name("test")
3013 .system_prompt("sys")
3014 .tool(Arc::new(MockTool::new("search", "result")))
3015 .guardrail(Arc::new(FailingPostTool))
3016 .audit_trail(trail.clone())
3017 .audit_user_context("alice", "tenant-1")
3018 .build()
3019 .unwrap();
3020
3021 runner.execute("search").await.unwrap();
3022
3023 let entries = trail.entries_unscoped(usize::MAX).await.unwrap();
3024 let denial = entries
3025 .iter()
3026 .find(|e| e.event_type == "guardrail_denied")
3027 .expect("expected a guardrail_denied audit record");
3028 assert_eq!(
3029 denial.user_id.as_deref(),
3030 Some("alice"),
3031 "audit record should carry user_id: {denial:?}"
3032 );
3033 assert_eq!(
3034 denial.tenant_id.as_deref(),
3035 Some("tenant-1"),
3036 "audit record should carry tenant_id: {denial:?}"
3037 );
3038 }
3039
3040 #[tokio::test]
3041 async fn without_on_input_returns_immediately() {
3042 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
3044 content: vec![ContentBlock::Text {
3045 text: "Done.".into(),
3046 }],
3047 stop_reason: StopReason::EndTurn,
3048 usage: TokenUsage::default(),
3049 model: None,
3050 }]));
3051
3052 let runner = AgentRunner::builder(provider)
3053 .name("test")
3054 .system_prompt("sys")
3055 .build()
3056 .unwrap();
3057
3058 let output = runner.execute("Hi").await.unwrap();
3059 assert_eq!(output.result, "Done.");
3060 }
3061
3062 #[tokio::test]
3063 async fn run_timeout_preserves_partial_usage() {
3064 struct SlowProvider;
3067 impl LlmProvider for SlowProvider {
3068 async fn complete(
3069 &self,
3070 request: CompletionRequest,
3071 ) -> Result<CompletionResponse, Error> {
3072 if request.messages.len() <= 1 {
3074 return Ok(CompletionResponse {
3075 content: vec![ContentBlock::ToolUse {
3076 id: "tc1".into(),
3077 name: "echo".into(),
3078 input: json!({}),
3079 }],
3080 stop_reason: StopReason::ToolUse,
3081 usage: TokenUsage {
3082 input_tokens: 100,
3083 output_tokens: 50,
3084 ..Default::default()
3085 },
3086 model: None,
3087 });
3088 }
3089 tokio::time::sleep(Duration::from_secs(3600)).await;
3091 unreachable!()
3092 }
3093 }
3094
3095 let provider = Arc::new(SlowProvider);
3096 let tool = Arc::new(MockTool::new("echo", "echoed"));
3097 let runner = AgentRunner::builder(provider)
3098 .name("timeout-test")
3099 .system_prompt("sys")
3100 .tool(tool)
3101 .max_turns(10)
3102 .run_timeout(Duration::from_millis(100))
3103 .build()
3104 .unwrap();
3105
3106 let err = runner.execute("go").await.unwrap_err();
3107 assert!(
3108 matches!(&err, Error::WithPartialUsage { source, .. }
3109 if matches!(**source, Error::RunTimeout(_))),
3110 "expected WithPartialUsage(RunTimeout), got: {err}"
3111 );
3112 let usage = err.partial_usage();
3113 assert_eq!(usage.input_tokens, 100, "should preserve input tokens");
3114 assert_eq!(usage.output_tokens, 50, "should preserve output tokens");
3115 }
3116
3117 #[tokio::test]
3118 async fn run_timeout_without_accumulated_usage() {
3119 struct ImmediatelySlowProvider;
3121 impl LlmProvider for ImmediatelySlowProvider {
3122 async fn complete(
3123 &self,
3124 _request: CompletionRequest,
3125 ) -> Result<CompletionResponse, Error> {
3126 tokio::time::sleep(Duration::from_secs(3600)).await;
3127 unreachable!()
3128 }
3129 }
3130
3131 let provider = Arc::new(ImmediatelySlowProvider);
3132 let runner = AgentRunner::builder(provider)
3133 .name("timeout-test")
3134 .system_prompt("sys")
3135 .run_timeout(Duration::from_millis(50))
3136 .build()
3137 .unwrap();
3138
3139 let err = runner.execute("go").await.unwrap_err();
3140 assert!(
3141 matches!(&err, Error::WithPartialUsage { source, .. }
3142 if matches!(**source, Error::RunTimeout(_))),
3143 "expected WithPartialUsage(RunTimeout), got: {err}"
3144 );
3145 let usage = err.partial_usage();
3146 assert_eq!(usage.input_tokens, 0);
3147 assert_eq!(usage.output_tokens, 0);
3148 }
3149
3150 #[tokio::test]
3151 async fn llm_error_mid_run_preserves_partial_usage() {
3152 struct FailOnSecondCall;
3155 impl LlmProvider for FailOnSecondCall {
3156 async fn complete(
3157 &self,
3158 request: CompletionRequest,
3159 ) -> Result<CompletionResponse, Error> {
3160 if request.messages.len() <= 1 {
3161 return Ok(CompletionResponse {
3162 content: vec![ContentBlock::ToolUse {
3163 id: "tc1".into(),
3164 name: "echo".into(),
3165 input: json!({}),
3166 }],
3167 stop_reason: StopReason::ToolUse,
3168 usage: TokenUsage {
3169 input_tokens: 200,
3170 output_tokens: 80,
3171 ..Default::default()
3172 },
3173 model: None,
3174 });
3175 }
3176 Err(Error::Api {
3177 status: 500,
3178 message: "internal server error".into(),
3179 })
3180 }
3181 }
3182
3183 let provider = Arc::new(FailOnSecondCall);
3184 let tool = Arc::new(MockTool::new("echo", "echoed"));
3185 let runner = AgentRunner::builder(provider)
3186 .name("mid-error-test")
3187 .system_prompt("sys")
3188 .tool(tool)
3189 .max_turns(10)
3190 .build()
3191 .unwrap();
3192
3193 let err = runner.execute("go").await.unwrap_err();
3194 assert!(
3195 matches!(&err, Error::WithPartialUsage { source, .. }
3196 if matches!(**source, Error::Api { status: 500, .. })),
3197 "expected WithPartialUsage(Api{{500}}), got: {err}"
3198 );
3199 let usage = err.partial_usage();
3200 assert_eq!(
3201 usage.input_tokens, 200,
3202 "should preserve input tokens from turn 1"
3203 );
3204 assert_eq!(
3205 usage.output_tokens, 80,
3206 "should preserve output tokens from turn 1"
3207 );
3208 }
3209
3210 #[tokio::test]
3213 async fn reflection_prompt_injected_after_tool_results() {
3214 struct ReflectionCapture {
3216 responses: Mutex<Vec<CompletionResponse>>,
3217 user_messages: Mutex<Vec<String>>,
3218 }
3219 impl LlmProvider for ReflectionCapture {
3220 async fn complete(
3221 &self,
3222 request: CompletionRequest,
3223 ) -> Result<CompletionResponse, Error> {
3224 for msg in &request.messages {
3225 if msg.role == crate::llm::types::Role::User {
3226 for block in &msg.content {
3227 if let ContentBlock::Text { text } = block {
3228 self.user_messages.lock().unwrap().push(text.clone());
3229 }
3230 }
3231 }
3232 }
3233 let mut responses = self.responses.lock().unwrap();
3234 if responses.is_empty() {
3235 return Err(Error::Agent("no more responses".into()));
3236 }
3237 Ok(responses.remove(0))
3238 }
3239 }
3240
3241 let provider = Arc::new(ReflectionCapture {
3242 responses: Mutex::new(vec![
3243 CompletionResponse {
3245 content: vec![ContentBlock::ToolUse {
3246 id: "t1".into(),
3247 name: "search".into(),
3248 input: json!({}),
3249 }],
3250 stop_reason: StopReason::ToolUse,
3251 usage: TokenUsage::default(),
3252 model: None,
3253 },
3254 CompletionResponse {
3256 content: vec![ContentBlock::Text {
3257 text: "Done.".into(),
3258 }],
3259 stop_reason: StopReason::EndTurn,
3260 usage: TokenUsage::default(),
3261 model: None,
3262 },
3263 ]),
3264 user_messages: Mutex::new(vec![]),
3265 });
3266
3267 let tool = Arc::new(MockTool::new("search", "found results"));
3268 let runner = AgentRunner::builder(provider.clone())
3269 .name("reflector")
3270 .system_prompt("sys")
3271 .tool(tool)
3272 .enable_reflection(true)
3273 .build()
3274 .unwrap();
3275
3276 let output = runner.execute("do something").await.unwrap();
3277 assert_eq!(output.result, "Done.");
3278
3279 let msgs = provider.user_messages.lock().unwrap();
3280 assert!(
3282 msgs.iter()
3283 .any(|m| m.contains("Before proceeding, briefly reflect")),
3284 "expected reflection prompt in user messages, got: {msgs:?}"
3285 );
3286 }
3287
3288 #[tokio::test]
3289 async fn reflection_not_injected_when_disabled() {
3290 struct ReflectionCapture {
3291 responses: Mutex<Vec<CompletionResponse>>,
3292 user_messages: Mutex<Vec<String>>,
3293 }
3294 impl LlmProvider for ReflectionCapture {
3295 async fn complete(
3296 &self,
3297 request: CompletionRequest,
3298 ) -> Result<CompletionResponse, Error> {
3299 for msg in &request.messages {
3300 if msg.role == crate::llm::types::Role::User {
3301 for block in &msg.content {
3302 if let ContentBlock::Text { text } = block {
3303 self.user_messages.lock().unwrap().push(text.clone());
3304 }
3305 }
3306 }
3307 }
3308 let mut responses = self.responses.lock().unwrap();
3309 if responses.is_empty() {
3310 return Err(Error::Agent("no more responses".into()));
3311 }
3312 Ok(responses.remove(0))
3313 }
3314 }
3315
3316 let provider = Arc::new(ReflectionCapture {
3317 responses: Mutex::new(vec![
3318 CompletionResponse {
3319 content: vec![ContentBlock::ToolUse {
3320 id: "t1".into(),
3321 name: "search".into(),
3322 input: json!({}),
3323 }],
3324 stop_reason: StopReason::ToolUse,
3325 usage: TokenUsage::default(),
3326 model: None,
3327 },
3328 CompletionResponse {
3329 content: vec![ContentBlock::Text {
3330 text: "Done.".into(),
3331 }],
3332 stop_reason: StopReason::EndTurn,
3333 usage: TokenUsage::default(),
3334 model: None,
3335 },
3336 ]),
3337 user_messages: Mutex::new(vec![]),
3338 });
3339
3340 let tool = Arc::new(MockTool::new("search", "found results"));
3341 let runner = AgentRunner::builder(provider.clone())
3343 .name("no-reflect")
3344 .system_prompt("sys")
3345 .tool(tool)
3346 .build()
3347 .unwrap();
3348
3349 let output = runner.execute("do something").await.unwrap();
3350 assert_eq!(output.result, "Done.");
3351
3352 let msgs = provider.user_messages.lock().unwrap();
3353 assert!(
3354 !msgs.iter().any(|m| m.contains("reflect")),
3355 "should not contain reflection prompt, got: {msgs:?}"
3356 );
3357 }
3358
3359 #[tokio::test]
3360 async fn reflection_not_injected_when_no_tool_calls() {
3361 struct ReflectionCapture {
3362 responses: Mutex<Vec<CompletionResponse>>,
3363 user_messages: Mutex<Vec<String>>,
3364 }
3365 impl LlmProvider for ReflectionCapture {
3366 async fn complete(
3367 &self,
3368 request: CompletionRequest,
3369 ) -> Result<CompletionResponse, Error> {
3370 for msg in &request.messages {
3371 if msg.role == crate::llm::types::Role::User {
3372 for block in &msg.content {
3373 if let ContentBlock::Text { text } = block {
3374 self.user_messages.lock().unwrap().push(text.clone());
3375 }
3376 }
3377 }
3378 }
3379 let mut responses = self.responses.lock().unwrap();
3380 if responses.is_empty() {
3381 return Err(Error::Agent("no more responses".into()));
3382 }
3383 Ok(responses.remove(0))
3384 }
3385 }
3386
3387 let provider = Arc::new(ReflectionCapture {
3388 responses: Mutex::new(vec![CompletionResponse {
3389 content: vec![ContentBlock::Text {
3390 text: "Direct answer.".into(),
3391 }],
3392 stop_reason: StopReason::EndTurn,
3393 usage: TokenUsage::default(),
3394 model: None,
3395 }]),
3396 user_messages: Mutex::new(vec![]),
3397 });
3398
3399 let runner = AgentRunner::builder(provider.clone())
3401 .name("no-tools")
3402 .system_prompt("sys")
3403 .enable_reflection(true)
3404 .build()
3405 .unwrap();
3406
3407 let output = runner.execute("just answer").await.unwrap();
3408 assert_eq!(output.result, "Direct answer.");
3409
3410 let msgs = provider.user_messages.lock().unwrap();
3411 assert!(
3412 !msgs.iter().any(|m| m.contains("reflect")),
3413 "no reflection when no tool calls, got: {msgs:?}"
3414 );
3415 }
3416
3417 #[tokio::test]
3420 async fn compress_short_output_unchanged() {
3421 let provider = Arc::new(MockProvider::new(vec![
3422 CompletionResponse {
3424 content: vec![ContentBlock::ToolUse {
3425 id: "t1".into(),
3426 name: "search".into(),
3427 input: json!({}),
3428 }],
3429 stop_reason: StopReason::ToolUse,
3430 usage: TokenUsage::default(),
3431 model: None,
3432 },
3433 CompletionResponse {
3435 content: vec![ContentBlock::Text {
3436 text: "Done.".into(),
3437 }],
3438 stop_reason: StopReason::EndTurn,
3439 usage: TokenUsage::default(),
3440 model: None,
3441 },
3442 ]));
3443
3444 let tool = Arc::new(MockTool::new("search", "short result"));
3445 let runner = AgentRunner::builder(provider)
3446 .name("compressor")
3447 .system_prompt("sys")
3448 .tool(tool)
3449 .tool_output_compression_threshold(10000)
3451 .build()
3452 .unwrap();
3453
3454 let output = runner.execute("search something").await.unwrap();
3455 assert_eq!(output.result, "Done.");
3456 assert_eq!(output.tool_calls_made, 1);
3458 }
3459
3460 #[tokio::test]
3461 async fn compress_long_output_calls_llm() {
3462 struct CompressionProvider {
3464 responses: Mutex<Vec<CompletionResponse>>,
3465 call_count: Mutex<usize>,
3466 }
3467 impl LlmProvider for CompressionProvider {
3468 async fn complete(
3469 &self,
3470 _request: CompletionRequest,
3471 ) -> Result<CompletionResponse, Error> {
3472 let mut count = self.call_count.lock().unwrap();
3473 *count += 1;
3474 let mut responses = self.responses.lock().unwrap();
3475 if responses.is_empty() {
3476 return Err(Error::Agent("no more responses".into()));
3477 }
3478 Ok(responses.remove(0))
3479 }
3480 }
3481
3482 let provider = Arc::new(CompressionProvider {
3483 responses: Mutex::new(vec![
3484 CompletionResponse {
3486 content: vec![ContentBlock::ToolUse {
3487 id: "t1".into(),
3488 name: "read".into(),
3489 input: json!({}),
3490 }],
3491 stop_reason: StopReason::ToolUse,
3492 usage: TokenUsage::default(),
3493 model: None,
3494 },
3495 CompletionResponse {
3497 content: vec![ContentBlock::Text {
3498 text: "Compressed summary of large file.".into(),
3499 }],
3500 stop_reason: StopReason::EndTurn,
3501 usage: TokenUsage {
3502 input_tokens: 50,
3503 output_tokens: 10,
3504 ..Default::default()
3505 },
3506 model: None,
3507 },
3508 CompletionResponse {
3510 content: vec![ContentBlock::Text {
3511 text: "Here's the result.".into(),
3512 }],
3513 stop_reason: StopReason::EndTurn,
3514 usage: TokenUsage::default(),
3515 model: None,
3516 },
3517 ]),
3518 call_count: Mutex::new(0),
3519 });
3520
3521 let large_output = "x".repeat(200);
3523 let tool = Arc::new(MockTool::new("read", &large_output));
3524 let runner = AgentRunner::builder(provider.clone())
3525 .name("compressor")
3526 .system_prompt("sys")
3527 .tool(tool)
3528 .tool_output_compression_threshold(50)
3529 .build()
3530 .unwrap();
3531
3532 let output = runner.execute("read the file").await.unwrap();
3533 assert_eq!(output.result, "Here's the result.");
3534 let calls = *provider.call_count.lock().unwrap();
3536 assert_eq!(calls, 3, "expected 3 LLM calls (tool + compress + answer)");
3537 assert_eq!(output.tokens_used.input_tokens, 50);
3539 assert_eq!(output.tokens_used.output_tokens, 10);
3540 }
3541
3542 #[tokio::test]
3543 async fn compression_preserves_error_status() {
3544 let provider = Arc::new(MockProvider::new(vec![
3546 CompletionResponse {
3547 content: vec![ContentBlock::ToolUse {
3548 id: "t1".into(),
3549 name: "failing_tool".into(),
3550 input: json!({}),
3551 }],
3552 stop_reason: StopReason::ToolUse,
3553 usage: TokenUsage::default(),
3554 model: None,
3555 },
3556 CompletionResponse {
3557 content: vec![ContentBlock::Text {
3558 text: "Tool failed.".into(),
3559 }],
3560 stop_reason: StopReason::EndTurn,
3561 usage: TokenUsage::default(),
3562 model: None,
3563 },
3564 ]));
3565
3566 let large_error = "e".repeat(200);
3567 let tool = Arc::new(MockTool::failing("failing_tool", &large_error));
3568 let runner = AgentRunner::builder(provider)
3569 .name("compressor")
3570 .system_prompt("sys")
3571 .tool(tool)
3572 .tool_output_compression_threshold(50)
3573 .build()
3574 .unwrap();
3575
3576 let output = runner.execute("try something").await.unwrap();
3577 assert_eq!(output.result, "Tool failed.");
3578 assert_eq!(output.tool_calls_made, 1);
3580 }
3581
3582 #[test]
3585 fn select_tools_returns_all_when_below_max() {
3586 let provider = Arc::new(MockProvider::new(vec![]));
3587 let runner = AgentRunner::builder(provider)
3588 .name("selector")
3589 .system_prompt("sys")
3590 .max_tools_per_turn(10)
3591 .build()
3592 .unwrap();
3593
3594 let tools = vec![
3595 ToolDefinition {
3596 name: "a".into(),
3597 description: "Tool A".into(),
3598 input_schema: json!({"type": "object"}),
3599 },
3600 ToolDefinition {
3601 name: "b".into(),
3602 description: "Tool B".into(),
3603 input_schema: json!({"type": "object"}),
3604 },
3605 ];
3606
3607 let selected = runner.select_tools_for_turn(&tools, &[], &[], 10);
3608 assert_eq!(selected.len(), 2, "should return all when below max");
3609 }
3610
3611 #[test]
3612 fn select_tools_includes_recently_used() {
3613 let provider = Arc::new(MockProvider::new(vec![]));
3614 let runner = AgentRunner::builder(provider)
3615 .name("selector")
3616 .system_prompt("sys")
3617 .max_tools_per_turn(2)
3618 .build()
3619 .unwrap();
3620
3621 let tools: Vec<ToolDefinition> = (0..5)
3622 .map(|i| ToolDefinition {
3623 name: format!("tool_{i}"),
3624 description: format!("Tool number {i}"),
3625 input_schema: json!({"type": "object"}),
3626 })
3627 .collect();
3628
3629 let recently_used = vec!["tool_3".to_string()];
3631 let selected = runner.select_tools_for_turn(&tools, &[], &recently_used, 2);
3632
3633 assert_eq!(selected.len(), 2, "should cap at max");
3634 assert!(
3635 selected.iter().any(|t| t.name == "tool_3"),
3636 "recently used tool must be included"
3637 );
3638 }
3639
3640 #[test]
3641 fn select_tools_keyword_match_ranking() {
3642 let provider = Arc::new(MockProvider::new(vec![]));
3643 let runner = AgentRunner::builder(provider)
3644 .name("selector")
3645 .system_prompt("sys")
3646 .max_tools_per_turn(2)
3647 .build()
3648 .unwrap();
3649
3650 let tools = vec![
3651 ToolDefinition {
3652 name: "web_search".into(),
3653 description: "Search the web".into(),
3654 input_schema: json!({"type": "object"}),
3655 },
3656 ToolDefinition {
3657 name: "read_file".into(),
3658 description: "Read a file from disk".into(),
3659 input_schema: json!({"type": "object"}),
3660 },
3661 ToolDefinition {
3662 name: "write_file".into(),
3663 description: "Write a file to disk".into(),
3664 input_schema: json!({"type": "object"}),
3665 },
3666 ToolDefinition {
3667 name: "run_command".into(),
3668 description: "Run a shell command".into(),
3669 input_schema: json!({"type": "object"}),
3670 },
3671 ];
3672
3673 let messages = vec![Message::user(
3675 "Please search the web for information.".to_string(),
3676 )];
3677 let selected = runner.select_tools_for_turn(&tools, &messages, &[], 2);
3678
3679 assert_eq!(selected.len(), 2);
3680 assert!(
3682 selected.iter().any(|t| t.name == "web_search"),
3683 "web_search should be selected by keyword match, got: {:?}",
3684 selected.iter().map(|t| &t.name).collect::<Vec<_>>()
3685 );
3686 }
3687
3688 #[test]
3689 fn select_tools_caps_at_max() {
3690 let provider = Arc::new(MockProvider::new(vec![]));
3691 let runner = AgentRunner::builder(provider)
3692 .name("selector")
3693 .system_prompt("sys")
3694 .max_tools_per_turn(3)
3695 .build()
3696 .unwrap();
3697
3698 let tools: Vec<ToolDefinition> = (0..10)
3699 .map(|i| ToolDefinition {
3700 name: format!("tool_{i}"),
3701 description: format!("Tool number {i}"),
3702 input_schema: json!({"type": "object"}),
3703 })
3704 .collect();
3705
3706 let selected = runner.select_tools_for_turn(&tools, &[], &[], 3);
3707 assert_eq!(selected.len(), 3, "should cap at max_tools");
3708 }
3709
3710 #[test]
3711 fn select_tools_caps_when_recently_used_exceeds_max() {
3712 let provider = Arc::new(MockProvider::new(vec![]));
3713 let runner = AgentRunner::builder(provider)
3714 .name("selector")
3715 .system_prompt("sys")
3716 .build()
3717 .unwrap();
3718
3719 let tools: Vec<ToolDefinition> = (0..5)
3720 .map(|i| ToolDefinition {
3721 name: format!("tool_{i}"),
3722 description: format!("Tool {i}"),
3723 input_schema: json!({"type": "object"}),
3724 })
3725 .collect();
3726
3727 let recently_used: Vec<String> = (0..4).map(|i| format!("tool_{i}")).collect();
3729 let selected = runner.select_tools_for_turn(&tools, &[], &recently_used, 2);
3730 assert_eq!(
3731 selected.len(),
3732 2,
3733 "should cap at max_tools even when recently_used exceeds it"
3734 );
3735 }
3736
3737 #[test]
3738 fn select_tools_preserves_respond_tool() {
3739 let provider = Arc::new(MockProvider::new(vec![]));
3740 let runner = AgentRunner::builder(provider)
3741 .name("test")
3742 .system_prompt("sys")
3743 .build()
3744 .unwrap();
3745
3746 let tools: Vec<ToolDefinition> = vec![
3747 ToolDefinition {
3748 name: "bash".into(),
3749 description: "Run commands".into(),
3750 input_schema: json!({"type": "object"}),
3751 },
3752 ToolDefinition {
3753 name: "read".into(),
3754 description: "Read files".into(),
3755 input_schema: json!({"type": "object"}),
3756 },
3757 ToolDefinition {
3758 name: "write".into(),
3759 description: "Write files".into(),
3760 input_schema: json!({"type": "object"}),
3761 },
3762 ToolDefinition {
3763 name: crate::llm::types::RESPOND_TOOL_NAME.into(),
3764 description: "Structured output".into(),
3765 input_schema: json!({"type": "object"}),
3766 },
3767 ];
3768
3769 let selected = runner.select_tools_for_turn(&tools, &[], &[], 2);
3771 assert!(
3772 selected.iter().any(|t| t.name == "__respond__"),
3773 "__respond__ must always survive select_tools_for_turn"
3774 );
3775 }
3776
3777 #[test]
3780 fn find_closest_tool_exact_match_returns_none() {
3781 let provider = Arc::new(MockProvider::new(vec![]));
3782 let runner = AgentRunner::builder(provider)
3783 .name("test")
3784 .system_prompt("sys")
3785 .tool(Arc::new(MockTool::new("read_file", "ok")))
3786 .build()
3787 .unwrap();
3788 assert!(runner.find_closest_tool("read_file", 2).is_none());
3789 }
3790
3791 #[test]
3792 fn find_closest_tool_within_distance() {
3793 let provider = Arc::new(MockProvider::new(vec![]));
3794 let runner = AgentRunner::builder(provider)
3795 .name("test")
3796 .system_prompt("sys")
3797 .tool(Arc::new(MockTool::new("read_file", "ok")))
3798 .build()
3799 .unwrap();
3800 assert_eq!(runner.find_closest_tool("reed_file", 2), Some("read_file"));
3801 }
3802
3803 #[test]
3804 fn find_closest_tool_too_far() {
3805 let provider = Arc::new(MockProvider::new(vec![]));
3806 let runner = AgentRunner::builder(provider)
3807 .name("test")
3808 .system_prompt("sys")
3809 .tool(Arc::new(MockTool::new("read_file", "ok")))
3810 .build()
3811 .unwrap();
3812 assert!(runner.find_closest_tool("completely_wrong", 2).is_none());
3813 }
3814
3815 #[test]
3816 fn find_closest_tool_prefers_closest() {
3817 let provider = Arc::new(MockProvider::new(vec![]));
3818 let runner = AgentRunner::builder(provider)
3819 .name("test")
3820 .system_prompt("sys")
3821 .tool(Arc::new(MockTool::new("read_fil", "ok")))
3822 .tool(Arc::new(MockTool::new("read_file", "ok")))
3823 .build()
3824 .unwrap();
3825 assert_eq!(runner.find_closest_tool("read_fi", 2), Some("read_fil"));
3826 }
3827
3828 #[tokio::test]
3829 async fn tool_name_repair_executes_correct_tool() {
3830 let provider = Arc::new(MockProvider::new(vec![
3831 CompletionResponse {
3832 content: vec![ContentBlock::ToolUse {
3833 id: "tc1".into(),
3834 name: "reed_file".into(),
3835 input: json!({}),
3836 }],
3837 stop_reason: StopReason::ToolUse,
3838 usage: TokenUsage {
3839 input_tokens: 10,
3840 output_tokens: 5,
3841 ..Default::default()
3842 },
3843 model: None,
3844 },
3845 CompletionResponse {
3846 content: vec![ContentBlock::Text {
3847 text: "Done!".into(),
3848 }],
3849 stop_reason: StopReason::EndTurn,
3850 usage: TokenUsage {
3851 input_tokens: 15,
3852 output_tokens: 3,
3853 ..Default::default()
3854 },
3855 model: None,
3856 },
3857 ]));
3858 let runner = AgentRunner::builder(provider)
3859 .name("repair-test")
3860 .system_prompt("sys")
3861 .tool(Arc::new(MockTool::new("read_file", "file contents here")))
3862 .build()
3863 .unwrap();
3864 let output = runner.execute("read the file").await.unwrap();
3865 assert_eq!(output.result, "Done!");
3866 assert_eq!(output.tool_calls_made, 1);
3867 }
3868
3869 #[tokio::test]
3870 async fn tool_name_too_far_returns_error() {
3871 let provider = Arc::new(MockProvider::new(vec![
3872 CompletionResponse {
3873 content: vec![ContentBlock::ToolUse {
3874 id: "tc1".into(),
3875 name: "completely_wrong".into(),
3876 input: json!({}),
3877 }],
3878 stop_reason: StopReason::ToolUse,
3879 usage: TokenUsage {
3880 input_tokens: 10,
3881 output_tokens: 5,
3882 ..Default::default()
3883 },
3884 model: None,
3885 },
3886 CompletionResponse {
3887 content: vec![ContentBlock::Text {
3888 text: "Error handled".into(),
3889 }],
3890 stop_reason: StopReason::EndTurn,
3891 usage: TokenUsage {
3892 input_tokens: 15,
3893 output_tokens: 3,
3894 ..Default::default()
3895 },
3896 model: None,
3897 },
3898 ]));
3899 let runner = AgentRunner::builder(provider)
3900 .name("repair-test")
3901 .system_prompt("sys")
3902 .tool(Arc::new(MockTool::new("read_file", "file contents here")))
3903 .build()
3904 .unwrap();
3905 let output = runner.execute("do something").await.unwrap();
3906 assert_eq!(output.result, "Error handled");
3907 assert_eq!(output.tool_calls_made, 1);
3908 }
3909
3910 struct FallibleMockProvider {
3913 responses: Mutex<Vec<Result<CompletionResponse, Error>>>,
3914 }
3915
3916 impl FallibleMockProvider {
3917 fn new(responses: Vec<Result<CompletionResponse, Error>>) -> Self {
3918 Self {
3919 responses: Mutex::new(responses),
3920 }
3921 }
3922 }
3923
3924 impl LlmProvider for FallibleMockProvider {
3925 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
3926 let mut responses = self.responses.lock().expect("mock lock poisoned");
3927 if responses.is_empty() {
3928 return Err(Error::Agent("no more mock responses".into()));
3929 }
3930 responses.remove(0)
3931 }
3932 }
3933
3934 fn overflow_error() -> Error {
3935 Error::Api {
3936 status: 400,
3937 message: "prompt is too long: 250000 tokens > 200000 maximum".into(),
3938 }
3939 }
3940
3941 fn success_response(text: &str) -> CompletionResponse {
3942 CompletionResponse {
3943 content: vec![ContentBlock::Text { text: text.into() }],
3944 stop_reason: StopReason::EndTurn,
3945 usage: TokenUsage {
3946 input_tokens: 10,
3947 output_tokens: 5,
3948 ..Default::default()
3949 },
3950 model: None,
3951 }
3952 }
3953
3954 fn tool_use_response(id: &str, tool_name: &str) -> CompletionResponse {
3955 CompletionResponse {
3956 content: vec![ContentBlock::ToolUse {
3957 id: id.into(),
3958 name: tool_name.into(),
3959 input: json!({}),
3960 }],
3961 stop_reason: StopReason::ToolUse,
3962 usage: TokenUsage {
3963 input_tokens: 10,
3964 output_tokens: 5,
3965 ..Default::default()
3966 },
3967 model: None,
3968 }
3969 }
3970
3971 #[tokio::test]
3972 async fn auto_compaction_on_context_overflow() {
3973 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
3984 Arc::new(std::sync::Mutex::new(vec![]));
3985 let events_clone = events.clone();
3986
3987 let provider = Arc::new(FallibleMockProvider::new(vec![
3988 Ok(tool_use_response("c1", "search")), Ok(tool_use_response("c2", "search")), Ok(tool_use_response("c3", "search")), Err(overflow_error()), Ok(success_response("Summary of conversation so far")), Ok(success_response("Final answer after compaction")), ]));
3995
3996 let runner = AgentRunner::builder(provider)
3997 .name("test-compact")
3998 .system_prompt("sys")
3999 .tool(Arc::new(MockTool::new("search", "result")))
4000 .max_turns(10)
4001 .on_event(Arc::new(move |e| {
4002 events_clone.lock().unwrap().push(e);
4003 }))
4004 .build()
4005 .unwrap();
4006
4007 let output = runner.execute("do something").await.unwrap();
4008 assert_eq!(output.result, "Final answer after compaction");
4009
4010 let events = events.lock().unwrap();
4011 let summarized = events
4012 .iter()
4013 .any(|e| matches!(e, AgentEvent::ContextSummarized { .. }));
4014 assert!(summarized, "expected ContextSummarized event");
4015 }
4016
4017 #[tokio::test]
4018 async fn auto_compaction_not_attempted_twice() {
4019 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4022 Arc::new(std::sync::Mutex::new(vec![]));
4023 let events_clone = events.clone();
4024
4025 let provider = Arc::new(FallibleMockProvider::new(vec![
4026 Ok(tool_use_response("c1", "search")),
4027 Ok(tool_use_response("c2", "search")),
4028 Ok(tool_use_response("c3", "search")),
4029 Err(overflow_error()),
4030 Ok(success_response("Summary")),
4031 Err(overflow_error()), ]));
4033
4034 let runner = AgentRunner::builder(provider)
4035 .name("test-compact")
4036 .system_prompt("sys")
4037 .tool(Arc::new(MockTool::new("search", "result")))
4038 .max_turns(10)
4039 .on_event(Arc::new(move |e| {
4040 events_clone.lock().unwrap().push(e);
4041 }))
4042 .build()
4043 .unwrap();
4044
4045 let err = runner.execute("do something").await.unwrap_err();
4046 let inner = match &err {
4047 Error::WithPartialUsage { source, .. } => source.as_ref(),
4048 other => other,
4049 };
4050 assert!(
4051 matches!(inner, Error::Api { status: 400, .. }),
4052 "expected overflow error, got: {err:?}"
4053 );
4054
4055 let events = events.lock().unwrap();
4056 let count = events
4057 .iter()
4058 .filter(|e| matches!(e, AgentEvent::ContextSummarized { .. }))
4059 .count();
4060 assert_eq!(count, 1, "compaction attempted exactly once");
4061 }
4062
4063 #[tokio::test]
4064 async fn auto_compaction_skipped_when_too_few_messages() {
4065 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4067 Arc::new(std::sync::Mutex::new(vec![]));
4068 let events_clone = events.clone();
4069
4070 let provider = Arc::new(FallibleMockProvider::new(vec![Err(overflow_error())]));
4071
4072 let runner = AgentRunner::builder(provider)
4073 .name("test-compact")
4074 .system_prompt("sys")
4075 .max_turns(10)
4076 .on_event(Arc::new(move |e| {
4077 events_clone.lock().unwrap().push(e);
4078 }))
4079 .build()
4080 .unwrap();
4081
4082 let err = runner.execute("short task").await.unwrap_err();
4083 let inner = match &err {
4084 Error::WithPartialUsage { source, .. } => source.as_ref(),
4085 other => other,
4086 };
4087 assert!(
4088 matches!(inner, Error::Api { status: 400, .. }),
4089 "expected overflow error, got: {err:?}"
4090 );
4091
4092 let events = events.lock().unwrap();
4093 let count = events
4094 .iter()
4095 .filter(|e| matches!(e, AgentEvent::ContextSummarized { .. }))
4096 .count();
4097 assert_eq!(count, 0, "no compaction with too few messages");
4098 }
4099
4100 #[test]
4103 fn doom_loop_tracker_detects_repeated_calls() {
4104 let mut tracker = DoomLoopTracker::new();
4105 let calls = vec![ToolCall {
4106 id: "call-1".into(),
4107 name: "search".into(),
4108 input: json!({"query": "rust"}),
4109 }];
4110 assert!(!tracker.record(&calls, 3, None).0);
4111 assert!(!tracker.record(&calls, 3, None).0);
4112 assert!(tracker.record(&calls, 3, None).0); }
4114
4115 #[test]
4116 fn doom_loop_tracker_resets_on_different_call() {
4117 let mut tracker = DoomLoopTracker::new();
4118 let calls_a = vec![ToolCall {
4119 id: "call-1".into(),
4120 name: "search".into(),
4121 input: json!({"query": "rust"}),
4122 }];
4123 let calls_b = vec![ToolCall {
4124 id: "call-2".into(),
4125 name: "search".into(),
4126 input: json!({"query": "python"}),
4127 }];
4128 assert!(!tracker.record(&calls_a, 3, None).0);
4129 assert!(!tracker.record(&calls_a, 3, None).0);
4130 assert!(!tracker.record(&calls_b, 3, None).0);
4132 assert!(!tracker.record(&calls_b, 3, None).0);
4133 assert!(tracker.record(&calls_b, 3, None).0); }
4135
4136 #[test]
4137 fn doom_loop_tracker_ignores_call_id_differences() {
4138 let mut tracker = DoomLoopTracker::new();
4140 let calls_1 = vec![ToolCall {
4141 id: "call-1".into(),
4142 name: "read".into(),
4143 input: json!({"file": "foo.txt"}),
4144 }];
4145 let calls_2 = vec![ToolCall {
4146 id: "call-2".into(),
4147 name: "read".into(),
4148 input: json!({"file": "foo.txt"}),
4149 }];
4150 assert!(!tracker.record(&calls_1, 2, None).0);
4151 assert!(tracker.record(&calls_2, 2, None).0); }
4153
4154 #[test]
4155 fn doom_loop_tracker_multi_tool_turn() {
4156 let mut tracker = DoomLoopTracker::new();
4157 let calls = vec![
4158 ToolCall {
4159 id: "a".into(),
4160 name: "search".into(),
4161 input: json!({"q": "x"}),
4162 },
4163 ToolCall {
4164 id: "b".into(),
4165 name: "read".into(),
4166 input: json!({"file": "y"}),
4167 },
4168 ];
4169 assert!(!tracker.record(&calls, 2, None).0);
4170 assert!(tracker.record(&calls, 2, None).0);
4171 }
4172
4173 #[test]
4174 fn fuzzy_doom_loop_same_tools_different_inputs() {
4175 let mut tracker = DoomLoopTracker::new();
4176 let calls_a = vec![ToolCall {
4177 id: "c1".into(),
4178 name: "search".into(),
4179 input: json!({"query": "rust"}),
4180 }];
4181 let calls_b = vec![ToolCall {
4182 id: "c2".into(),
4183 name: "search".into(),
4184 input: json!({"query": "python"}),
4185 }];
4186 let calls_c = vec![ToolCall {
4187 id: "c3".into(),
4188 name: "search".into(),
4189 input: json!({"query": "go"}),
4190 }];
4191 let (exact, fuzzy) = tracker.record(&calls_a, 5, Some(3));
4193 assert!(!exact && !fuzzy, "first call: no detection");
4194 let (exact, fuzzy) = tracker.record(&calls_b, 5, Some(3));
4195 assert!(!exact && !fuzzy, "second call: no detection yet");
4196 let (exact, fuzzy) = tracker.record(&calls_c, 5, Some(3));
4197 assert!(!exact && fuzzy, "third call: fuzzy triggered");
4198 }
4199
4200 #[test]
4201 fn fuzzy_doom_loop_different_tools_no_trigger() {
4202 let mut tracker = DoomLoopTracker::new();
4203 let calls_a = vec![ToolCall {
4204 id: "c1".into(),
4205 name: "search".into(),
4206 input: json!({"query": "rust"}),
4207 }];
4208 let calls_b = vec![ToolCall {
4209 id: "c2".into(),
4210 name: "read".into(),
4211 input: json!({"file": "foo.txt"}),
4212 }];
4213 let calls_c = vec![ToolCall {
4214 id: "c3".into(),
4215 name: "write".into(),
4216 input: json!({"file": "bar.txt"}),
4217 }];
4218 let (_, fuzzy) = tracker.record(&calls_a, 5, Some(3));
4220 assert!(!fuzzy);
4221 let (_, fuzzy) = tracker.record(&calls_b, 5, Some(3));
4222 assert!(!fuzzy);
4223 let (_, fuzzy) = tracker.record(&calls_c, 5, Some(3));
4224 assert!(!fuzzy);
4225 }
4226
4227 #[test]
4228 fn fuzzy_doom_loop_disabled_by_default() {
4229 let mut tracker = DoomLoopTracker::new();
4230 let calls_a = vec![ToolCall {
4231 id: "c1".into(),
4232 name: "search".into(),
4233 input: json!({"query": "rust"}),
4234 }];
4235 let calls_b = vec![ToolCall {
4236 id: "c2".into(),
4237 name: "search".into(),
4238 input: json!({"query": "python"}),
4239 }];
4240 let (_, fuzzy) = tracker.record(&calls_a, 5, None);
4242 assert!(!fuzzy);
4243 let (_, fuzzy) = tracker.record(&calls_b, 5, None);
4244 assert!(!fuzzy);
4245 }
4246
4247 #[test]
4248 fn exact_match_does_not_double_trigger_fuzzy() {
4249 let mut tracker = DoomLoopTracker::new();
4250 let calls = vec![ToolCall {
4251 id: "c1".into(),
4252 name: "search".into(),
4253 input: json!({"query": "rust"}),
4254 }];
4255 let (exact, fuzzy) = tracker.record(&calls, 3, Some(3));
4257 assert!(!exact && !fuzzy);
4258 let (exact, fuzzy) = tracker.record(&calls, 3, Some(3));
4259 assert!(!exact && !fuzzy);
4260 let (exact, fuzzy) = tracker.record(&calls, 3, Some(3));
4262 assert!(exact, "exact should trigger");
4263 assert!(!fuzzy, "fuzzy should not trigger when exact fires");
4264 }
4265
4266 #[test]
4267 fn exact_match_resets_fuzzy_count() {
4268 let mut tracker = DoomLoopTracker::new();
4269 let calls_a = vec![ToolCall {
4271 id: "c1".into(),
4272 name: "search".into(),
4273 input: json!({"query": "a"}),
4274 }];
4275 let calls_b = vec![ToolCall {
4276 id: "c2".into(),
4277 name: "search".into(),
4278 input: json!({"query": "b"}),
4279 }];
4280 let calls_c = vec![ToolCall {
4281 id: "c3".into(),
4282 name: "read".into(),
4283 input: json!({"file": "x"}),
4284 }];
4285 tracker.record(&calls_a, 5, Some(3));
4286 tracker.record(&calls_b, 5, Some(3));
4287 tracker.record(&calls_c, 5, Some(3));
4289 assert_eq!(
4290 tracker.fuzzy_count(),
4291 1,
4292 "fuzzy count reset on different tools"
4293 );
4294 }
4295
4296 #[test]
4297 fn builder_rejects_zero_max_fuzzy_identical_tool_calls() {
4298 let provider = Arc::new(MockProvider::new(vec![]));
4299 let result = AgentRunner::builder(provider)
4300 .name("test")
4301 .system_prompt("sys")
4302 .max_fuzzy_identical_tool_calls(0)
4303 .build();
4304 match result {
4305 Err(e) => {
4306 let msg = e.to_string();
4307 assert!(
4308 msg.contains("max_fuzzy_identical_tool_calls must be at least 1"),
4309 "error: {msg}"
4310 );
4311 }
4312 Ok(_) => panic!("expected error for max_fuzzy_identical_tool_calls(0)"),
4313 }
4314 }
4315
4316 #[tokio::test]
4317 async fn doom_loop_detected_after_threshold() {
4318 let tool_response = |id: &str| CompletionResponse {
4321 content: vec![ContentBlock::ToolUse {
4322 id: id.into(),
4323 name: "my_tool".into(),
4324 input: json!({"key": "same_value"}),
4325 }],
4326 stop_reason: StopReason::ToolUse,
4327 usage: TokenUsage::default(),
4328 model: None,
4329 };
4330
4331 let provider = Arc::new(MockProvider::new(vec![
4332 tool_response("c1"),
4333 tool_response("c2"),
4334 tool_response("c3"), CompletionResponse {
4337 content: vec![ContentBlock::Text {
4338 text: "I'll try something different.".into(),
4339 }],
4340 stop_reason: StopReason::EndTurn,
4341 usage: TokenUsage::default(),
4342 model: None,
4343 },
4344 ]));
4345
4346 let tool = MockTool::new("my_tool", "tool result");
4347 let runner = AgentRunner::builder(provider)
4348 .name("test")
4349 .system_prompt("sys")
4350 .tool(Arc::new(tool))
4351 .max_turns(10)
4352 .max_identical_tool_calls(3)
4353 .build()
4354 .unwrap();
4355
4356 let output = runner.execute("do something").await.unwrap();
4357 assert_eq!(output.result, "I'll try something different.");
4358 assert_eq!(output.tool_calls_made, 3);
4360 }
4361
4362 #[tokio::test]
4363 async fn doom_loop_resets_on_different_call() {
4364 let provider = Arc::new(MockProvider::new(vec![
4367 CompletionResponse {
4368 content: vec![ContentBlock::ToolUse {
4369 id: "c1".into(),
4370 name: "my_tool".into(),
4371 input: json!({"key": "value_a"}),
4372 }],
4373 stop_reason: StopReason::ToolUse,
4374 usage: TokenUsage::default(),
4375 model: None,
4376 },
4377 CompletionResponse {
4378 content: vec![ContentBlock::ToolUse {
4379 id: "c2".into(),
4380 name: "my_tool".into(),
4381 input: json!({"key": "value_a"}),
4382 }],
4383 stop_reason: StopReason::ToolUse,
4384 usage: TokenUsage::default(),
4385 model: None,
4386 },
4387 CompletionResponse {
4389 content: vec![ContentBlock::ToolUse {
4390 id: "c3".into(),
4391 name: "my_tool".into(),
4392 input: json!({"key": "value_b"}),
4393 }],
4394 stop_reason: StopReason::ToolUse,
4395 usage: TokenUsage::default(),
4396 model: None,
4397 },
4398 CompletionResponse {
4399 content: vec![ContentBlock::ToolUse {
4400 id: "c4".into(),
4401 name: "my_tool".into(),
4402 input: json!({"key": "value_b"}),
4403 }],
4404 stop_reason: StopReason::ToolUse,
4405 usage: TokenUsage::default(),
4406 model: None,
4407 },
4408 CompletionResponse {
4409 content: vec![ContentBlock::Text {
4410 text: "done".into(),
4411 }],
4412 stop_reason: StopReason::EndTurn,
4413 usage: TokenUsage::default(),
4414 model: None,
4415 },
4416 ]));
4417
4418 let tool = MockTool::new("my_tool", "result");
4419 let runner = AgentRunner::builder(provider)
4420 .name("test")
4421 .system_prompt("sys")
4422 .tool(Arc::new(tool))
4423 .max_turns(10)
4424 .max_identical_tool_calls(3)
4425 .build()
4426 .unwrap();
4427
4428 let output = runner.execute("task").await.unwrap();
4429 assert_eq!(output.result, "done");
4430 assert_eq!(output.tool_calls_made, 4);
4432 }
4433
4434 #[tokio::test]
4435 async fn doom_loop_disabled_by_default() {
4436 let tool_response = |id: &str| CompletionResponse {
4439 content: vec![ContentBlock::ToolUse {
4440 id: id.into(),
4441 name: "my_tool".into(),
4442 input: json!({"key": "same"}),
4443 }],
4444 stop_reason: StopReason::ToolUse,
4445 usage: TokenUsage::default(),
4446 model: None,
4447 };
4448
4449 let provider = Arc::new(MockProvider::new(vec![
4450 tool_response("c1"),
4451 tool_response("c2"),
4452 tool_response("c3"),
4453 tool_response("c4"),
4454 tool_response("c5"),
4455 CompletionResponse {
4456 content: vec![ContentBlock::Text {
4457 text: "done".into(),
4458 }],
4459 stop_reason: StopReason::EndTurn,
4460 usage: TokenUsage::default(),
4461 model: None,
4462 },
4463 ]));
4464
4465 let tool = MockTool::new("my_tool", "result");
4466 let runner = AgentRunner::builder(provider)
4467 .name("test")
4468 .system_prompt("sys")
4469 .tool(Arc::new(tool))
4470 .max_turns(10)
4471 .build()
4473 .unwrap();
4474
4475 let output = runner.execute("task").await.unwrap();
4476 assert_eq!(output.result, "done");
4477 assert_eq!(output.tool_calls_made, 5);
4479 }
4480
4481 #[test]
4482 fn builder_rejects_zero_max_identical_tool_calls() {
4483 let provider = Arc::new(MockProvider::new(vec![]));
4484 let result = AgentRunner::builder(provider)
4485 .name("test")
4486 .system_prompt("sys")
4487 .max_identical_tool_calls(0)
4488 .build();
4489 match result {
4490 Err(e) => {
4491 let msg = e.to_string();
4492 assert!(
4493 msg.contains("max_identical_tool_calls must be at least 1"),
4494 "error: {msg}"
4495 );
4496 }
4497 Ok(_) => panic!("expected error for max_identical_tool_calls(0)"),
4498 }
4499 }
4500
4501 #[test]
4502 fn builder_rejects_zero_max_total_tokens() {
4503 let provider = Arc::new(MockProvider::new(vec![]));
4504 let result = AgentRunner::builder(provider)
4505 .name("test")
4506 .system_prompt("sys")
4507 .max_total_tokens(0)
4508 .build();
4509 match result {
4510 Err(e) => {
4511 let msg = e.to_string();
4512 assert!(
4513 msg.contains("max_total_tokens must be at least 1"),
4514 "error: {msg}"
4515 );
4516 }
4517 Ok(_) => panic!("expected error for max_total_tokens(0)"),
4518 }
4519 }
4520
4521 #[test]
4522 fn builder_rejects_zero_response_cache_size() {
4523 let provider = Arc::new(MockProvider::new(vec![]));
4524 let result = AgentRunner::builder(provider)
4525 .name("test")
4526 .system_prompt("test")
4527 .response_cache_size(0)
4528 .build();
4529 match result {
4530 Err(e) => {
4531 let msg = e.to_string();
4532 assert!(
4533 msg.contains("response_cache_size must be at least 1"),
4534 "error: {msg}"
4535 );
4536 }
4537 Ok(_) => panic!("expected error for response_cache_size(0)"),
4538 }
4539 }
4540
4541 #[tokio::test]
4544 async fn permission_allow_bypasses_approval() {
4545 let provider = Arc::new(MockProvider::new(vec![
4547 CompletionResponse {
4548 content: vec![ContentBlock::ToolUse {
4549 id: "c1".into(),
4550 name: "read_file".into(),
4551 input: json!({"path": "src/main.rs"}),
4552 }],
4553 stop_reason: StopReason::ToolUse,
4554 usage: TokenUsage::default(),
4555 model: None,
4556 },
4557 CompletionResponse {
4558 content: vec![ContentBlock::Text {
4559 text: "done".into(),
4560 }],
4561 stop_reason: StopReason::EndTurn,
4562 usage: TokenUsage::default(),
4563 model: None,
4564 },
4565 ]));
4566
4567 let rules = permission::PermissionRuleset::new(vec![permission::PermissionRule {
4568 tool: "read_file".into(),
4569 pattern: "*".into(),
4570 action: permission::PermissionAction::Allow,
4571 }]);
4572
4573 let approval_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
4574 let approval_called_clone = approval_called.clone();
4575
4576 let runner = AgentRunner::builder(provider)
4577 .name("perm-test")
4578 .system_prompt("sys")
4579 .tool(Arc::new(MockTool::new("read_file", "file contents")))
4580 .on_approval(Arc::new(move |_: &[ToolCall]| {
4581 approval_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4582 crate::llm::ApprovalDecision::Deny }))
4584 .permission_rules(rules)
4585 .build()
4586 .unwrap();
4587
4588 let output = runner.execute("read something").await.unwrap();
4589 assert_eq!(output.result, "done");
4590 assert!(!approval_called.load(std::sync::atomic::Ordering::SeqCst));
4591 }
4592
4593 #[tokio::test]
4594 async fn permission_deny_returns_error_result() {
4595 let provider = Arc::new(MockProvider::new(vec![
4596 CompletionResponse {
4597 content: vec![ContentBlock::ToolUse {
4598 id: "c1".into(),
4599 name: "bash".into(),
4600 input: json!({"command": "rm -rf /"}),
4601 }],
4602 stop_reason: StopReason::ToolUse,
4603 usage: TokenUsage::default(),
4604 model: None,
4605 },
4606 CompletionResponse {
4607 content: vec![ContentBlock::Text {
4608 text: "ok i won't do that".into(),
4609 }],
4610 stop_reason: StopReason::EndTurn,
4611 usage: TokenUsage::default(),
4612 model: None,
4613 },
4614 ]));
4615
4616 let rules = permission::PermissionRuleset::new(vec![permission::PermissionRule {
4617 tool: "bash".into(),
4618 pattern: "rm *".into(),
4619 action: permission::PermissionAction::Deny,
4620 }]);
4621
4622 let runner = AgentRunner::builder(provider)
4623 .name("perm-test")
4624 .system_prompt("sys")
4625 .tool(Arc::new(MockTool::new("bash", "executed")))
4626 .permission_rules(rules)
4627 .build()
4628 .unwrap();
4629
4630 let output = runner.execute("delete everything").await.unwrap();
4631 assert_eq!(output.result, "ok i won't do that");
4632 assert_eq!(output.tool_calls_made, 1);
4634 }
4635
4636 #[tokio::test]
4637 async fn permission_ask_falls_through_to_approval() {
4638 let provider = Arc::new(MockProvider::new(vec![
4639 CompletionResponse {
4640 content: vec![ContentBlock::ToolUse {
4641 id: "c1".into(),
4642 name: "bash".into(),
4643 input: json!({"command": "cargo test"}),
4644 }],
4645 stop_reason: StopReason::ToolUse,
4646 usage: TokenUsage::default(),
4647 model: None,
4648 },
4649 CompletionResponse {
4650 content: vec![ContentBlock::Text {
4651 text: "tests passed".into(),
4652 }],
4653 stop_reason: StopReason::EndTurn,
4654 usage: TokenUsage::default(),
4655 model: None,
4656 },
4657 ]));
4658
4659 let rules = permission::PermissionRuleset::new(vec![
4660 permission::PermissionRule {
4661 tool: "bash".into(),
4662 pattern: "rm *".into(),
4663 action: permission::PermissionAction::Deny,
4664 },
4665 permission::PermissionRule {
4666 tool: "bash".into(),
4667 pattern: "*".into(),
4668 action: permission::PermissionAction::Ask,
4669 },
4670 ]);
4671
4672 let approval_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
4673 let approval_called_clone = approval_called.clone();
4674
4675 let runner = AgentRunner::builder(provider)
4676 .name("perm-test")
4677 .system_prompt("sys")
4678 .tool(Arc::new(MockTool::new("bash", "ok")))
4679 .on_approval(Arc::new(move |_: &[ToolCall]| {
4680 approval_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4681 crate::llm::ApprovalDecision::Allow }))
4683 .permission_rules(rules)
4684 .build()
4685 .unwrap();
4686
4687 let output = runner.execute("run tests").await.unwrap();
4688 assert_eq!(output.result, "tests passed");
4689 assert!(approval_called.load(std::sync::atomic::Ordering::SeqCst));
4690 }
4691
4692 #[tokio::test]
4693 async fn permission_mixed_allow_and_deny() {
4694 let provider = Arc::new(MockProvider::new(vec![
4696 CompletionResponse {
4697 content: vec![
4698 ContentBlock::ToolUse {
4699 id: "c1".into(),
4700 name: "read_file".into(),
4701 input: json!({"path": "src/main.rs"}),
4702 },
4703 ContentBlock::ToolUse {
4704 id: "c2".into(),
4705 name: "read_file".into(),
4706 input: json!({"path": ".env"}),
4707 },
4708 ],
4709 stop_reason: StopReason::ToolUse,
4710 usage: TokenUsage::default(),
4711 model: None,
4712 },
4713 CompletionResponse {
4714 content: vec![ContentBlock::Text {
4715 text: "got it".into(),
4716 }],
4717 stop_reason: StopReason::EndTurn,
4718 usage: TokenUsage::default(),
4719 model: None,
4720 },
4721 ]));
4722
4723 let rules = permission::PermissionRuleset::new(vec![
4724 permission::PermissionRule {
4725 tool: "*".into(),
4726 pattern: "*.env*".into(),
4727 action: permission::PermissionAction::Deny,
4728 },
4729 permission::PermissionRule {
4730 tool: "read_file".into(),
4731 pattern: "*".into(),
4732 action: permission::PermissionAction::Allow,
4733 },
4734 ]);
4735
4736 let runner = AgentRunner::builder(provider)
4737 .name("perm-test")
4738 .system_prompt("sys")
4739 .tool(Arc::new(MockTool::new("read_file", "contents")))
4740 .permission_rules(rules)
4741 .build()
4742 .unwrap();
4743
4744 let output = runner.execute("read files").await.unwrap();
4745 assert_eq!(output.result, "got it");
4746 assert_eq!(output.tool_calls_made, 2);
4748 }
4749
4750 #[tokio::test]
4751 async fn permission_no_rules_uses_legacy_approval() {
4752 let provider = Arc::new(MockProvider::new(vec![
4754 CompletionResponse {
4755 content: vec![ContentBlock::ToolUse {
4756 id: "c1".into(),
4757 name: "bash".into(),
4758 input: json!({"command": "ls"}),
4759 }],
4760 stop_reason: StopReason::ToolUse,
4761 usage: TokenUsage::default(),
4762 model: None,
4763 },
4764 CompletionResponse {
4765 content: vec![ContentBlock::Text {
4766 text: "denied".into(),
4767 }],
4768 stop_reason: StopReason::EndTurn,
4769 usage: TokenUsage::default(),
4770 model: None,
4771 },
4772 ]));
4773
4774 let runner = AgentRunner::builder(provider)
4775 .name("perm-test")
4776 .system_prompt("sys")
4777 .tool(Arc::new(MockTool::new("bash", "ok")))
4778 .on_approval(Arc::new(|_: &[ToolCall]| {
4779 crate::llm::ApprovalDecision::Deny
4780 }))
4781 .build()
4782 .unwrap();
4783
4784 let output = runner.execute("do something").await.unwrap();
4785 assert_eq!(output.result, "denied");
4786 }
4787
4788 #[tokio::test]
4789 async fn always_allow_injects_rule_into_live_ruleset() {
4790 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
4793 let call_count_clone = call_count.clone();
4794
4795 let provider = Arc::new(MockProvider::new(vec![
4796 CompletionResponse {
4798 content: vec![ContentBlock::ToolUse {
4799 id: "c1".into(),
4800 name: "bash".into(),
4801 input: json!({"command": "ls"}),
4802 }],
4803 stop_reason: StopReason::ToolUse,
4804 usage: TokenUsage::default(),
4805 model: None,
4806 },
4807 CompletionResponse {
4809 content: vec![ContentBlock::ToolUse {
4810 id: "c2".into(),
4811 name: "bash".into(),
4812 input: json!({"command": "ls"}),
4813 }],
4814 stop_reason: StopReason::ToolUse,
4815 usage: TokenUsage::default(),
4816 model: None,
4817 },
4818 CompletionResponse {
4820 content: vec![ContentBlock::Text {
4821 text: "done".into(),
4822 }],
4823 stop_reason: StopReason::EndTurn,
4824 usage: TokenUsage::default(),
4825 model: None,
4826 },
4827 ]));
4828
4829 let runner = AgentRunner::builder(provider)
4830 .name("perm-test")
4831 .system_prompt("sys")
4832 .tool(Arc::new(MockTool::new("bash", "ok")))
4833 .on_approval(Arc::new(move |_: &[ToolCall]| {
4834 call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
4835 crate::llm::ApprovalDecision::AlwaysAllow
4836 }))
4837 .build()
4838 .unwrap();
4839
4840 let output = runner.execute("do something").await.unwrap();
4841 assert_eq!(output.result, "done");
4842 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
4846 }
4847
4848 #[tokio::test]
4849 async fn always_deny_injects_rule_into_live_ruleset() {
4850 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
4853 let call_count_clone = call_count.clone();
4854
4855 let provider = Arc::new(MockProvider::new(vec![
4856 CompletionResponse {
4858 content: vec![ContentBlock::ToolUse {
4859 id: "c1".into(),
4860 name: "bash".into(),
4861 input: json!({"command": "rm -rf /"}),
4862 }],
4863 stop_reason: StopReason::ToolUse,
4864 usage: TokenUsage::default(),
4865 model: None,
4866 },
4867 CompletionResponse {
4869 content: vec![ContentBlock::ToolUse {
4870 id: "c2".into(),
4871 name: "bash".into(),
4872 input: json!({"command": "rm -rf /"}),
4873 }],
4874 stop_reason: StopReason::ToolUse,
4875 usage: TokenUsage::default(),
4876 model: None,
4877 },
4878 CompletionResponse {
4880 content: vec![ContentBlock::Text {
4881 text: "gave up".into(),
4882 }],
4883 stop_reason: StopReason::EndTurn,
4884 usage: TokenUsage::default(),
4885 model: None,
4886 },
4887 ]));
4888
4889 let runner = AgentRunner::builder(provider)
4890 .name("perm-test")
4891 .system_prompt("sys")
4892 .tool(Arc::new(MockTool::new("bash", "ok")))
4893 .on_approval(Arc::new(move |_: &[ToolCall]| {
4894 call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
4895 crate::llm::ApprovalDecision::AlwaysDeny
4896 }))
4897 .build()
4898 .unwrap();
4899
4900 let output = runner.execute("do something").await.unwrap();
4901 assert_eq!(output.result, "gave up");
4902 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
4905 }
4906
4907 #[tokio::test]
4908 async fn config_deny_overrides_learned_allow() {
4909 let provider = Arc::new(MockProvider::new(vec![
4914 CompletionResponse {
4916 content: vec![ContentBlock::ToolUse {
4917 id: "c1".into(),
4918 name: "bash".into(),
4919 input: json!({"command": "ls"}),
4920 }],
4921 stop_reason: StopReason::ToolUse,
4922 usage: TokenUsage::default(),
4923 model: None,
4924 },
4925 CompletionResponse {
4927 content: vec![ContentBlock::ToolUse {
4928 id: "c2".into(),
4929 name: "bash".into(),
4930 input: json!({"command": "rm -rf /"}),
4931 }],
4932 stop_reason: StopReason::ToolUse,
4933 usage: TokenUsage::default(),
4934 model: None,
4935 },
4936 CompletionResponse {
4938 content: vec![ContentBlock::Text {
4939 text: "blocked".into(),
4940 }],
4941 stop_reason: StopReason::EndTurn,
4942 usage: TokenUsage::default(),
4943 model: None,
4944 },
4945 ]));
4946
4947 let rules = permission::PermissionRuleset::new(vec![permission::PermissionRule {
4948 tool: "bash".into(),
4949 pattern: "rm *".into(),
4950 action: permission::PermissionAction::Deny,
4951 }]);
4952
4953 let runner = AgentRunner::builder(provider)
4954 .name("perm-test")
4955 .system_prompt("sys")
4956 .tool(Arc::new(MockTool::new("bash", "ok")))
4957 .on_approval(Arc::new(|_: &[ToolCall]| {
4958 crate::llm::ApprovalDecision::AlwaysAllow
4959 }))
4960 .permission_rules(rules)
4961 .build()
4962 .unwrap();
4963
4964 let output = runner.execute("do something").await.unwrap();
4965 assert_eq!(output.result, "blocked");
4966 }
4969
4970 #[tokio::test]
4971 async fn workspace_injects_system_prompt_hint() {
4972 let provider = MockProvider::new(vec![CompletionResponse {
4973 content: vec![ContentBlock::Text {
4974 text: "done".into(),
4975 }],
4976 stop_reason: StopReason::EndTurn,
4977 usage: TokenUsage::default(),
4978 model: None,
4979 }]);
4980 let runner = AgentRunner::builder(Arc::new(provider))
4981 .name("test")
4982 .system_prompt("base prompt")
4983 .workspace("/test/workspace")
4984 .build()
4985 .unwrap();
4986
4987 assert!(runner.system_prompt.contains("/test/workspace"));
4989 assert!(runner.system_prompt.contains("base prompt"));
4990 assert!(runner.system_prompt.contains("workspace directory"));
4991 }
4992
4993 #[tokio::test]
4994 async fn no_workspace_no_prompt_hint() {
4995 let provider = MockProvider::new(vec![CompletionResponse {
4996 content: vec![ContentBlock::Text {
4997 text: "done".into(),
4998 }],
4999 stop_reason: StopReason::EndTurn,
5000 usage: TokenUsage::default(),
5001 model: None,
5002 }]);
5003 let runner = AgentRunner::builder(Arc::new(provider))
5004 .name("test")
5005 .system_prompt("base prompt")
5006 .tool(Arc::new(MockTool::new("bash", "ok")))
5007 .build()
5008 .unwrap();
5009
5010 assert!(runner.system_prompt.starts_with("base prompt"));
5011 assert!(runner.system_prompt.contains("Resourcefulness"));
5012 assert!(!runner.system_prompt.contains("workspace"));
5013 }
5014
5015 #[test]
5016 fn resourcefulness_guidelines_included_with_power_tools() {
5017 let provider = Arc::new(MockProvider::new(vec![]));
5018 let runner = AgentRunner::builder(provider)
5019 .name("test")
5020 .system_prompt("prompt")
5021 .tool(Arc::new(MockTool::new("bash", "ok")))
5022 .build()
5023 .unwrap();
5024 assert!(
5025 runner.system_prompt.contains("Resourcefulness"),
5026 "should include guidelines when bash tool is present"
5027 );
5028 }
5029
5030 #[test]
5031 fn resourcefulness_guidelines_excluded_without_power_tools() {
5032 let provider = Arc::new(MockProvider::new(vec![]));
5033 let runner = AgentRunner::builder(provider)
5034 .name("test")
5035 .system_prompt("prompt")
5036 .tool(Arc::new(MockTool::new("memory_recall", "ok")))
5037 .build()
5038 .unwrap();
5039 assert!(
5040 !runner.system_prompt.contains("Resourcefulness"),
5041 "should not include guidelines when only memory tools are present"
5042 );
5043 }
5044
5045 #[test]
5046 fn system_prompt_contains_current_date() {
5047 let provider = Arc::new(MockProvider::new(vec![]));
5048 let runner = AgentRunner::builder(provider)
5049 .name("test")
5050 .system_prompt("prompt")
5051 .build()
5052 .unwrap();
5053 assert!(
5054 runner.system_prompt.contains("Current date and time:"),
5055 "system prompt should contain current date/time"
5056 );
5057 let year = chrono::Utc::now().format("%Y").to_string();
5059 assert!(
5060 runner.system_prompt.contains(&year),
5061 "system prompt should contain current year"
5062 );
5063 }
5064
5065 #[tokio::test]
5066 async fn budget_exceeded_returns_error() {
5067 let provider = Arc::new(MockProvider::new(vec![
5070 CompletionResponse {
5071 content: vec![ContentBlock::ToolUse {
5072 id: "call-1".into(),
5073 name: "echo".into(),
5074 input: json!({}),
5075 }],
5076 stop_reason: StopReason::ToolUse,
5077 usage: TokenUsage {
5078 input_tokens: 30000,
5079 output_tokens: 30000,
5080 ..Default::default()
5081 },
5082 model: None,
5083 },
5084 CompletionResponse {
5085 content: vec![ContentBlock::Text {
5086 text: "done".into(),
5087 }],
5088 stop_reason: StopReason::EndTurn,
5089 usage: TokenUsage {
5090 input_tokens: 30000,
5091 output_tokens: 30000,
5092 ..Default::default()
5093 },
5094 model: None,
5095 },
5096 ]));
5097 let tool = MockTool::new("echo", "ok");
5098 let runner = AgentRunner::builder(provider)
5099 .name("budget-test")
5100 .system_prompt("test")
5101 .tool(Arc::new(tool))
5102 .max_total_tokens(100000) .build()
5104 .unwrap();
5105
5106 let result = runner.execute("test task").await;
5107 match result {
5108 Err(Error::WithPartialUsage { source, usage }) => {
5109 assert!(
5110 matches!(
5111 *source,
5112 Error::BudgetExceeded {
5113 used: 120000,
5114 limit: 100000
5115 }
5116 ),
5117 "expected BudgetExceeded, got: {source}"
5118 );
5119 assert_eq!(usage.total(), 120000);
5120 }
5121 Err(e) => panic!("expected BudgetExceeded, got: {e}"),
5122 Ok(output) => panic!("expected error, got success: {}", output.result),
5123 }
5124 }
5125
5126 #[tokio::test]
5127 async fn budget_not_exceeded_when_under_limit() {
5128 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5130 content: vec![ContentBlock::Text {
5131 text: "done".into(),
5132 }],
5133 stop_reason: StopReason::EndTurn,
5134 usage: TokenUsage {
5135 input_tokens: 50,
5136 output_tokens: 50,
5137 ..Default::default()
5138 },
5139 model: None,
5140 }]));
5141 let runner = AgentRunner::builder(provider)
5142 .name("budget-ok-test")
5143 .system_prompt("test")
5144 .max_total_tokens(1000)
5145 .build()
5146 .unwrap();
5147
5148 let output = runner.execute("test task").await.unwrap();
5149 assert_eq!(output.tokens_used.total(), 100);
5150 }
5151
5152 #[tokio::test]
5153 async fn budget_event_emitted_on_exceeded() {
5154 let events: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new()));
5155 let events_clone = events.clone();
5156 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5157 content: vec![ContentBlock::Text {
5158 text: "done".into(),
5159 }],
5160 stop_reason: StopReason::EndTurn,
5161 usage: TokenUsage {
5162 input_tokens: 100,
5163 output_tokens: 100,
5164 ..Default::default()
5165 },
5166 model: None,
5167 }]));
5168 let runner = AgentRunner::builder(provider)
5169 .name("budget-event-test")
5170 .system_prompt("test")
5171 .max_total_tokens(50) .on_event(Arc::new(move |event| {
5173 events_clone.lock().unwrap().push(event);
5174 }))
5175 .build()
5176 .unwrap();
5177
5178 let _ = runner.execute("test task").await;
5179 let events = events.lock().unwrap();
5180 let budget_events: Vec<_> = events
5181 .iter()
5182 .filter(|e| matches!(e, AgentEvent::BudgetExceeded { .. }))
5183 .collect();
5184 assert_eq!(
5185 budget_events.len(),
5186 1,
5187 "expected exactly one BudgetExceeded event"
5188 );
5189 match &budget_events[0] {
5190 AgentEvent::BudgetExceeded { used, limit, .. } => {
5191 assert_eq!(*used, 200);
5192 assert_eq!(*limit, 50);
5193 }
5194 _ => unreachable!(),
5195 }
5196 }
5197
5198 #[tokio::test]
5199 async fn agent_runner_records_audit_trail() {
5200 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5201 content: vec![ContentBlock::Text {
5202 text: "Done!".into(),
5203 }],
5204 usage: TokenUsage {
5205 input_tokens: 10,
5206 output_tokens: 5,
5207 ..Default::default()
5208 },
5209 stop_reason: StopReason::EndTurn,
5210 model: Some("test-model".into()),
5211 }]));
5212
5213 let trail = Arc::new(crate::agent::audit::InMemoryAuditTrail::new());
5214 let runner = AgentRunner::builder(provider)
5215 .name("audit-test")
5216 .system_prompt("You help.")
5217 .max_turns(5)
5218 .audit_trail(trail.clone())
5219 .build()
5220 .unwrap();
5221
5222 let output = runner.execute("hello").await.unwrap();
5223 assert_eq!(output.result, "Done!");
5224
5225 let entries = trail.entries_unscoped(usize::MAX).await.unwrap();
5226 let event_types: Vec<&str> = entries.iter().map(|e| e.event_type.as_str()).collect();
5227 assert!(
5228 event_types.contains(&"llm_response"),
5229 "expected llm_response, got: {event_types:?}"
5230 );
5231 assert!(
5232 event_types.contains(&"run_completed"),
5233 "expected run_completed, got: {event_types:?}"
5234 );
5235 }
5236
5237 #[tokio::test]
5238 async fn audit_trail_captures_tool_calls() {
5239 let tool = Arc::new(MockTool::new("greet", "Hello!"));
5240 let provider = Arc::new(MockProvider::new(vec![
5241 CompletionResponse {
5243 content: vec![ContentBlock::ToolUse {
5244 id: "call-1".into(),
5245 name: "greet".into(),
5246 input: json!({"name": "world"}),
5247 }],
5248 usage: TokenUsage {
5249 input_tokens: 10,
5250 output_tokens: 5,
5251 ..Default::default()
5252 },
5253 stop_reason: StopReason::ToolUse,
5254 model: None,
5255 },
5256 CompletionResponse {
5258 content: vec![ContentBlock::Text {
5259 text: "All done.".into(),
5260 }],
5261 usage: TokenUsage {
5262 input_tokens: 15,
5263 output_tokens: 3,
5264 ..Default::default()
5265 },
5266 stop_reason: StopReason::EndTurn,
5267 model: None,
5268 },
5269 ]));
5270
5271 let trail = Arc::new(crate::agent::audit::InMemoryAuditTrail::new());
5272 let runner = AgentRunner::builder(provider)
5273 .name("tool-audit-test")
5274 .system_prompt("You help.")
5275 .tool(tool)
5276 .max_turns(5)
5277 .audit_trail(trail.clone())
5278 .build()
5279 .unwrap();
5280
5281 runner.execute("greet the world").await.unwrap();
5282
5283 let entries = trail.entries_unscoped(usize::MAX).await.unwrap();
5284 let event_types: Vec<&str> = entries.iter().map(|e| e.event_type.as_str()).collect();
5285 assert!(
5286 event_types.contains(&"tool_call"),
5287 "expected tool_call, got: {event_types:?}"
5288 );
5289 assert!(
5290 event_types.contains(&"tool_result"),
5291 "expected tool_result, got: {event_types:?}"
5292 );
5293
5294 let tool_result = entries
5296 .iter()
5297 .find(|e| e.event_type == "tool_result")
5298 .unwrap();
5299 assert_eq!(tool_result.payload["output"], "Hello!");
5300
5301 let tool_call_entry = entries
5303 .iter()
5304 .find(|e| e.event_type == "tool_call")
5305 .unwrap();
5306 assert!(
5307 tool_call_entry.turn > 0,
5308 "tool_call turn should be > 0, got: {}",
5309 tool_call_entry.turn
5310 );
5311 assert_eq!(tool_call_entry.payload["input"]["name"], "world");
5313 }
5314
5315 #[tokio::test]
5316 async fn audit_trail_none_by_default() {
5317 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5318 content: vec![ContentBlock::Text { text: "OK".into() }],
5319 usage: TokenUsage::default(),
5320 stop_reason: StopReason::EndTurn,
5321 model: None,
5322 }]));
5323
5324 let runner = AgentRunner::builder(provider)
5326 .name("no-audit")
5327 .system_prompt("You help.")
5328 .max_turns(5)
5329 .build()
5330 .unwrap();
5331
5332 let output = runner.execute("hello").await.unwrap();
5333 assert_eq!(output.result, "OK");
5334 }
5335
5336 #[test]
5337 fn audit_user_context_builder_sets_fields() {
5338 let provider = Arc::new(MockProvider::new(vec![]));
5339 let runner = AgentRunner::builder(provider)
5340 .name("test-agent")
5341 .system_prompt("prompt")
5342 .max_turns(5)
5343 .audit_user_context("alice", "acme")
5344 .build()
5345 .unwrap();
5346
5347 assert_eq!(runner.audit_user_id.as_deref(), Some("alice"));
5348 assert_eq!(runner.audit_tenant_id.as_deref(), Some("acme"));
5349 }
5350
5351 #[test]
5352 fn audit_user_context_defaults_to_none() {
5353 let provider = Arc::new(MockProvider::new(vec![]));
5354 let runner = AgentRunner::builder(provider)
5355 .name("test-agent")
5356 .system_prompt("prompt")
5357 .max_turns(5)
5358 .build()
5359 .unwrap();
5360
5361 assert!(runner.audit_user_id.is_none());
5362 assert!(runner.audit_tenant_id.is_none());
5363 }
5364
5365 #[tokio::test]
5366 async fn post_llm_warn_does_not_block_execution() {
5367 use std::sync::atomic::{AtomicBool, Ordering};
5369
5370 struct WarnAlways;
5371 impl Guardrail for WarnAlways {
5372 fn post_llm(
5373 &self,
5374 _response: &mut crate::llm::types::CompletionResponse,
5375 ) -> std::pin::Pin<
5376 Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>,
5377 > {
5378 Box::pin(async { Ok(GuardAction::warn("suspicious but allowed")) })
5379 }
5380 }
5381
5382 let warned = Arc::new(AtomicBool::new(false));
5383 let warned_clone = warned.clone();
5384
5385 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5386 content: vec![ContentBlock::Text {
5387 text: "answer".into(),
5388 }],
5389 stop_reason: StopReason::EndTurn,
5390 usage: TokenUsage::default(),
5391 model: None,
5392 }]));
5393
5394 let runner = AgentRunner::builder(provider)
5395 .name("test")
5396 .system_prompt("sys")
5397 .guardrail(Arc::new(WarnAlways))
5398 .on_event(Arc::new(move |event| {
5399 if matches!(event, AgentEvent::GuardrailWarned { .. }) {
5400 warned_clone.store(true, Ordering::Relaxed);
5401 }
5402 }))
5403 .build()
5404 .unwrap();
5405
5406 let output = runner.execute("hello").await.unwrap();
5407 assert_eq!(output.result, "answer");
5409 assert!(
5411 warned.load(Ordering::Relaxed),
5412 "GuardrailWarned event should have fired"
5413 );
5414 }
5415
5416 #[tokio::test]
5417 async fn pre_tool_warn_does_not_block_tool_execution() {
5418 use std::sync::atomic::{AtomicBool, Ordering};
5420
5421 struct WarnPreTool;
5422 impl Guardrail for WarnPreTool {
5423 fn pre_tool(
5424 &self,
5425 _call: &crate::llm::types::ToolCall,
5426 ) -> std::pin::Pin<
5427 Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>,
5428 > {
5429 Box::pin(async { Ok(GuardAction::warn("risky tool usage")) })
5430 }
5431 }
5432
5433 let warned = Arc::new(AtomicBool::new(false));
5434 let warned_clone = warned.clone();
5435
5436 let provider = Arc::new(MockProvider::new(vec![
5437 CompletionResponse {
5438 content: vec![ContentBlock::ToolUse {
5439 id: "c1".into(),
5440 name: "search".into(),
5441 input: json!({}),
5442 }],
5443 stop_reason: StopReason::ToolUse,
5444 usage: TokenUsage::default(),
5445 model: None,
5446 },
5447 CompletionResponse {
5448 content: vec![ContentBlock::Text {
5449 text: "Done with search.".into(),
5450 }],
5451 stop_reason: StopReason::EndTurn,
5452 usage: TokenUsage::default(),
5453 model: None,
5454 },
5455 ]));
5456
5457 let runner = AgentRunner::builder(provider)
5458 .name("test")
5459 .system_prompt("sys")
5460 .tool(Arc::new(MockTool::new("search", "search result")))
5461 .guardrail(Arc::new(WarnPreTool))
5462 .on_event(Arc::new(move |event| {
5463 if matches!(event, AgentEvent::GuardrailWarned { .. }) {
5464 warned_clone.store(true, Ordering::Relaxed);
5465 }
5466 }))
5467 .build()
5468 .unwrap();
5469
5470 let output = runner.execute("search something").await.unwrap();
5471 assert_eq!(output.result, "Done with search.");
5473 assert_eq!(output.tool_calls_made, 1);
5474 assert!(
5476 warned.load(Ordering::Relaxed),
5477 "GuardrailWarned event should have fired"
5478 );
5479 }
5480
5481 #[tokio::test]
5482 async fn max_tool_calls_per_turn_caps_excess_dispatch() {
5483 let provider = Arc::new(MockProvider::new(vec![CompletionResponse {
5485 content: vec![
5486 ContentBlock::ToolUse {
5487 id: "c1".into(),
5488 name: "a".into(),
5489 input: json!({}),
5490 },
5491 ContentBlock::ToolUse {
5492 id: "c2".into(),
5493 name: "b".into(),
5494 input: json!({}),
5495 },
5496 ContentBlock::ToolUse {
5497 id: "c3".into(),
5498 name: "c".into(),
5499 input: json!({}),
5500 },
5501 ],
5502 stop_reason: StopReason::ToolUse,
5503 usage: TokenUsage::default(),
5504 model: None,
5505 }]));
5506
5507 let runner = AgentRunner::builder(provider)
5508 .name("test")
5509 .system_prompt("sys")
5510 .tool(Arc::new(MockTool::new("a", "x")))
5511 .tool(Arc::new(MockTool::new("b", "y")))
5512 .tool(Arc::new(MockTool::new("c", "z")))
5513 .max_tool_calls_per_turn(2)
5514 .build()
5515 .unwrap();
5516
5517 let err = runner.execute("go").await.unwrap_err();
5518 let s = err.to_string();
5519 assert!(s.contains("tool-call cap exceeded"), "got: {s}");
5520 assert!(
5522 matches!(err, Error::WithPartialUsage { .. }),
5523 "got: {err:?}"
5524 );
5525 }
5526
5527 #[test]
5528 fn max_tool_calls_per_turn_zero_is_rejected_at_build() {
5529 let provider = Arc::new(MockProvider::new(vec![]));
5530 let result = AgentRunner::builder(provider)
5531 .name("t")
5532 .system_prompt("p")
5533 .max_tool_calls_per_turn(0)
5534 .build();
5535 assert!(result.is_err());
5536 let err = result.err().unwrap();
5537 assert!(
5538 err.to_string()
5539 .contains("max_tool_calls_per_turn must be > 0"),
5540 "got: {err}"
5541 );
5542 }
5543
5544 #[tokio::test]
5545 async fn max_tool_calls_per_turn_at_cap_is_allowed() {
5546 let provider = Arc::new(MockProvider::new(vec![
5548 CompletionResponse {
5549 content: vec![
5550 ContentBlock::ToolUse {
5551 id: "c1".into(),
5552 name: "a".into(),
5553 input: json!({}),
5554 },
5555 ContentBlock::ToolUse {
5556 id: "c2".into(),
5557 name: "b".into(),
5558 input: json!({}),
5559 },
5560 ],
5561 stop_reason: StopReason::ToolUse,
5562 usage: TokenUsage::default(),
5563 model: None,
5564 },
5565 CompletionResponse {
5566 content: vec![ContentBlock::Text {
5567 text: "done".into(),
5568 }],
5569 stop_reason: StopReason::EndTurn,
5570 usage: TokenUsage::default(),
5571 model: None,
5572 },
5573 ]));
5574
5575 let runner = AgentRunner::builder(provider)
5576 .name("test")
5577 .system_prompt("sys")
5578 .tool(Arc::new(MockTool::new("a", "x")))
5579 .tool(Arc::new(MockTool::new("b", "y")))
5580 .max_tool_calls_per_turn(2)
5581 .build()
5582 .unwrap();
5583
5584 let output = runner.execute("go").await.unwrap();
5585 assert_eq!(output.tool_calls_made, 2);
5586 }
5587
5588 #[tokio::test]
5593 async fn levenshtein_repair_runs_before_tool_policy() {
5594 use crate::agent::guardrails::tool_policy::{ToolPolicyGuardrail, ToolRule};
5595
5596 let provider = Arc::new(MockProvider::new(vec![
5597 CompletionResponse {
5598 content: vec![ContentBlock::ToolUse {
5599 id: "call-1".into(),
5600 name: "bask".into(),
5602 input: json!({}),
5603 }],
5604 stop_reason: StopReason::ToolUse,
5605 usage: TokenUsage::default(),
5606 model: None,
5607 },
5608 CompletionResponse {
5609 content: vec![ContentBlock::Text {
5610 text: "done".into(),
5611 }],
5612 stop_reason: StopReason::EndTurn,
5613 usage: TokenUsage::default(),
5614 model: None,
5615 },
5616 ]));
5617
5618 let policy = Arc::new(ToolPolicyGuardrail::new(
5621 vec![ToolRule {
5622 tool_pattern: "bash".into(),
5623 action: crate::GuardAction::deny("blocked"),
5624 input_constraints: vec![],
5625 }],
5626 crate::GuardAction::Allow,
5627 ));
5628
5629 let runner = AgentRunner::builder(provider)
5630 .name("test")
5631 .system_prompt("sys")
5632 .tool(Arc::new(MockTool::new("bash", "DANGEROUS_OUTPUT")))
5633 .guardrails(vec![policy])
5634 .build()
5635 .unwrap();
5636
5637 let output = runner.execute("run shell").await.unwrap();
5638
5639 assert!(
5642 !output.result.contains("DANGEROUS_OUTPUT"),
5643 "tool result leaked despite policy deny: {}",
5644 output.result
5645 );
5646 assert_eq!(
5647 output.tool_calls_made, 1,
5648 "exactly one tool call should be recorded (denied)"
5649 );
5650 }
5651}