1use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage};
41use crate::retry::RetryPolicy;
42
43pub struct Model;
74
75impl Model {
76 pub const SONNET: &str = "sonnet";
80 pub const OPUS: &str = "opus";
82 pub const HAIKU: &str = "haiku";
84
85 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90 pub const SONNET_46: &str = "claude-sonnet-4-6";
94 pub const OPUS_46: &str = "claude-opus-4-6";
96
97 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103}
104
105#[derive(Debug, Default, Clone, Copy, Serialize)]
110pub enum PermissionMode {
111 #[default]
113 Default,
114 Auto,
116 DontAsk,
118 BypassPermissions,
122}
123
124impl<'de> Deserialize<'de> for PermissionMode {
125 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126 where
127 D: serde::Deserializer<'de>,
128 {
129 let s = String::deserialize(deserializer)?;
130 Ok(match s.to_lowercase().replace('_', "").as_str() {
131 "auto" => Self::Auto,
132 "dontask" => Self::DontAsk,
133 "bypass" | "bypasspermissions" => Self::BypassPermissions,
134 _ => Self::Default,
135 })
136 }
137}
138
139#[must_use = "an Agent does nothing until .run() is awaited"]
169pub struct Agent {
170 config: AgentConfig,
171 dry_run: Option<bool>,
172 retry_policy: Option<RetryPolicy>,
173}
174
175impl Agent {
176 pub fn new() -> Self {
181 Self {
182 config: AgentConfig::new(""),
183 dry_run: None,
184 retry_policy: None,
185 }
186 }
187
188 pub fn from_config(config: AgentConfig) -> Self {
207 Self {
208 config,
209 dry_run: None,
210 retry_policy: None,
211 }
212 }
213
214 pub fn system_prompt(mut self, prompt: &str) -> Self {
216 self.config.system_prompt = Some(prompt.to_string());
217 self
218 }
219
220 pub fn prompt(mut self, prompt: &str) -> Self {
222 self.config.prompt = prompt.to_string();
223 self
224 }
225
226 pub fn model(mut self, model: impl Into<String>) -> Self {
233 self.config.model = model.into();
234 self
235 }
236
237 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
242 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
243 self
244 }
245
246 pub fn max_turns(mut self, turns: u32) -> Self {
252 assert!(turns > 0, "max_turns must be greater than 0");
253 self.config.max_turns = Some(turns);
254 self
255 }
256
257 pub fn max_budget_usd(mut self, budget: f64) -> Self {
263 assert!(
264 budget.is_finite() && budget > 0.0,
265 "budget must be a positive finite number, got {budget}"
266 );
267 self.config.max_budget_usd = Some(budget);
268 self
269 }
270
271 pub fn working_dir(mut self, dir: &str) -> Self {
273 self.config.working_dir = Some(dir.to_string());
274 self
275 }
276
277 pub fn mcp_config(mut self, config: &str) -> Self {
279 self.config.mcp_config = Some(config.to_string());
280 self
281 }
282
283 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
287 self.config.permission_mode = mode;
288 self
289 }
290
291 pub fn output<T: JsonSchema>(mut self) -> Self {
322 let schema = schema_for!(T);
323 self.config.json_schema = match to_string(&schema) {
324 Ok(s) => Some(s),
325 Err(e) => {
326 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
327 None
328 }
329 };
330 self
331 }
332
333 pub fn output_schema_raw(mut self, schema: &str) -> Self {
357 self.config.json_schema = Some(schema.to_string());
358 self
359 }
360
361 pub fn retry(mut self, max_retries: u32) -> Self {
389 self.retry_policy = Some(RetryPolicy::new(max_retries));
390 self
391 }
392
393 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
420 self.retry_policy = Some(policy);
421 self
422 }
423
424 pub fn dry_run(mut self, enabled: bool) -> Self {
433 self.dry_run = Some(enabled);
434 self
435 }
436
437 pub fn verbose(mut self) -> Self {
467 self.config.verbose = true;
468 self
469 }
470
471 pub fn resume(mut self, session_id: &str) -> Self {
507 assert!(!session_id.is_empty(), "session_id must not be empty");
508 assert!(
509 session_id
510 .chars()
511 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
512 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
513 );
514 self.config.resume_session_id = Some(session_id.to_string());
515 self
516 }
517
518 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
535 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
536 assert!(
537 !self.config.prompt.trim().is_empty(),
538 "prompt must not be empty - call .prompt(\"...\") before .run()"
539 );
540
541 if crate::dry_run::effective_dry_run(self.dry_run) {
542 info!(
543 prompt_len = self.config.prompt.len(),
544 "[dry-run] agent call skipped"
545 );
546 let mut output =
547 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
548 output.cost_usd = Some(0.0);
549 output.input_tokens = Some(0);
550 output.output_tokens = Some(0);
551 return Ok(AgentResult { output });
552 }
553
554 let result = self.invoke_once(provider).await;
555
556 let policy = match &self.retry_policy {
557 Some(p) => p,
558 None => return result,
559 };
560
561 if let Err(ref err) = result {
563 if !crate::retry::is_retryable(err) {
564 return result;
565 }
566 } else {
567 return result;
568 }
569
570 let mut last_result = result;
571
572 for attempt in 0..policy.max_retries {
573 let delay = policy.delay_for_attempt(attempt);
574 warn!(
575 attempt = attempt + 1,
576 max_retries = policy.max_retries,
577 delay_ms = delay.as_millis() as u64,
578 "retrying agent invocation"
579 );
580 time::sleep(delay).await;
581
582 last_result = self.invoke_once(provider).await;
583
584 match &last_result {
585 Ok(_) => return last_result,
586 Err(err) if !crate::retry::is_retryable(err) => return last_result,
587 _ => {}
588 }
589 }
590
591 last_result
592 }
593
594 async fn invoke_once(
596 &self,
597 provider: &dyn AgentProvider,
598 ) -> Result<AgentResult, OperationError> {
599 #[cfg(feature = "prometheus")]
600 let model_label = self.config.model.to_string();
601
602 let output = match provider.invoke(&self.config).await {
603 Ok(output) => output,
604 Err(e) => {
605 #[cfg(feature = "prometheus")]
606 {
607 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
608 }
609 return Err(OperationError::Agent(e));
610 }
611 };
612
613 info!(
614 duration_ms = output.duration_ms,
615 cost_usd = output.cost_usd,
616 input_tokens = output.input_tokens,
617 output_tokens = output.output_tokens,
618 model = output.model,
619 "agent completed"
620 );
621
622 #[cfg(feature = "prometheus")]
623 {
624 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
625 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
626 .record(output.duration_ms as f64 / 1000.0);
627 if let Some(cost) = output.cost_usd {
628 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
629 .increment(cost);
630 }
631 if let Some(tokens) = output.input_tokens {
632 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
633 }
634 if let Some(tokens) = output.output_tokens {
635 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
636 .increment(tokens);
637 }
638 }
639
640 Ok(AgentResult { output })
641 }
642}
643
644impl Default for Agent {
645 fn default() -> Self {
646 Self::new()
647 }
648}
649
650#[derive(Debug)]
655pub struct AgentResult {
656 output: AgentOutput,
657}
658
659impl AgentResult {
660 pub fn text(&self) -> &str {
665 match self.output.value.as_str() {
666 Some(s) => s,
667 None => {
668 warn!(
669 value_type = self.output.value.to_string(),
670 "agent output is not a string, returning empty"
671 );
672 ""
673 }
674 }
675 }
676
677 pub fn value(&self) -> &Value {
679 &self.output.value
680 }
681
682 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
692 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
693 }
694
695 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
701 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
702 }
703
704 #[cfg(test)]
709 pub(crate) fn from_output(output: AgentOutput) -> Self {
710 Self { output }
711 }
712
713 pub fn session_id(&self) -> Option<&str> {
715 self.output.session_id.as_deref()
716 }
717
718 pub fn cost_usd(&self) -> Option<f64> {
720 self.output.cost_usd
721 }
722
723 pub fn input_tokens(&self) -> Option<u64> {
725 self.output.input_tokens
726 }
727
728 pub fn output_tokens(&self) -> Option<u64> {
730 self.output.output_tokens
731 }
732
733 pub fn duration_ms(&self) -> u64 {
735 self.output.duration_ms
736 }
737
738 pub fn model(&self) -> Option<&str> {
740 self.output.model.as_deref()
741 }
742
743 pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
749 self.output.debug_messages.as_deref()
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use crate::error::AgentError;
757 use crate::provider::InvokeFuture;
758 use serde_json::json;
759
760 struct TestProvider {
761 output: AgentOutput,
762 }
763
764 impl AgentProvider for TestProvider {
765 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
766 Box::pin(async move {
767 Ok(AgentOutput {
768 value: self.output.value.clone(),
769 session_id: self.output.session_id.clone(),
770 cost_usd: self.output.cost_usd,
771 input_tokens: self.output.input_tokens,
772 output_tokens: self.output.output_tokens,
773 model: self.output.model.clone(),
774 duration_ms: self.output.duration_ms,
775 debug_messages: None,
776 })
777 })
778 }
779 }
780
781 struct ConfigCapture {
782 output: AgentOutput,
783 }
784
785 impl AgentProvider for ConfigCapture {
786 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
787 let config_json = serde_json::to_value(config).unwrap();
788 Box::pin(async move {
789 Ok(AgentOutput {
790 value: config_json,
791 session_id: self.output.session_id.clone(),
792 cost_usd: self.output.cost_usd,
793 input_tokens: self.output.input_tokens,
794 output_tokens: self.output.output_tokens,
795 model: self.output.model.clone(),
796 duration_ms: self.output.duration_ms,
797 debug_messages: None,
798 })
799 })
800 }
801 }
802
803 fn default_output() -> AgentOutput {
804 AgentOutput {
805 value: json!("test output"),
806 session_id: Some("sess-123".to_string()),
807 cost_usd: Some(0.05),
808 input_tokens: Some(100),
809 output_tokens: Some(50),
810 model: Some("sonnet".to_string()),
811 duration_ms: 1500,
812 debug_messages: None,
813 }
814 }
815
816 #[test]
819 fn model_constants_have_expected_values() {
820 assert_eq!(Model::SONNET, "sonnet");
821 assert_eq!(Model::OPUS, "opus");
822 assert_eq!(Model::HAIKU, "haiku");
823 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
824 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
825 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
826 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
827 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
828 }
829
830 #[tokio::test]
833 async fn agent_new_default_values() {
834 let provider = ConfigCapture {
835 output: default_output(),
836 };
837 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
838
839 let config = result.value();
840 assert_eq!(config["system_prompt"], json!(null));
841 assert_eq!(config["prompt"], json!("hi"));
842 assert_eq!(config["model"], json!("sonnet"));
843 assert_eq!(config["allowed_tools"], json!([]));
844 assert_eq!(config["max_turns"], json!(null));
845 assert_eq!(config["max_budget_usd"], json!(null));
846 assert_eq!(config["working_dir"], json!(null));
847 assert_eq!(config["mcp_config"], json!(null));
848 assert_eq!(config["permission_mode"], json!("Default"));
849 assert_eq!(config["json_schema"], json!(null));
850 }
851
852 #[tokio::test]
853 async fn agent_default_matches_new() {
854 let provider = ConfigCapture {
855 output: default_output(),
856 };
857 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
858 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
859
860 assert_eq!(result_new.value(), result_default.value());
861 }
862
863 #[tokio::test]
866 async fn builder_methods_store_values_correctly() {
867 let provider = ConfigCapture {
868 output: default_output(),
869 };
870 let result = Agent::new()
871 .system_prompt("you are a bot")
872 .prompt("do something")
873 .model(Model::OPUS)
874 .allowed_tools(&["Read", "Write"])
875 .max_turns(5)
876 .max_budget_usd(1.5)
877 .working_dir("/tmp")
878 .mcp_config("{}")
879 .permission_mode(PermissionMode::Auto)
880 .run(&provider)
881 .await
882 .unwrap();
883
884 let config = result.value();
885 assert_eq!(config["system_prompt"], json!("you are a bot"));
886 assert_eq!(config["prompt"], json!("do something"));
887 assert_eq!(config["model"], json!("opus"));
888 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
889 assert_eq!(config["max_turns"], json!(5));
890 assert_eq!(config["max_budget_usd"], json!(1.5));
891 assert_eq!(config["working_dir"], json!("/tmp"));
892 assert_eq!(config["mcp_config"], json!("{}"));
893 assert_eq!(config["permission_mode"], json!("Auto"));
894 }
895
896 #[test]
899 #[should_panic(expected = "max_turns must be greater than 0")]
900 fn max_turns_zero_panics() {
901 let _ = Agent::new().max_turns(0);
902 }
903
904 #[test]
905 #[should_panic(expected = "budget must be a positive finite number")]
906 fn max_budget_negative_panics() {
907 let _ = Agent::new().max_budget_usd(-1.0);
908 }
909
910 #[test]
911 #[should_panic(expected = "budget must be a positive finite number")]
912 fn max_budget_nan_panics() {
913 let _ = Agent::new().max_budget_usd(f64::NAN);
914 }
915
916 #[test]
917 #[should_panic(expected = "budget must be a positive finite number")]
918 fn max_budget_infinity_panics() {
919 let _ = Agent::new().max_budget_usd(f64::INFINITY);
920 }
921
922 #[tokio::test]
925 async fn agent_result_text_with_string_value() {
926 let provider = TestProvider {
927 output: AgentOutput {
928 value: json!("hello world"),
929 ..default_output()
930 },
931 };
932 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
933 assert_eq!(result.text(), "hello world");
934 }
935
936 #[tokio::test]
937 async fn agent_result_text_with_non_string_value() {
938 let provider = TestProvider {
939 output: AgentOutput {
940 value: json!(42),
941 ..default_output()
942 },
943 };
944 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
945 assert_eq!(result.text(), "");
946 }
947
948 #[tokio::test]
949 async fn agent_result_text_with_null_value() {
950 let provider = TestProvider {
951 output: AgentOutput {
952 value: json!(null),
953 ..default_output()
954 },
955 };
956 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
957 assert_eq!(result.text(), "");
958 }
959
960 #[tokio::test]
961 async fn agent_result_json_successful_deserialize() {
962 #[derive(Deserialize, PartialEq, Debug)]
963 struct MyOutput {
964 name: String,
965 count: u32,
966 }
967 let provider = TestProvider {
968 output: AgentOutput {
969 value: json!({"name": "test", "count": 7}),
970 ..default_output()
971 },
972 };
973 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
974 let parsed: MyOutput = result.json().unwrap();
975 assert_eq!(parsed.name, "test");
976 assert_eq!(parsed.count, 7);
977 }
978
979 #[tokio::test]
980 async fn agent_result_json_failed_deserialize() {
981 #[derive(Debug, Deserialize)]
982 #[allow(dead_code)]
983 struct MyOutput {
984 name: String,
985 }
986 let provider = TestProvider {
987 output: AgentOutput {
988 value: json!(42),
989 ..default_output()
990 },
991 };
992 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
993 let err = result.json::<MyOutput>().unwrap_err();
994 assert!(matches!(err, OperationError::Deserialize { .. }));
995 }
996
997 #[tokio::test]
998 async fn agent_result_accessors() {
999 let provider = TestProvider {
1000 output: AgentOutput {
1001 value: json!("v"),
1002 session_id: Some("s-1".to_string()),
1003 cost_usd: Some(0.123),
1004 input_tokens: Some(999),
1005 output_tokens: Some(456),
1006 model: Some("opus".to_string()),
1007 duration_ms: 2000,
1008 debug_messages: None,
1009 },
1010 };
1011 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1012 assert_eq!(result.session_id(), Some("s-1"));
1013 assert_eq!(result.cost_usd(), Some(0.123));
1014 assert_eq!(result.input_tokens(), Some(999));
1015 assert_eq!(result.output_tokens(), Some(456));
1016 assert_eq!(result.duration_ms(), 2000);
1017 assert_eq!(result.model(), Some("opus"));
1018 }
1019
1020 #[tokio::test]
1023 async fn resume_passes_session_id_in_config() {
1024 let provider = ConfigCapture {
1025 output: default_output(),
1026 };
1027 let result = Agent::new()
1028 .prompt("followup")
1029 .resume("sess-abc")
1030 .run(&provider)
1031 .await
1032 .unwrap();
1033
1034 let config = result.value();
1035 assert_eq!(config["resume_session_id"], json!("sess-abc"));
1036 }
1037
1038 #[tokio::test]
1039 async fn no_resume_has_null_session_id() {
1040 let provider = ConfigCapture {
1041 output: default_output(),
1042 };
1043 let result = Agent::new()
1044 .prompt("first call")
1045 .run(&provider)
1046 .await
1047 .unwrap();
1048
1049 let config = result.value();
1050 assert_eq!(config["resume_session_id"], json!(null));
1051 }
1052
1053 #[test]
1054 #[should_panic(expected = "session_id must not be empty")]
1055 fn resume_empty_session_id_panics() {
1056 let _ = Agent::new().resume("");
1057 }
1058
1059 #[test]
1060 #[should_panic(expected = "session_id must only contain")]
1061 fn resume_invalid_chars_panics() {
1062 let _ = Agent::new().resume("sess;rm -rf /");
1063 }
1064
1065 #[test]
1066 fn resume_valid_formats_accepted() {
1067 let _ = Agent::new().resume("sess-abc123");
1068 let _ = Agent::new().resume("a1b2c3d4_session");
1069 let _ = Agent::new().resume("abc-DEF-123_456");
1070 }
1071
1072 #[tokio::test]
1073 #[should_panic(expected = "prompt must not be empty")]
1074 async fn run_without_prompt_panics() {
1075 let provider = TestProvider {
1076 output: default_output(),
1077 };
1078 let _ = Agent::new().run(&provider).await;
1079 }
1080
1081 #[tokio::test]
1082 #[should_panic(expected = "prompt must not be empty")]
1083 async fn run_with_whitespace_only_prompt_panics() {
1084 let provider = TestProvider {
1085 output: default_output(),
1086 };
1087 let _ = Agent::new().prompt(" ").run(&provider).await;
1088 }
1089
1090 #[tokio::test]
1093 async fn model_accepts_custom_string() {
1094 let provider = ConfigCapture {
1095 output: default_output(),
1096 };
1097 let result = Agent::new()
1098 .prompt("hi")
1099 .model("mistral-large-latest")
1100 .run(&provider)
1101 .await
1102 .unwrap();
1103 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1104 }
1105
1106 #[tokio::test]
1107 async fn verbose_sets_config_flag() {
1108 let provider = ConfigCapture {
1109 output: default_output(),
1110 };
1111 let result = Agent::new()
1112 .prompt("hi")
1113 .verbose()
1114 .run(&provider)
1115 .await
1116 .unwrap();
1117 assert_eq!(result.value()["verbose"], json!(true));
1118 }
1119
1120 #[tokio::test]
1121 async fn verbose_not_set_by_default() {
1122 let provider = ConfigCapture {
1123 output: default_output(),
1124 };
1125 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1126 assert_eq!(result.value()["verbose"], json!(false));
1127 }
1128
1129 #[tokio::test]
1130 async fn debug_messages_none_without_verbose() {
1131 let provider = TestProvider {
1132 output: default_output(),
1133 };
1134 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1135 assert!(result.debug_messages().is_none());
1136 }
1137
1138 #[tokio::test]
1139 async fn model_accepts_owned_string() {
1140 let provider = ConfigCapture {
1141 output: default_output(),
1142 };
1143 let model_name = String::from("gpt-4o");
1144 let result = Agent::new()
1145 .prompt("hi")
1146 .model(model_name)
1147 .run(&provider)
1148 .await
1149 .unwrap();
1150 assert_eq!(result.value()["model"], json!("gpt-4o"));
1151 }
1152
1153 #[tokio::test]
1154 async fn into_json_success() {
1155 #[derive(Deserialize, PartialEq, Debug)]
1156 struct Out {
1157 name: String,
1158 }
1159 let provider = TestProvider {
1160 output: AgentOutput {
1161 value: json!({"name": "test"}),
1162 ..default_output()
1163 },
1164 };
1165 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1166 let parsed: Out = result.into_json().unwrap();
1167 assert_eq!(parsed.name, "test");
1168 }
1169
1170 #[tokio::test]
1171 async fn into_json_failure() {
1172 #[derive(Debug, Deserialize)]
1173 #[allow(dead_code)]
1174 struct Out {
1175 name: String,
1176 }
1177 let provider = TestProvider {
1178 output: AgentOutput {
1179 value: json!(42),
1180 ..default_output()
1181 },
1182 };
1183 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1184 let err = result.into_json::<Out>().unwrap_err();
1185 assert!(matches!(err, OperationError::Deserialize { .. }));
1186 }
1187
1188 #[test]
1189 fn from_output_creates_result() {
1190 let output = AgentOutput {
1191 value: json!("hello"),
1192 ..default_output()
1193 };
1194 let result = AgentResult::from_output(output);
1195 assert_eq!(result.text(), "hello");
1196 assert_eq!(result.cost_usd(), Some(0.05));
1197 }
1198
1199 #[test]
1200 #[should_panic(expected = "budget must be a positive finite number")]
1201 fn max_budget_zero_panics() {
1202 let _ = Agent::new().max_budget_usd(0.0);
1203 }
1204
1205 #[test]
1206 fn model_constant_equality() {
1207 assert_eq!(Model::SONNET, "sonnet");
1208 assert_ne!(Model::SONNET, Model::OPUS);
1209 }
1210
1211 #[test]
1212 fn permission_mode_serialize_deserialize_roundtrip() {
1213 for mode in [
1214 PermissionMode::Default,
1215 PermissionMode::Auto,
1216 PermissionMode::DontAsk,
1217 PermissionMode::BypassPermissions,
1218 ] {
1219 let json = to_string(&mode).unwrap();
1220 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1221 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1222 }
1223 }
1224
1225 #[test]
1228 fn retry_builder_stores_policy() {
1229 let agent = Agent::new().retry(3);
1230 assert!(agent.retry_policy.is_some());
1231 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1232 }
1233
1234 #[test]
1235 fn retry_policy_builder_stores_custom_policy() {
1236 use crate::retry::RetryPolicy;
1237 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1238 let agent = Agent::new().retry_policy(policy);
1239 let p = agent.retry_policy.unwrap();
1240 assert_eq!(p.max_retries(), 5);
1241 }
1242
1243 #[test]
1244 fn no_retry_by_default() {
1245 let agent = Agent::new();
1246 assert!(agent.retry_policy.is_none());
1247 }
1248
1249 use std::sync::Arc;
1252 use std::sync::atomic::{AtomicU32, Ordering};
1253 use std::time::Duration;
1254
1255 struct FailNTimesProvider {
1256 fail_count: AtomicU32,
1257 failures_before_success: u32,
1258 output: AgentOutput,
1259 }
1260
1261 impl AgentProvider for FailNTimesProvider {
1262 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1263 Box::pin(async move {
1264 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1265 if current < self.failures_before_success {
1266 Err(AgentError::ProcessFailed {
1267 exit_code: 1,
1268 stderr: format!("transient failure #{}", current + 1),
1269 })
1270 } else {
1271 Ok(AgentOutput {
1272 value: self.output.value.clone(),
1273 session_id: self.output.session_id.clone(),
1274 cost_usd: self.output.cost_usd,
1275 input_tokens: self.output.input_tokens,
1276 output_tokens: self.output.output_tokens,
1277 model: self.output.model.clone(),
1278 duration_ms: self.output.duration_ms,
1279 debug_messages: None,
1280 })
1281 }
1282 })
1283 }
1284 }
1285
1286 #[tokio::test]
1287 async fn retry_succeeds_after_transient_failures() {
1288 let provider = FailNTimesProvider {
1289 fail_count: AtomicU32::new(0),
1290 failures_before_success: 2,
1291 output: default_output(),
1292 };
1293 let result = Agent::new()
1294 .prompt("test")
1295 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1296 .run(&provider)
1297 .await;
1298
1299 assert!(result.is_ok());
1300 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1302
1303 #[tokio::test]
1304 async fn retry_exhausted_returns_last_error() {
1305 let provider = FailNTimesProvider {
1306 fail_count: AtomicU32::new(0),
1307 failures_before_success: 10, output: default_output(),
1309 };
1310 let result = Agent::new()
1311 .prompt("test")
1312 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1313 .run(&provider)
1314 .await;
1315
1316 assert!(result.is_err());
1317 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1319 }
1320
1321 #[tokio::test]
1322 async fn retry_does_not_retry_non_retryable_errors() {
1323 let call_count = Arc::new(AtomicU32::new(0));
1324 let count = call_count.clone();
1325
1326 struct CountingNonRetryable {
1327 count: Arc<AtomicU32>,
1328 }
1329 impl AgentProvider for CountingNonRetryable {
1330 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1331 self.count.fetch_add(1, Ordering::SeqCst);
1332 Box::pin(async move {
1333 Err(AgentError::SchemaValidation {
1334 expected: "object".to_string(),
1335 got: "string".to_string(),
1336 })
1337 })
1338 }
1339 }
1340
1341 let provider = CountingNonRetryable { count };
1342 let result = Agent::new()
1343 .prompt("test")
1344 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1345 .run(&provider)
1346 .await;
1347
1348 assert!(result.is_err());
1349 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1351 }
1352
1353 #[tokio::test]
1354 async fn no_retry_without_policy() {
1355 let provider = FailNTimesProvider {
1356 fail_count: AtomicU32::new(0),
1357 failures_before_success: 1,
1358 output: default_output(),
1359 };
1360 let result = Agent::new().prompt("test").run(&provider).await;
1361
1362 assert!(result.is_err());
1363 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1364 }
1365}