1use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage};
41use crate::retry::RetryPolicy;
42
43pub struct Model;
74
75impl Model {
76 pub const SONNET: &str = "sonnet";
80 pub const OPUS: &str = "opus";
82 pub const HAIKU: &str = "haiku";
84
85 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90 pub const SONNET_46: &str = "claude-sonnet-4-6";
94 pub const OPUS_46: &str = "claude-opus-4-6";
96
97 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103
104 pub const OPUS_47: &str = "claude-opus-4-7";
108 pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
110}
111
112#[derive(Debug, Default, Clone, Copy, Serialize)]
117pub enum PermissionMode {
118 #[default]
120 Default,
121 Auto,
123 DontAsk,
125 BypassPermissions,
129}
130
131impl<'de> Deserialize<'de> for PermissionMode {
132 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133 where
134 D: serde::Deserializer<'de>,
135 {
136 let s = String::deserialize(deserializer)?;
137 Ok(match s.to_lowercase().replace('_', "").as_str() {
138 "auto" => Self::Auto,
139 "dontask" => Self::DontAsk,
140 "bypass" | "bypasspermissions" => Self::BypassPermissions,
141 _ => Self::Default,
142 })
143 }
144}
145
146#[must_use = "an Agent does nothing until .run() is awaited"]
176pub struct Agent {
177 config: AgentConfig,
178 dry_run: Option<bool>,
179 retry_policy: Option<RetryPolicy>,
180}
181
182impl Agent {
183 pub fn new() -> Self {
188 Self {
189 config: AgentConfig::new(""),
190 dry_run: None,
191 retry_policy: None,
192 }
193 }
194
195 pub fn from_config(config: impl Into<AgentConfig>) -> Self {
214 Self {
215 config: config.into(),
216 dry_run: None,
217 retry_policy: None,
218 }
219 }
220
221 pub fn system_prompt(mut self, prompt: &str) -> Self {
223 self.config.system_prompt = Some(prompt.to_string());
224 self
225 }
226
227 pub fn prompt(mut self, prompt: &str) -> Self {
229 self.config.prompt = prompt.to_string();
230 self
231 }
232
233 pub fn model(mut self, model: impl Into<String>) -> Self {
240 self.config.model = model.into();
241 self
242 }
243
244 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
249 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
250 self
251 }
252
253 pub fn max_turns(mut self, turns: u32) -> Self {
259 assert!(turns > 0, "max_turns must be greater than 0");
260 self.config.max_turns = Some(turns);
261 self
262 }
263
264 pub fn max_budget_usd(mut self, budget: f64) -> Self {
270 assert!(
271 budget.is_finite() && budget > 0.0,
272 "budget must be a positive finite number, got {budget}"
273 );
274 self.config.max_budget_usd = Some(budget);
275 self
276 }
277
278 pub fn working_dir(mut self, dir: &str) -> Self {
280 self.config.working_dir = Some(dir.to_string());
281 self
282 }
283
284 pub fn mcp_config(mut self, config: &str) -> Self {
286 self.config.mcp_config = Some(config.to_string());
287 self
288 }
289
290 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
294 self.config.permission_mode = mode;
295 self
296 }
297
298 pub fn output<T: JsonSchema>(mut self) -> Self {
329 let schema = schema_for!(T);
330 self.config.json_schema = match to_string(&schema) {
331 Ok(s) => Some(s),
332 Err(e) => {
333 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
334 None
335 }
336 };
337 self
338 }
339
340 pub fn output_schema_raw(mut self, schema: &str) -> Self {
364 self.config.json_schema = Some(schema.to_string());
365 self
366 }
367
368 pub fn retry(mut self, max_retries: u32) -> Self {
396 self.retry_policy = Some(RetryPolicy::new(max_retries));
397 self
398 }
399
400 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
427 self.retry_policy = Some(policy);
428 self
429 }
430
431 pub fn dry_run(mut self, enabled: bool) -> Self {
440 self.dry_run = Some(enabled);
441 self
442 }
443
444 pub fn verbose(mut self) -> Self {
474 self.config.verbose = true;
475 self
476 }
477
478 pub fn resume(mut self, session_id: &str) -> Self {
514 assert!(!session_id.is_empty(), "session_id must not be empty");
515 assert!(
516 session_id
517 .chars()
518 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
519 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
520 );
521 self.config.resume_session_id = Some(session_id.to_string());
522 self
523 }
524
525 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
542 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
543 assert!(
544 !self.config.prompt.trim().is_empty(),
545 "prompt must not be empty - call .prompt(\"...\") before .run()"
546 );
547
548 if crate::dry_run::effective_dry_run(self.dry_run) {
549 info!(
550 prompt_len = self.config.prompt.len(),
551 "[dry-run] agent call skipped"
552 );
553 let mut output =
554 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
555 output.cost_usd = Some(0.0);
556 output.input_tokens = Some(0);
557 output.output_tokens = Some(0);
558 return Ok(AgentResult { output });
559 }
560
561 let result = self.invoke_once(provider).await;
562
563 let policy = match &self.retry_policy {
564 Some(p) => p,
565 None => return result,
566 };
567
568 if let Err(ref err) = result {
570 if !crate::retry::is_retryable(err) {
571 return result;
572 }
573 } else {
574 return result;
575 }
576
577 let mut last_result = result;
578
579 for attempt in 0..policy.max_retries {
580 let delay = policy.delay_for_attempt(attempt);
581 warn!(
582 attempt = attempt + 1,
583 max_retries = policy.max_retries,
584 delay_ms = delay.as_millis() as u64,
585 "retrying agent invocation"
586 );
587 time::sleep(delay).await;
588
589 last_result = self.invoke_once(provider).await;
590
591 match &last_result {
592 Ok(_) => return last_result,
593 Err(err) if !crate::retry::is_retryable(err) => return last_result,
594 _ => {}
595 }
596 }
597
598 last_result
599 }
600
601 async fn invoke_once(
603 &self,
604 provider: &dyn AgentProvider,
605 ) -> Result<AgentResult, OperationError> {
606 #[cfg(feature = "prometheus")]
607 let model_label = self.config.model.to_string();
608
609 let output = match provider.invoke(&self.config).await {
610 Ok(output) => output,
611 Err(e) => {
612 #[cfg(feature = "prometheus")]
613 {
614 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
615 }
616 return Err(OperationError::Agent(e));
617 }
618 };
619
620 info!(
621 duration_ms = output.duration_ms,
622 cost_usd = output.cost_usd,
623 input_tokens = output.input_tokens,
624 output_tokens = output.output_tokens,
625 model = output.model,
626 "agent completed"
627 );
628
629 #[cfg(feature = "prometheus")]
630 {
631 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
632 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
633 .record(output.duration_ms as f64 / 1000.0);
634 if let Some(cost) = output.cost_usd {
635 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
636 .increment(cost);
637 }
638 if let Some(tokens) = output.input_tokens {
639 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
640 }
641 if let Some(tokens) = output.output_tokens {
642 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
643 .increment(tokens);
644 }
645 }
646
647 Ok(AgentResult { output })
648 }
649}
650
651impl Default for Agent {
652 fn default() -> Self {
653 Self::new()
654 }
655}
656
657#[derive(Debug)]
662pub struct AgentResult {
663 output: AgentOutput,
664}
665
666impl AgentResult {
667 pub fn text(&self) -> &str {
672 match self.output.value.as_str() {
673 Some(s) => s,
674 None => {
675 warn!(
676 value_type = self.output.value.to_string(),
677 "agent output is not a string, returning empty"
678 );
679 ""
680 }
681 }
682 }
683
684 pub fn value(&self) -> &Value {
686 &self.output.value
687 }
688
689 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
699 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
700 }
701
702 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
708 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
709 }
710
711 #[cfg(test)]
716 pub(crate) fn from_output(output: AgentOutput) -> Self {
717 Self { output }
718 }
719
720 pub fn session_id(&self) -> Option<&str> {
722 self.output.session_id.as_deref()
723 }
724
725 pub fn cost_usd(&self) -> Option<f64> {
727 self.output.cost_usd
728 }
729
730 pub fn input_tokens(&self) -> Option<u64> {
732 self.output.input_tokens
733 }
734
735 pub fn output_tokens(&self) -> Option<u64> {
737 self.output.output_tokens
738 }
739
740 pub fn duration_ms(&self) -> u64 {
742 self.output.duration_ms
743 }
744
745 pub fn model(&self) -> Option<&str> {
747 self.output.model.as_deref()
748 }
749
750 pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
756 self.output.debug_messages.as_deref()
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use crate::error::AgentError;
764 use crate::provider::InvokeFuture;
765 use serde_json::json;
766
767 struct TestProvider {
768 output: AgentOutput,
769 }
770
771 impl AgentProvider for TestProvider {
772 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
773 Box::pin(async move {
774 Ok(AgentOutput {
775 value: self.output.value.clone(),
776 session_id: self.output.session_id.clone(),
777 cost_usd: self.output.cost_usd,
778 input_tokens: self.output.input_tokens,
779 output_tokens: self.output.output_tokens,
780 model: self.output.model.clone(),
781 duration_ms: self.output.duration_ms,
782 debug_messages: None,
783 })
784 })
785 }
786 }
787
788 struct ConfigCapture {
789 output: AgentOutput,
790 }
791
792 impl AgentProvider for ConfigCapture {
793 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
794 let config_json = serde_json::to_value(config).unwrap();
795 Box::pin(async move {
796 Ok(AgentOutput {
797 value: config_json,
798 session_id: self.output.session_id.clone(),
799 cost_usd: self.output.cost_usd,
800 input_tokens: self.output.input_tokens,
801 output_tokens: self.output.output_tokens,
802 model: self.output.model.clone(),
803 duration_ms: self.output.duration_ms,
804 debug_messages: None,
805 })
806 })
807 }
808 }
809
810 fn default_output() -> AgentOutput {
811 AgentOutput {
812 value: json!("test output"),
813 session_id: Some("sess-123".to_string()),
814 cost_usd: Some(0.05),
815 input_tokens: Some(100),
816 output_tokens: Some(50),
817 model: Some("sonnet".to_string()),
818 duration_ms: 1500,
819 debug_messages: None,
820 }
821 }
822
823 #[test]
826 fn model_constants_have_expected_values() {
827 assert_eq!(Model::SONNET, "sonnet");
828 assert_eq!(Model::OPUS, "opus");
829 assert_eq!(Model::HAIKU, "haiku");
830 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
831 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
832 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
833 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
834 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
835 assert_eq!(Model::OPUS_47, "claude-opus-4-7");
836 assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
837 }
838
839 #[tokio::test]
842 async fn agent_new_default_values() {
843 let provider = ConfigCapture {
844 output: default_output(),
845 };
846 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
847
848 let config = result.value();
849 assert_eq!(config["system_prompt"], json!(null));
850 assert_eq!(config["prompt"], json!("hi"));
851 assert_eq!(config["model"], json!("sonnet"));
852 assert_eq!(config["allowed_tools"], json!([]));
853 assert_eq!(config["max_turns"], json!(null));
854 assert_eq!(config["max_budget_usd"], json!(null));
855 assert_eq!(config["working_dir"], json!(null));
856 assert_eq!(config["mcp_config"], json!(null));
857 assert_eq!(config["permission_mode"], json!("Default"));
858 assert_eq!(config["json_schema"], json!(null));
859 }
860
861 #[tokio::test]
862 async fn agent_default_matches_new() {
863 let provider = ConfigCapture {
864 output: default_output(),
865 };
866 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
867 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
868
869 assert_eq!(result_new.value(), result_default.value());
870 }
871
872 #[tokio::test]
875 async fn builder_methods_store_values_correctly() {
876 let provider = ConfigCapture {
877 output: default_output(),
878 };
879 let result = Agent::new()
880 .system_prompt("you are a bot")
881 .prompt("do something")
882 .model(Model::OPUS)
883 .allowed_tools(&["Read", "Write"])
884 .max_turns(5)
885 .max_budget_usd(1.5)
886 .working_dir("/tmp")
887 .mcp_config("{}")
888 .permission_mode(PermissionMode::Auto)
889 .run(&provider)
890 .await
891 .unwrap();
892
893 let config = result.value();
894 assert_eq!(config["system_prompt"], json!("you are a bot"));
895 assert_eq!(config["prompt"], json!("do something"));
896 assert_eq!(config["model"], json!("opus"));
897 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
898 assert_eq!(config["max_turns"], json!(5));
899 assert_eq!(config["max_budget_usd"], json!(1.5));
900 assert_eq!(config["working_dir"], json!("/tmp"));
901 assert_eq!(config["mcp_config"], json!("{}"));
902 assert_eq!(config["permission_mode"], json!("Auto"));
903 }
904
905 #[test]
908 #[should_panic(expected = "max_turns must be greater than 0")]
909 fn max_turns_zero_panics() {
910 let _ = Agent::new().max_turns(0);
911 }
912
913 #[test]
914 #[should_panic(expected = "budget must be a positive finite number")]
915 fn max_budget_negative_panics() {
916 let _ = Agent::new().max_budget_usd(-1.0);
917 }
918
919 #[test]
920 #[should_panic(expected = "budget must be a positive finite number")]
921 fn max_budget_nan_panics() {
922 let _ = Agent::new().max_budget_usd(f64::NAN);
923 }
924
925 #[test]
926 #[should_panic(expected = "budget must be a positive finite number")]
927 fn max_budget_infinity_panics() {
928 let _ = Agent::new().max_budget_usd(f64::INFINITY);
929 }
930
931 #[tokio::test]
934 async fn agent_result_text_with_string_value() {
935 let provider = TestProvider {
936 output: AgentOutput {
937 value: json!("hello world"),
938 ..default_output()
939 },
940 };
941 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
942 assert_eq!(result.text(), "hello world");
943 }
944
945 #[tokio::test]
946 async fn agent_result_text_with_non_string_value() {
947 let provider = TestProvider {
948 output: AgentOutput {
949 value: json!(42),
950 ..default_output()
951 },
952 };
953 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
954 assert_eq!(result.text(), "");
955 }
956
957 #[tokio::test]
958 async fn agent_result_text_with_null_value() {
959 let provider = TestProvider {
960 output: AgentOutput {
961 value: json!(null),
962 ..default_output()
963 },
964 };
965 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
966 assert_eq!(result.text(), "");
967 }
968
969 #[tokio::test]
970 async fn agent_result_json_successful_deserialize() {
971 #[derive(Deserialize, PartialEq, Debug)]
972 struct MyOutput {
973 name: String,
974 count: u32,
975 }
976 let provider = TestProvider {
977 output: AgentOutput {
978 value: json!({"name": "test", "count": 7}),
979 ..default_output()
980 },
981 };
982 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
983 let parsed: MyOutput = result.json().unwrap();
984 assert_eq!(parsed.name, "test");
985 assert_eq!(parsed.count, 7);
986 }
987
988 #[tokio::test]
989 async fn agent_result_json_failed_deserialize() {
990 #[derive(Debug, Deserialize)]
991 #[allow(dead_code)]
992 struct MyOutput {
993 name: String,
994 }
995 let provider = TestProvider {
996 output: AgentOutput {
997 value: json!(42),
998 ..default_output()
999 },
1000 };
1001 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1002 let err = result.json::<MyOutput>().unwrap_err();
1003 assert!(matches!(err, OperationError::Deserialize { .. }));
1004 }
1005
1006 #[tokio::test]
1007 async fn agent_result_accessors() {
1008 let provider = TestProvider {
1009 output: AgentOutput {
1010 value: json!("v"),
1011 session_id: Some("s-1".to_string()),
1012 cost_usd: Some(0.123),
1013 input_tokens: Some(999),
1014 output_tokens: Some(456),
1015 model: Some("opus".to_string()),
1016 duration_ms: 2000,
1017 debug_messages: None,
1018 },
1019 };
1020 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1021 assert_eq!(result.session_id(), Some("s-1"));
1022 assert_eq!(result.cost_usd(), Some(0.123));
1023 assert_eq!(result.input_tokens(), Some(999));
1024 assert_eq!(result.output_tokens(), Some(456));
1025 assert_eq!(result.duration_ms(), 2000);
1026 assert_eq!(result.model(), Some("opus"));
1027 }
1028
1029 #[tokio::test]
1032 async fn resume_passes_session_id_in_config() {
1033 let provider = ConfigCapture {
1034 output: default_output(),
1035 };
1036 let result = Agent::new()
1037 .prompt("followup")
1038 .resume("sess-abc")
1039 .run(&provider)
1040 .await
1041 .unwrap();
1042
1043 let config = result.value();
1044 assert_eq!(config["resume_session_id"], json!("sess-abc"));
1045 }
1046
1047 #[tokio::test]
1048 async fn no_resume_has_null_session_id() {
1049 let provider = ConfigCapture {
1050 output: default_output(),
1051 };
1052 let result = Agent::new()
1053 .prompt("first call")
1054 .run(&provider)
1055 .await
1056 .unwrap();
1057
1058 let config = result.value();
1059 assert_eq!(config["resume_session_id"], json!(null));
1060 }
1061
1062 #[test]
1063 #[should_panic(expected = "session_id must not be empty")]
1064 fn resume_empty_session_id_panics() {
1065 let _ = Agent::new().resume("");
1066 }
1067
1068 #[test]
1069 #[should_panic(expected = "session_id must only contain")]
1070 fn resume_invalid_chars_panics() {
1071 let _ = Agent::new().resume("sess;rm -rf /");
1072 }
1073
1074 #[test]
1075 fn resume_valid_formats_accepted() {
1076 let _ = Agent::new().resume("sess-abc123");
1077 let _ = Agent::new().resume("a1b2c3d4_session");
1078 let _ = Agent::new().resume("abc-DEF-123_456");
1079 }
1080
1081 #[tokio::test]
1082 #[should_panic(expected = "prompt must not be empty")]
1083 async fn run_without_prompt_panics() {
1084 let provider = TestProvider {
1085 output: default_output(),
1086 };
1087 let _ = Agent::new().run(&provider).await;
1088 }
1089
1090 #[tokio::test]
1091 #[should_panic(expected = "prompt must not be empty")]
1092 async fn run_with_whitespace_only_prompt_panics() {
1093 let provider = TestProvider {
1094 output: default_output(),
1095 };
1096 let _ = Agent::new().prompt(" ").run(&provider).await;
1097 }
1098
1099 #[tokio::test]
1102 async fn model_accepts_custom_string() {
1103 let provider = ConfigCapture {
1104 output: default_output(),
1105 };
1106 let result = Agent::new()
1107 .prompt("hi")
1108 .model("mistral-large-latest")
1109 .run(&provider)
1110 .await
1111 .unwrap();
1112 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1113 }
1114
1115 #[tokio::test]
1116 async fn verbose_sets_config_flag() {
1117 let provider = ConfigCapture {
1118 output: default_output(),
1119 };
1120 let result = Agent::new()
1121 .prompt("hi")
1122 .verbose()
1123 .run(&provider)
1124 .await
1125 .unwrap();
1126 assert_eq!(result.value()["verbose"], json!(true));
1127 }
1128
1129 #[tokio::test]
1130 async fn verbose_not_set_by_default() {
1131 let provider = ConfigCapture {
1132 output: default_output(),
1133 };
1134 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1135 assert_eq!(result.value()["verbose"], json!(false));
1136 }
1137
1138 #[tokio::test]
1139 async fn debug_messages_none_without_verbose() {
1140 let provider = TestProvider {
1141 output: default_output(),
1142 };
1143 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1144 assert!(result.debug_messages().is_none());
1145 }
1146
1147 #[tokio::test]
1148 async fn model_accepts_owned_string() {
1149 let provider = ConfigCapture {
1150 output: default_output(),
1151 };
1152 let model_name = String::from("gpt-4o");
1153 let result = Agent::new()
1154 .prompt("hi")
1155 .model(model_name)
1156 .run(&provider)
1157 .await
1158 .unwrap();
1159 assert_eq!(result.value()["model"], json!("gpt-4o"));
1160 }
1161
1162 #[tokio::test]
1163 async fn into_json_success() {
1164 #[derive(Deserialize, PartialEq, Debug)]
1165 struct Out {
1166 name: String,
1167 }
1168 let provider = TestProvider {
1169 output: AgentOutput {
1170 value: json!({"name": "test"}),
1171 ..default_output()
1172 },
1173 };
1174 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1175 let parsed: Out = result.into_json().unwrap();
1176 assert_eq!(parsed.name, "test");
1177 }
1178
1179 #[tokio::test]
1180 async fn into_json_failure() {
1181 #[derive(Debug, Deserialize)]
1182 #[allow(dead_code)]
1183 struct Out {
1184 name: String,
1185 }
1186 let provider = TestProvider {
1187 output: AgentOutput {
1188 value: json!(42),
1189 ..default_output()
1190 },
1191 };
1192 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1193 let err = result.into_json::<Out>().unwrap_err();
1194 assert!(matches!(err, OperationError::Deserialize { .. }));
1195 }
1196
1197 #[test]
1198 fn from_output_creates_result() {
1199 let output = AgentOutput {
1200 value: json!("hello"),
1201 ..default_output()
1202 };
1203 let result = AgentResult::from_output(output);
1204 assert_eq!(result.text(), "hello");
1205 assert_eq!(result.cost_usd(), Some(0.05));
1206 }
1207
1208 #[test]
1209 #[should_panic(expected = "budget must be a positive finite number")]
1210 fn max_budget_zero_panics() {
1211 let _ = Agent::new().max_budget_usd(0.0);
1212 }
1213
1214 #[test]
1215 fn model_constant_equality() {
1216 assert_eq!(Model::SONNET, "sonnet");
1217 assert_ne!(Model::SONNET, Model::OPUS);
1218 }
1219
1220 #[test]
1221 fn permission_mode_serialize_deserialize_roundtrip() {
1222 for mode in [
1223 PermissionMode::Default,
1224 PermissionMode::Auto,
1225 PermissionMode::DontAsk,
1226 PermissionMode::BypassPermissions,
1227 ] {
1228 let json = to_string(&mode).unwrap();
1229 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1230 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1231 }
1232 }
1233
1234 #[test]
1237 fn retry_builder_stores_policy() {
1238 let agent = Agent::new().retry(3);
1239 assert!(agent.retry_policy.is_some());
1240 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1241 }
1242
1243 #[test]
1244 fn retry_policy_builder_stores_custom_policy() {
1245 use crate::retry::RetryPolicy;
1246 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1247 let agent = Agent::new().retry_policy(policy);
1248 let p = agent.retry_policy.unwrap();
1249 assert_eq!(p.max_retries(), 5);
1250 }
1251
1252 #[test]
1253 fn no_retry_by_default() {
1254 let agent = Agent::new();
1255 assert!(agent.retry_policy.is_none());
1256 }
1257
1258 use std::sync::Arc;
1261 use std::sync::atomic::{AtomicU32, Ordering};
1262 use std::time::Duration;
1263
1264 struct FailNTimesProvider {
1265 fail_count: AtomicU32,
1266 failures_before_success: u32,
1267 output: AgentOutput,
1268 }
1269
1270 impl AgentProvider for FailNTimesProvider {
1271 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1272 Box::pin(async move {
1273 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1274 if current < self.failures_before_success {
1275 Err(AgentError::ProcessFailed {
1276 exit_code: 1,
1277 stderr: format!("transient failure #{}", current + 1),
1278 })
1279 } else {
1280 Ok(AgentOutput {
1281 value: self.output.value.clone(),
1282 session_id: self.output.session_id.clone(),
1283 cost_usd: self.output.cost_usd,
1284 input_tokens: self.output.input_tokens,
1285 output_tokens: self.output.output_tokens,
1286 model: self.output.model.clone(),
1287 duration_ms: self.output.duration_ms,
1288 debug_messages: None,
1289 })
1290 }
1291 })
1292 }
1293 }
1294
1295 #[tokio::test]
1296 async fn retry_succeeds_after_transient_failures() {
1297 let provider = FailNTimesProvider {
1298 fail_count: AtomicU32::new(0),
1299 failures_before_success: 2,
1300 output: default_output(),
1301 };
1302 let result = Agent::new()
1303 .prompt("test")
1304 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1305 .run(&provider)
1306 .await;
1307
1308 assert!(result.is_ok());
1309 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1311
1312 #[tokio::test]
1313 async fn retry_exhausted_returns_last_error() {
1314 let provider = FailNTimesProvider {
1315 fail_count: AtomicU32::new(0),
1316 failures_before_success: 10, output: default_output(),
1318 };
1319 let result = Agent::new()
1320 .prompt("test")
1321 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1322 .run(&provider)
1323 .await;
1324
1325 assert!(result.is_err());
1326 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1328 }
1329
1330 #[tokio::test]
1331 async fn retry_does_not_retry_non_retryable_errors() {
1332 let call_count = Arc::new(AtomicU32::new(0));
1333 let count = call_count.clone();
1334
1335 struct CountingNonRetryable {
1336 count: Arc<AtomicU32>,
1337 }
1338 impl AgentProvider for CountingNonRetryable {
1339 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1340 self.count.fetch_add(1, Ordering::SeqCst);
1341 Box::pin(async move {
1342 Err(AgentError::SchemaValidation {
1343 expected: "object".to_string(),
1344 got: "string".to_string(),
1345 debug_messages: Vec::new(),
1346 partial_usage: Box::default(),
1347 })
1348 })
1349 }
1350 }
1351
1352 let provider = CountingNonRetryable { count };
1353 let result = Agent::new()
1354 .prompt("test")
1355 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1356 .run(&provider)
1357 .await;
1358
1359 assert!(result.is_err());
1360 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1362 }
1363
1364 #[tokio::test]
1365 async fn no_retry_without_policy() {
1366 let provider = FailNTimesProvider {
1367 fail_count: AtomicU32::new(0),
1368 failures_before_success: 1,
1369 output: default_output(),
1370 };
1371 let result = Agent::new().prompt("test").run(&provider).await;
1372
1373 assert!(result.is_err());
1374 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1375 }
1376}