1use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider};
41use crate::retry::RetryPolicy;
42
43pub struct Model;
74
75impl Model {
76 pub const SONNET: &str = "sonnet";
80 pub const OPUS: &str = "opus";
82 pub const HAIKU: &str = "haiku";
84
85 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90 pub const SONNET_46: &str = "claude-sonnet-4-6";
94 pub const OPUS_46: &str = "claude-opus-4-6";
96
97 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103}
104
105#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
110pub enum PermissionMode {
111 Default,
113 Auto,
115 DontAsk,
117 BypassPermissions,
121}
122
123#[must_use = "an Agent does nothing until .run() is awaited"]
153pub struct Agent {
154 config: AgentConfig,
155 dry_run: Option<bool>,
156 retry_policy: Option<RetryPolicy>,
157}
158
159impl Agent {
160 pub fn new() -> Self {
165 Self {
166 config: AgentConfig::new(""),
167 dry_run: None,
168 retry_policy: None,
169 }
170 }
171
172 pub fn system_prompt(mut self, prompt: &str) -> Self {
174 self.config.system_prompt = Some(prompt.to_string());
175 self
176 }
177
178 pub fn prompt(mut self, prompt: &str) -> Self {
180 self.config.prompt = prompt.to_string();
181 self
182 }
183
184 pub fn model(mut self, model: impl Into<String>) -> Self {
191 self.config.model = model.into();
192 self
193 }
194
195 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
200 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
201 self
202 }
203
204 pub fn max_turns(mut self, turns: u32) -> Self {
210 assert!(turns > 0, "max_turns must be greater than 0");
211 self.config.max_turns = Some(turns);
212 self
213 }
214
215 pub fn max_budget_usd(mut self, budget: f64) -> Self {
221 assert!(
222 budget.is_finite() && budget > 0.0,
223 "budget must be a positive finite number, got {budget}"
224 );
225 self.config.max_budget_usd = Some(budget);
226 self
227 }
228
229 pub fn working_dir(mut self, dir: &str) -> Self {
231 self.config.working_dir = Some(dir.to_string());
232 self
233 }
234
235 pub fn mcp_config(mut self, config: &str) -> Self {
237 self.config.mcp_config = Some(config.to_string());
238 self
239 }
240
241 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
245 self.config.permission_mode = mode;
246 self
247 }
248
249 pub fn output<T: JsonSchema>(mut self) -> Self {
280 let schema = schema_for!(T);
281 self.config.json_schema = match to_string(&schema) {
282 Ok(s) => Some(s),
283 Err(e) => {
284 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
285 None
286 }
287 };
288 self
289 }
290
291 pub fn output_schema_raw(mut self, schema: &str) -> Self {
315 self.config.json_schema = Some(schema.to_string());
316 self
317 }
318
319 pub fn retry(mut self, max_retries: u32) -> Self {
347 self.retry_policy = Some(RetryPolicy::new(max_retries));
348 self
349 }
350
351 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
378 self.retry_policy = Some(policy);
379 self
380 }
381
382 pub fn dry_run(mut self, enabled: bool) -> Self {
391 self.dry_run = Some(enabled);
392 self
393 }
394
395 pub fn resume(mut self, session_id: &str) -> Self {
431 assert!(!session_id.is_empty(), "session_id must not be empty");
432 assert!(
433 session_id
434 .chars()
435 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
436 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
437 );
438 self.config.resume_session_id = Some(session_id.to_string());
439 self
440 }
441
442 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
459 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
460 assert!(
461 !self.config.prompt.trim().is_empty(),
462 "prompt must not be empty - call .prompt(\"...\") before .run()"
463 );
464
465 if crate::dry_run::effective_dry_run(self.dry_run) {
466 info!(
467 prompt_len = self.config.prompt.len(),
468 "[dry-run] agent call skipped"
469 );
470 let mut output =
471 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
472 output.cost_usd = Some(0.0);
473 output.input_tokens = Some(0);
474 output.output_tokens = Some(0);
475 return Ok(AgentResult { output });
476 }
477
478 let result = self.invoke_once(provider).await;
479
480 let policy = match &self.retry_policy {
481 Some(p) => p,
482 None => return result,
483 };
484
485 if let Err(ref err) = result {
487 if !crate::retry::is_retryable(err) {
488 return result;
489 }
490 } else {
491 return result;
492 }
493
494 let mut last_result = result;
495
496 for attempt in 0..policy.max_retries {
497 let delay = policy.delay_for_attempt(attempt);
498 warn!(
499 attempt = attempt + 1,
500 max_retries = policy.max_retries,
501 delay_ms = delay.as_millis() as u64,
502 "retrying agent invocation"
503 );
504 time::sleep(delay).await;
505
506 last_result = self.invoke_once(provider).await;
507
508 match &last_result {
509 Ok(_) => return last_result,
510 Err(err) if !crate::retry::is_retryable(err) => return last_result,
511 _ => {}
512 }
513 }
514
515 last_result
516 }
517
518 async fn invoke_once(
520 &self,
521 provider: &dyn AgentProvider,
522 ) -> Result<AgentResult, OperationError> {
523 #[cfg(feature = "prometheus")]
524 let model_label = self.config.model.to_string();
525
526 let output = match provider.invoke(&self.config).await {
527 Ok(output) => output,
528 Err(e) => {
529 #[cfg(feature = "prometheus")]
530 {
531 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
532 }
533 return Err(OperationError::Agent(e));
534 }
535 };
536
537 info!(
538 duration_ms = output.duration_ms,
539 cost_usd = output.cost_usd,
540 input_tokens = output.input_tokens,
541 output_tokens = output.output_tokens,
542 model = output.model,
543 "agent completed"
544 );
545
546 #[cfg(feature = "prometheus")]
547 {
548 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
549 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
550 .record(output.duration_ms as f64 / 1000.0);
551 if let Some(cost) = output.cost_usd {
552 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
553 .increment(cost);
554 }
555 if let Some(tokens) = output.input_tokens {
556 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
557 }
558 if let Some(tokens) = output.output_tokens {
559 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
560 .increment(tokens);
561 }
562 }
563
564 Ok(AgentResult { output })
565 }
566}
567
568impl Default for Agent {
569 fn default() -> Self {
570 Self::new()
571 }
572}
573
574#[derive(Debug)]
579pub struct AgentResult {
580 output: AgentOutput,
581}
582
583impl AgentResult {
584 pub fn text(&self) -> &str {
589 match self.output.value.as_str() {
590 Some(s) => s,
591 None => {
592 warn!(
593 value_type = self.output.value.to_string(),
594 "agent output is not a string, returning empty"
595 );
596 ""
597 }
598 }
599 }
600
601 pub fn value(&self) -> &Value {
603 &self.output.value
604 }
605
606 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
616 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
617 }
618
619 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
625 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
626 }
627
628 #[cfg(test)]
633 pub(crate) fn from_output(output: AgentOutput) -> Self {
634 Self { output }
635 }
636
637 pub fn session_id(&self) -> Option<&str> {
639 self.output.session_id.as_deref()
640 }
641
642 pub fn cost_usd(&self) -> Option<f64> {
644 self.output.cost_usd
645 }
646
647 pub fn input_tokens(&self) -> Option<u64> {
649 self.output.input_tokens
650 }
651
652 pub fn output_tokens(&self) -> Option<u64> {
654 self.output.output_tokens
655 }
656
657 pub fn duration_ms(&self) -> u64 {
659 self.output.duration_ms
660 }
661
662 pub fn model(&self) -> Option<&str> {
664 self.output.model.as_deref()
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use crate::error::AgentError;
672 use crate::provider::InvokeFuture;
673 use serde_json::json;
674
675 struct TestProvider {
676 output: AgentOutput,
677 }
678
679 impl AgentProvider for TestProvider {
680 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
681 Box::pin(async move {
682 Ok(AgentOutput {
683 value: self.output.value.clone(),
684 session_id: self.output.session_id.clone(),
685 cost_usd: self.output.cost_usd,
686 input_tokens: self.output.input_tokens,
687 output_tokens: self.output.output_tokens,
688 model: self.output.model.clone(),
689 duration_ms: self.output.duration_ms,
690 })
691 })
692 }
693 }
694
695 struct ConfigCapture {
696 output: AgentOutput,
697 }
698
699 impl AgentProvider for ConfigCapture {
700 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
701 let config_json = serde_json::to_value(config).unwrap();
702 Box::pin(async move {
703 Ok(AgentOutput {
704 value: config_json,
705 session_id: self.output.session_id.clone(),
706 cost_usd: self.output.cost_usd,
707 input_tokens: self.output.input_tokens,
708 output_tokens: self.output.output_tokens,
709 model: self.output.model.clone(),
710 duration_ms: self.output.duration_ms,
711 })
712 })
713 }
714 }
715
716 fn default_output() -> AgentOutput {
717 AgentOutput {
718 value: json!("test output"),
719 session_id: Some("sess-123".to_string()),
720 cost_usd: Some(0.05),
721 input_tokens: Some(100),
722 output_tokens: Some(50),
723 model: Some("sonnet".to_string()),
724 duration_ms: 1500,
725 }
726 }
727
728 #[test]
731 fn model_constants_have_expected_values() {
732 assert_eq!(Model::SONNET, "sonnet");
733 assert_eq!(Model::OPUS, "opus");
734 assert_eq!(Model::HAIKU, "haiku");
735 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
736 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
737 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
738 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
739 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
740 }
741
742 #[tokio::test]
745 async fn agent_new_default_values() {
746 let provider = ConfigCapture {
747 output: default_output(),
748 };
749 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
750
751 let config = result.value();
752 assert_eq!(config["system_prompt"], json!(null));
753 assert_eq!(config["prompt"], json!("hi"));
754 assert_eq!(config["model"], json!("sonnet"));
755 assert_eq!(config["allowed_tools"], json!([]));
756 assert_eq!(config["max_turns"], json!(null));
757 assert_eq!(config["max_budget_usd"], json!(null));
758 assert_eq!(config["working_dir"], json!(null));
759 assert_eq!(config["mcp_config"], json!(null));
760 assert_eq!(config["permission_mode"], json!("Default"));
761 assert_eq!(config["json_schema"], json!(null));
762 }
763
764 #[tokio::test]
765 async fn agent_default_matches_new() {
766 let provider = ConfigCapture {
767 output: default_output(),
768 };
769 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
770 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
771
772 assert_eq!(result_new.value(), result_default.value());
773 }
774
775 #[tokio::test]
778 async fn builder_methods_store_values_correctly() {
779 let provider = ConfigCapture {
780 output: default_output(),
781 };
782 let result = Agent::new()
783 .system_prompt("you are a bot")
784 .prompt("do something")
785 .model(Model::OPUS)
786 .allowed_tools(&["Read", "Write"])
787 .max_turns(5)
788 .max_budget_usd(1.5)
789 .working_dir("/tmp")
790 .mcp_config("{}")
791 .permission_mode(PermissionMode::Auto)
792 .run(&provider)
793 .await
794 .unwrap();
795
796 let config = result.value();
797 assert_eq!(config["system_prompt"], json!("you are a bot"));
798 assert_eq!(config["prompt"], json!("do something"));
799 assert_eq!(config["model"], json!("opus"));
800 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
801 assert_eq!(config["max_turns"], json!(5));
802 assert_eq!(config["max_budget_usd"], json!(1.5));
803 assert_eq!(config["working_dir"], json!("/tmp"));
804 assert_eq!(config["mcp_config"], json!("{}"));
805 assert_eq!(config["permission_mode"], json!("Auto"));
806 }
807
808 #[test]
811 #[should_panic(expected = "max_turns must be greater than 0")]
812 fn max_turns_zero_panics() {
813 let _ = Agent::new().max_turns(0);
814 }
815
816 #[test]
817 #[should_panic(expected = "budget must be a positive finite number")]
818 fn max_budget_negative_panics() {
819 let _ = Agent::new().max_budget_usd(-1.0);
820 }
821
822 #[test]
823 #[should_panic(expected = "budget must be a positive finite number")]
824 fn max_budget_nan_panics() {
825 let _ = Agent::new().max_budget_usd(f64::NAN);
826 }
827
828 #[test]
829 #[should_panic(expected = "budget must be a positive finite number")]
830 fn max_budget_infinity_panics() {
831 let _ = Agent::new().max_budget_usd(f64::INFINITY);
832 }
833
834 #[tokio::test]
837 async fn agent_result_text_with_string_value() {
838 let provider = TestProvider {
839 output: AgentOutput {
840 value: json!("hello world"),
841 ..default_output()
842 },
843 };
844 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
845 assert_eq!(result.text(), "hello world");
846 }
847
848 #[tokio::test]
849 async fn agent_result_text_with_non_string_value() {
850 let provider = TestProvider {
851 output: AgentOutput {
852 value: json!(42),
853 ..default_output()
854 },
855 };
856 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
857 assert_eq!(result.text(), "");
858 }
859
860 #[tokio::test]
861 async fn agent_result_text_with_null_value() {
862 let provider = TestProvider {
863 output: AgentOutput {
864 value: json!(null),
865 ..default_output()
866 },
867 };
868 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
869 assert_eq!(result.text(), "");
870 }
871
872 #[tokio::test]
873 async fn agent_result_json_successful_deserialize() {
874 #[derive(Deserialize, PartialEq, Debug)]
875 struct MyOutput {
876 name: String,
877 count: u32,
878 }
879 let provider = TestProvider {
880 output: AgentOutput {
881 value: json!({"name": "test", "count": 7}),
882 ..default_output()
883 },
884 };
885 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
886 let parsed: MyOutput = result.json().unwrap();
887 assert_eq!(parsed.name, "test");
888 assert_eq!(parsed.count, 7);
889 }
890
891 #[tokio::test]
892 async fn agent_result_json_failed_deserialize() {
893 #[derive(Debug, Deserialize)]
894 #[allow(dead_code)]
895 struct MyOutput {
896 name: String,
897 }
898 let provider = TestProvider {
899 output: AgentOutput {
900 value: json!(42),
901 ..default_output()
902 },
903 };
904 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
905 let err = result.json::<MyOutput>().unwrap_err();
906 assert!(matches!(err, OperationError::Deserialize { .. }));
907 }
908
909 #[tokio::test]
910 async fn agent_result_accessors() {
911 let provider = TestProvider {
912 output: AgentOutput {
913 value: json!("v"),
914 session_id: Some("s-1".to_string()),
915 cost_usd: Some(0.123),
916 input_tokens: Some(999),
917 output_tokens: Some(456),
918 model: Some("opus".to_string()),
919 duration_ms: 2000,
920 },
921 };
922 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
923 assert_eq!(result.session_id(), Some("s-1"));
924 assert_eq!(result.cost_usd(), Some(0.123));
925 assert_eq!(result.input_tokens(), Some(999));
926 assert_eq!(result.output_tokens(), Some(456));
927 assert_eq!(result.duration_ms(), 2000);
928 assert_eq!(result.model(), Some("opus"));
929 }
930
931 #[tokio::test]
934 async fn resume_passes_session_id_in_config() {
935 let provider = ConfigCapture {
936 output: default_output(),
937 };
938 let result = Agent::new()
939 .prompt("followup")
940 .resume("sess-abc")
941 .run(&provider)
942 .await
943 .unwrap();
944
945 let config = result.value();
946 assert_eq!(config["resume_session_id"], json!("sess-abc"));
947 }
948
949 #[tokio::test]
950 async fn no_resume_has_null_session_id() {
951 let provider = ConfigCapture {
952 output: default_output(),
953 };
954 let result = Agent::new()
955 .prompt("first call")
956 .run(&provider)
957 .await
958 .unwrap();
959
960 let config = result.value();
961 assert_eq!(config["resume_session_id"], json!(null));
962 }
963
964 #[test]
965 #[should_panic(expected = "session_id must not be empty")]
966 fn resume_empty_session_id_panics() {
967 let _ = Agent::new().resume("");
968 }
969
970 #[test]
971 #[should_panic(expected = "session_id must only contain")]
972 fn resume_invalid_chars_panics() {
973 let _ = Agent::new().resume("sess;rm -rf /");
974 }
975
976 #[test]
977 fn resume_valid_formats_accepted() {
978 let _ = Agent::new().resume("sess-abc123");
979 let _ = Agent::new().resume("a1b2c3d4_session");
980 let _ = Agent::new().resume("abc-DEF-123_456");
981 }
982
983 #[tokio::test]
984 #[should_panic(expected = "prompt must not be empty")]
985 async fn run_without_prompt_panics() {
986 let provider = TestProvider {
987 output: default_output(),
988 };
989 let _ = Agent::new().run(&provider).await;
990 }
991
992 #[tokio::test]
993 #[should_panic(expected = "prompt must not be empty")]
994 async fn run_with_whitespace_only_prompt_panics() {
995 let provider = TestProvider {
996 output: default_output(),
997 };
998 let _ = Agent::new().prompt(" ").run(&provider).await;
999 }
1000
1001 #[tokio::test]
1004 async fn model_accepts_custom_string() {
1005 let provider = ConfigCapture {
1006 output: default_output(),
1007 };
1008 let result = Agent::new()
1009 .prompt("hi")
1010 .model("mistral-large-latest")
1011 .run(&provider)
1012 .await
1013 .unwrap();
1014 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1015 }
1016
1017 #[tokio::test]
1018 async fn model_accepts_owned_string() {
1019 let provider = ConfigCapture {
1020 output: default_output(),
1021 };
1022 let model_name = String::from("gpt-4o");
1023 let result = Agent::new()
1024 .prompt("hi")
1025 .model(model_name)
1026 .run(&provider)
1027 .await
1028 .unwrap();
1029 assert_eq!(result.value()["model"], json!("gpt-4o"));
1030 }
1031
1032 #[tokio::test]
1033 async fn into_json_success() {
1034 #[derive(Deserialize, PartialEq, Debug)]
1035 struct Out {
1036 name: String,
1037 }
1038 let provider = TestProvider {
1039 output: AgentOutput {
1040 value: json!({"name": "test"}),
1041 ..default_output()
1042 },
1043 };
1044 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1045 let parsed: Out = result.into_json().unwrap();
1046 assert_eq!(parsed.name, "test");
1047 }
1048
1049 #[tokio::test]
1050 async fn into_json_failure() {
1051 #[derive(Debug, Deserialize)]
1052 #[allow(dead_code)]
1053 struct Out {
1054 name: String,
1055 }
1056 let provider = TestProvider {
1057 output: AgentOutput {
1058 value: json!(42),
1059 ..default_output()
1060 },
1061 };
1062 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1063 let err = result.into_json::<Out>().unwrap_err();
1064 assert!(matches!(err, OperationError::Deserialize { .. }));
1065 }
1066
1067 #[test]
1068 fn from_output_creates_result() {
1069 let output = AgentOutput {
1070 value: json!("hello"),
1071 ..default_output()
1072 };
1073 let result = AgentResult::from_output(output);
1074 assert_eq!(result.text(), "hello");
1075 assert_eq!(result.cost_usd(), Some(0.05));
1076 }
1077
1078 #[test]
1079 #[should_panic(expected = "budget must be a positive finite number")]
1080 fn max_budget_zero_panics() {
1081 let _ = Agent::new().max_budget_usd(0.0);
1082 }
1083
1084 #[test]
1085 fn model_constant_equality() {
1086 assert_eq!(Model::SONNET, "sonnet");
1087 assert_ne!(Model::SONNET, Model::OPUS);
1088 }
1089
1090 #[test]
1091 fn permission_mode_serialize_deserialize_roundtrip() {
1092 for mode in [
1093 PermissionMode::Default,
1094 PermissionMode::Auto,
1095 PermissionMode::DontAsk,
1096 PermissionMode::BypassPermissions,
1097 ] {
1098 let json = to_string(&mode).unwrap();
1099 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1100 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1101 }
1102 }
1103
1104 #[test]
1107 fn retry_builder_stores_policy() {
1108 let agent = Agent::new().retry(3);
1109 assert!(agent.retry_policy.is_some());
1110 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1111 }
1112
1113 #[test]
1114 fn retry_policy_builder_stores_custom_policy() {
1115 use crate::retry::RetryPolicy;
1116 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1117 let agent = Agent::new().retry_policy(policy);
1118 let p = agent.retry_policy.unwrap();
1119 assert_eq!(p.max_retries(), 5);
1120 }
1121
1122 #[test]
1123 fn no_retry_by_default() {
1124 let agent = Agent::new();
1125 assert!(agent.retry_policy.is_none());
1126 }
1127
1128 use std::sync::Arc;
1131 use std::sync::atomic::{AtomicU32, Ordering};
1132 use std::time::Duration;
1133
1134 struct FailNTimesProvider {
1135 fail_count: AtomicU32,
1136 failures_before_success: u32,
1137 output: AgentOutput,
1138 }
1139
1140 impl AgentProvider for FailNTimesProvider {
1141 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1142 Box::pin(async move {
1143 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1144 if current < self.failures_before_success {
1145 Err(AgentError::ProcessFailed {
1146 exit_code: 1,
1147 stderr: format!("transient failure #{}", current + 1),
1148 })
1149 } else {
1150 Ok(AgentOutput {
1151 value: self.output.value.clone(),
1152 session_id: self.output.session_id.clone(),
1153 cost_usd: self.output.cost_usd,
1154 input_tokens: self.output.input_tokens,
1155 output_tokens: self.output.output_tokens,
1156 model: self.output.model.clone(),
1157 duration_ms: self.output.duration_ms,
1158 })
1159 }
1160 })
1161 }
1162 }
1163
1164 #[tokio::test]
1165 async fn retry_succeeds_after_transient_failures() {
1166 let provider = FailNTimesProvider {
1167 fail_count: AtomicU32::new(0),
1168 failures_before_success: 2,
1169 output: default_output(),
1170 };
1171 let result = Agent::new()
1172 .prompt("test")
1173 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1174 .run(&provider)
1175 .await;
1176
1177 assert!(result.is_ok());
1178 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1180
1181 #[tokio::test]
1182 async fn retry_exhausted_returns_last_error() {
1183 let provider = FailNTimesProvider {
1184 fail_count: AtomicU32::new(0),
1185 failures_before_success: 10, output: default_output(),
1187 };
1188 let result = Agent::new()
1189 .prompt("test")
1190 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1191 .run(&provider)
1192 .await;
1193
1194 assert!(result.is_err());
1195 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1197 }
1198
1199 #[tokio::test]
1200 async fn retry_does_not_retry_non_retryable_errors() {
1201 let call_count = Arc::new(AtomicU32::new(0));
1202 let count = call_count.clone();
1203
1204 struct CountingNonRetryable {
1205 count: Arc<AtomicU32>,
1206 }
1207 impl AgentProvider for CountingNonRetryable {
1208 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1209 self.count.fetch_add(1, Ordering::SeqCst);
1210 Box::pin(async move {
1211 Err(AgentError::SchemaValidation {
1212 expected: "object".to_string(),
1213 got: "string".to_string(),
1214 })
1215 })
1216 }
1217 }
1218
1219 let provider = CountingNonRetryable { count };
1220 let result = Agent::new()
1221 .prompt("test")
1222 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1223 .run(&provider)
1224 .await;
1225
1226 assert!(result.is_err());
1227 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1229 }
1230
1231 #[tokio::test]
1232 async fn no_retry_without_policy() {
1233 let provider = FailNTimesProvider {
1234 fail_count: AtomicU32::new(0),
1235 failures_before_success: 1,
1236 output: default_output(),
1237 };
1238 let result = Agent::new().prompt("test").run(&provider).await;
1239
1240 assert!(result.is_err());
1241 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1242 }
1243}