1use std::any;
29use std::sync::Arc;
30
31use schemars::{JsonSchema, schema_for};
32use serde::de::DeserializeOwned;
33use serde::{Deserialize, Serialize};
34use serde_json::{Value, from_value, to_string};
35use tokio::time;
36use tracing::{info, warn};
37
38use crate::error::OperationError;
39#[cfg(feature = "prometheus")]
40use crate::metric_names;
41use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, LogSink};
42use crate::retry::RetryPolicy;
43
44pub struct Model;
75
76impl Model {
77 pub const SONNET: &str = "sonnet";
81 pub const OPUS: &str = "opus";
83 pub const HAIKU: &str = "haiku";
85
86 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
90
91 pub const SONNET_46: &str = "claude-sonnet-4-6";
95 pub const OPUS_46: &str = "claude-opus-4-6";
97
98 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
102 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
104
105 pub const OPUS_47: &str = "claude-opus-4-7";
109 pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
111}
112
113#[derive(Debug, Default, Clone, Copy, Serialize)]
118pub enum PermissionMode {
119 #[default]
121 Default,
122 Auto,
124 DontAsk,
126 BypassPermissions,
130}
131
132impl<'de> Deserialize<'de> for PermissionMode {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: serde::Deserializer<'de>,
136 {
137 let s = String::deserialize(deserializer)?;
138 Ok(match s.to_lowercase().replace('_', "").as_str() {
139 "auto" => Self::Auto,
140 "dontask" => Self::DontAsk,
141 "bypass" | "bypasspermissions" => Self::BypassPermissions,
142 _ => Self::Default,
143 })
144 }
145}
146
147#[must_use = "an Agent does nothing until .run() is awaited"]
177pub struct Agent {
178 config: AgentConfig,
179 dry_run: Option<bool>,
180 retry_policy: Option<RetryPolicy>,
181 log_sink: Option<Arc<dyn LogSink>>,
182}
183
184impl Agent {
185 pub fn new() -> Self {
190 Self {
191 config: AgentConfig::new(""),
192 dry_run: None,
193 retry_policy: None,
194 log_sink: None,
195 }
196 }
197
198 pub fn from_config(config: impl Into<AgentConfig>) -> Self {
217 Self {
218 config: config.into(),
219 dry_run: None,
220 retry_policy: None,
221 log_sink: None,
222 }
223 }
224
225 pub fn system_prompt(mut self, prompt: &str) -> Self {
227 self.config.system_prompt = Some(prompt.to_string());
228 self
229 }
230
231 pub fn prompt(mut self, prompt: &str) -> Self {
233 self.config.prompt = prompt.to_string();
234 self
235 }
236
237 pub fn model(mut self, model: impl Into<String>) -> Self {
244 self.config.model = model.into();
245 self
246 }
247
248 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
253 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
254 self
255 }
256
257 pub fn max_turns(mut self, turns: u32) -> Self {
263 assert!(turns > 0, "max_turns must be greater than 0");
264 self.config.max_turns = Some(turns);
265 self
266 }
267
268 pub fn max_budget_usd(mut self, budget: f64) -> Self {
274 assert!(
275 budget.is_finite() && budget > 0.0,
276 "budget must be a positive finite number, got {budget}"
277 );
278 self.config.max_budget_usd = Some(budget);
279 self
280 }
281
282 pub fn working_dir(mut self, dir: &str) -> Self {
284 self.config.working_dir = Some(dir.to_string());
285 self
286 }
287
288 pub fn mcp_config(mut self, config: &str) -> Self {
290 self.config.mcp_config = Some(config.to_string());
291 self
292 }
293
294 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
298 self.config.permission_mode = mode;
299 self
300 }
301
302 pub fn output<T: JsonSchema>(mut self) -> Self {
333 let schema = schema_for!(T);
334 self.config.json_schema = match to_string(&schema) {
335 Ok(s) => Some(s),
336 Err(e) => {
337 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
338 None
339 }
340 };
341 self
342 }
343
344 pub fn output_schema_raw(mut self, schema: &str) -> Self {
368 self.config.json_schema = Some(schema.to_string());
369 self
370 }
371
372 pub fn retry(mut self, max_retries: u32) -> Self {
400 self.retry_policy = Some(RetryPolicy::new(max_retries));
401 self
402 }
403
404 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
431 self.retry_policy = Some(policy);
432 self
433 }
434
435 pub fn dry_run(mut self, enabled: bool) -> Self {
444 self.dry_run = Some(enabled);
445 self
446 }
447
448 pub fn log_sink(mut self, sink: Arc<dyn LogSink>) -> Self {
475 self.log_sink = Some(sink);
476 self
477 }
478
479 pub fn verbose(mut self) -> Self {
509 self.config.verbose = true;
510 self
511 }
512
513 pub fn resume(mut self, session_id: &str) -> Self {
549 assert!(!session_id.is_empty(), "session_id must not be empty");
550 assert!(
551 session_id
552 .chars()
553 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
554 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
555 );
556 self.config.resume_session_id = Some(session_id.to_string());
557 self
558 }
559
560 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
577 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
578 assert!(
579 !self.config.prompt.trim().is_empty(),
580 "prompt must not be empty - call .prompt(\"...\") before .run()"
581 );
582
583 if crate::dry_run::effective_dry_run(self.dry_run) {
584 info!(
585 prompt_len = self.config.prompt.len(),
586 "[dry-run] agent call skipped"
587 );
588 let mut output =
589 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
590 output.cost_usd = Some(0.0);
591 output.input_tokens = Some(0);
592 output.output_tokens = Some(0);
593 return Ok(AgentResult { output });
594 }
595
596 let result = self.invoke_once(provider).await;
597
598 let policy = match &self.retry_policy {
599 Some(p) => p,
600 None => return result,
601 };
602
603 if let Err(ref err) = result {
605 if !crate::retry::is_retryable(err) {
606 return result;
607 }
608 } else {
609 return result;
610 }
611
612 let mut last_result = result;
613
614 for attempt in 0..policy.max_retries {
615 let delay = policy.delay_for_attempt(attempt);
616 warn!(
617 attempt = attempt + 1,
618 max_retries = policy.max_retries,
619 delay_ms = delay.as_millis() as u64,
620 "retrying agent invocation"
621 );
622 time::sleep(delay).await;
623
624 last_result = self.invoke_once(provider).await;
625
626 match &last_result {
627 Ok(_) => return last_result,
628 Err(err) if !crate::retry::is_retryable(err) => return last_result,
629 _ => {}
630 }
631 }
632
633 last_result
634 }
635
636 async fn invoke_once(
638 &self,
639 provider: &dyn AgentProvider,
640 ) -> Result<AgentResult, OperationError> {
641 #[cfg(feature = "prometheus")]
642 let model_label = self.config.model.to_string();
643
644 let invoke_result = match self.log_sink {
645 Some(ref sink) => provider.invoke_with_logs(&self.config, sink.clone()).await,
646 None => provider.invoke(&self.config).await,
647 };
648 let output = match invoke_result {
649 Ok(output) => output,
650 Err(e) => {
651 #[cfg(feature = "prometheus")]
652 {
653 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
654 }
655 return Err(OperationError::Agent(e));
656 }
657 };
658
659 info!(
660 duration_ms = output.duration_ms,
661 cost_usd = output.cost_usd,
662 input_tokens = output.input_tokens,
663 output_tokens = output.output_tokens,
664 model = output.model,
665 "agent completed"
666 );
667
668 #[cfg(feature = "prometheus")]
669 {
670 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
671 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
672 .record(output.duration_ms as f64 / 1000.0);
673 if let Some(cost) = output.cost_usd {
674 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
675 .increment(cost);
676 }
677 if let Some(tokens) = output.input_tokens {
678 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
679 }
680 if let Some(tokens) = output.output_tokens {
681 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
682 .increment(tokens);
683 }
684 }
685
686 Ok(AgentResult { output })
687 }
688}
689
690impl Default for Agent {
691 fn default() -> Self {
692 Self::new()
693 }
694}
695
696#[derive(Debug)]
701pub struct AgentResult {
702 output: AgentOutput,
703}
704
705impl AgentResult {
706 pub fn text(&self) -> &str {
711 match self.output.value.as_str() {
712 Some(s) => s,
713 None => {
714 warn!(
715 value_type = self.output.value.to_string(),
716 "agent output is not a string, returning empty"
717 );
718 ""
719 }
720 }
721 }
722
723 pub fn value(&self) -> &Value {
725 &self.output.value
726 }
727
728 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
738 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
739 }
740
741 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
747 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
748 }
749
750 #[cfg(test)]
755 pub(crate) fn from_output(output: AgentOutput) -> Self {
756 Self { output }
757 }
758
759 pub fn session_id(&self) -> Option<&str> {
761 self.output.session_id.as_deref()
762 }
763
764 pub fn cost_usd(&self) -> Option<f64> {
766 self.output.cost_usd
767 }
768
769 pub fn input_tokens(&self) -> Option<u64> {
771 self.output.input_tokens
772 }
773
774 pub fn output_tokens(&self) -> Option<u64> {
776 self.output.output_tokens
777 }
778
779 pub fn duration_ms(&self) -> u64 {
781 self.output.duration_ms
782 }
783
784 pub fn model(&self) -> Option<&str> {
786 self.output.model.as_deref()
787 }
788
789 pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
795 self.output.debug_messages.as_deref()
796 }
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802 use crate::error::AgentError;
803 use crate::provider::InvokeFuture;
804 use serde_json::json;
805
806 struct TestProvider {
807 output: AgentOutput,
808 }
809
810 impl AgentProvider for TestProvider {
811 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
812 Box::pin(async move {
813 Ok(AgentOutput {
814 value: self.output.value.clone(),
815 session_id: self.output.session_id.clone(),
816 cost_usd: self.output.cost_usd,
817 input_tokens: self.output.input_tokens,
818 output_tokens: self.output.output_tokens,
819 model: self.output.model.clone(),
820 duration_ms: self.output.duration_ms,
821 debug_messages: None,
822 })
823 })
824 }
825 }
826
827 struct ConfigCapture {
828 output: AgentOutput,
829 }
830
831 impl AgentProvider for ConfigCapture {
832 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
833 let config_json = serde_json::to_value(config).unwrap();
834 Box::pin(async move {
835 Ok(AgentOutput {
836 value: config_json,
837 session_id: self.output.session_id.clone(),
838 cost_usd: self.output.cost_usd,
839 input_tokens: self.output.input_tokens,
840 output_tokens: self.output.output_tokens,
841 model: self.output.model.clone(),
842 duration_ms: self.output.duration_ms,
843 debug_messages: None,
844 })
845 })
846 }
847 }
848
849 fn default_output() -> AgentOutput {
850 AgentOutput {
851 value: json!("test output"),
852 session_id: Some("sess-123".to_string()),
853 cost_usd: Some(0.05),
854 input_tokens: Some(100),
855 output_tokens: Some(50),
856 model: Some("sonnet".to_string()),
857 duration_ms: 1500,
858 debug_messages: None,
859 }
860 }
861
862 #[test]
865 fn model_constants_have_expected_values() {
866 assert_eq!(Model::SONNET, "sonnet");
867 assert_eq!(Model::OPUS, "opus");
868 assert_eq!(Model::HAIKU, "haiku");
869 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
870 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
871 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
872 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
873 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
874 assert_eq!(Model::OPUS_47, "claude-opus-4-7");
875 assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
876 }
877
878 #[tokio::test]
881 async fn agent_new_default_values() {
882 let provider = ConfigCapture {
883 output: default_output(),
884 };
885 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
886
887 let config = result.value();
888 assert_eq!(config["system_prompt"], json!(null));
889 assert_eq!(config["prompt"], json!("hi"));
890 assert_eq!(config["model"], json!("sonnet"));
891 assert_eq!(config["allowed_tools"], json!([]));
892 assert_eq!(config["max_turns"], json!(null));
893 assert_eq!(config["max_budget_usd"], json!(null));
894 assert_eq!(config["working_dir"], json!(null));
895 assert_eq!(config["mcp_config"], json!(null));
896 assert_eq!(config["permission_mode"], json!("Default"));
897 assert_eq!(config["json_schema"], json!(null));
898 }
899
900 #[tokio::test]
901 async fn agent_default_matches_new() {
902 let provider = ConfigCapture {
903 output: default_output(),
904 };
905 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
906 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
907
908 assert_eq!(result_new.value(), result_default.value());
909 }
910
911 #[tokio::test]
914 async fn builder_methods_store_values_correctly() {
915 let provider = ConfigCapture {
916 output: default_output(),
917 };
918 let result = Agent::new()
919 .system_prompt("you are a bot")
920 .prompt("do something")
921 .model(Model::OPUS)
922 .allowed_tools(&["Read", "Write"])
923 .max_turns(5)
924 .max_budget_usd(1.5)
925 .working_dir("/tmp")
926 .mcp_config("{}")
927 .permission_mode(PermissionMode::Auto)
928 .run(&provider)
929 .await
930 .unwrap();
931
932 let config = result.value();
933 assert_eq!(config["system_prompt"], json!("you are a bot"));
934 assert_eq!(config["prompt"], json!("do something"));
935 assert_eq!(config["model"], json!("opus"));
936 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
937 assert_eq!(config["max_turns"], json!(5));
938 assert_eq!(config["max_budget_usd"], json!(1.5));
939 assert_eq!(config["working_dir"], json!("/tmp"));
940 assert_eq!(config["mcp_config"], json!("{}"));
941 assert_eq!(config["permission_mode"], json!("Auto"));
942 }
943
944 #[test]
947 #[should_panic(expected = "max_turns must be greater than 0")]
948 fn max_turns_zero_panics() {
949 let _ = Agent::new().max_turns(0);
950 }
951
952 #[test]
953 #[should_panic(expected = "budget must be a positive finite number")]
954 fn max_budget_negative_panics() {
955 let _ = Agent::new().max_budget_usd(-1.0);
956 }
957
958 #[test]
959 #[should_panic(expected = "budget must be a positive finite number")]
960 fn max_budget_nan_panics() {
961 let _ = Agent::new().max_budget_usd(f64::NAN);
962 }
963
964 #[test]
965 #[should_panic(expected = "budget must be a positive finite number")]
966 fn max_budget_infinity_panics() {
967 let _ = Agent::new().max_budget_usd(f64::INFINITY);
968 }
969
970 #[tokio::test]
973 async fn agent_result_text_with_string_value() {
974 let provider = TestProvider {
975 output: AgentOutput {
976 value: json!("hello world"),
977 ..default_output()
978 },
979 };
980 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
981 assert_eq!(result.text(), "hello world");
982 }
983
984 #[tokio::test]
985 async fn agent_result_text_with_non_string_value() {
986 let provider = TestProvider {
987 output: AgentOutput {
988 value: json!(42),
989 ..default_output()
990 },
991 };
992 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
993 assert_eq!(result.text(), "");
994 }
995
996 #[tokio::test]
997 async fn agent_result_text_with_null_value() {
998 let provider = TestProvider {
999 output: AgentOutput {
1000 value: json!(null),
1001 ..default_output()
1002 },
1003 };
1004 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1005 assert_eq!(result.text(), "");
1006 }
1007
1008 #[tokio::test]
1009 async fn agent_result_json_successful_deserialize() {
1010 #[derive(Deserialize, PartialEq, Debug)]
1011 struct MyOutput {
1012 name: String,
1013 count: u32,
1014 }
1015 let provider = TestProvider {
1016 output: AgentOutput {
1017 value: json!({"name": "test", "count": 7}),
1018 ..default_output()
1019 },
1020 };
1021 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1022 let parsed: MyOutput = result.json().unwrap();
1023 assert_eq!(parsed.name, "test");
1024 assert_eq!(parsed.count, 7);
1025 }
1026
1027 #[tokio::test]
1028 async fn agent_result_json_failed_deserialize() {
1029 #[derive(Debug, Deserialize)]
1030 #[allow(dead_code)]
1031 struct MyOutput {
1032 name: String,
1033 }
1034 let provider = TestProvider {
1035 output: AgentOutput {
1036 value: json!(42),
1037 ..default_output()
1038 },
1039 };
1040 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1041 let err = result.json::<MyOutput>().unwrap_err();
1042 assert!(matches!(err, OperationError::Deserialize { .. }));
1043 }
1044
1045 #[tokio::test]
1046 async fn agent_result_accessors() {
1047 let provider = TestProvider {
1048 output: AgentOutput {
1049 value: json!("v"),
1050 session_id: Some("s-1".to_string()),
1051 cost_usd: Some(0.123),
1052 input_tokens: Some(999),
1053 output_tokens: Some(456),
1054 model: Some("opus".to_string()),
1055 duration_ms: 2000,
1056 debug_messages: None,
1057 },
1058 };
1059 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1060 assert_eq!(result.session_id(), Some("s-1"));
1061 assert_eq!(result.cost_usd(), Some(0.123));
1062 assert_eq!(result.input_tokens(), Some(999));
1063 assert_eq!(result.output_tokens(), Some(456));
1064 assert_eq!(result.duration_ms(), 2000);
1065 assert_eq!(result.model(), Some("opus"));
1066 }
1067
1068 #[tokio::test]
1071 async fn resume_passes_session_id_in_config() {
1072 let provider = ConfigCapture {
1073 output: default_output(),
1074 };
1075 let result = Agent::new()
1076 .prompt("followup")
1077 .resume("sess-abc")
1078 .run(&provider)
1079 .await
1080 .unwrap();
1081
1082 let config = result.value();
1083 assert_eq!(config["resume_session_id"], json!("sess-abc"));
1084 }
1085
1086 #[tokio::test]
1087 async fn no_resume_has_null_session_id() {
1088 let provider = ConfigCapture {
1089 output: default_output(),
1090 };
1091 let result = Agent::new()
1092 .prompt("first call")
1093 .run(&provider)
1094 .await
1095 .unwrap();
1096
1097 let config = result.value();
1098 assert_eq!(config["resume_session_id"], json!(null));
1099 }
1100
1101 #[test]
1102 #[should_panic(expected = "session_id must not be empty")]
1103 fn resume_empty_session_id_panics() {
1104 let _ = Agent::new().resume("");
1105 }
1106
1107 #[test]
1108 #[should_panic(expected = "session_id must only contain")]
1109 fn resume_invalid_chars_panics() {
1110 let _ = Agent::new().resume("sess;rm -rf /");
1111 }
1112
1113 #[test]
1114 fn resume_valid_formats_accepted() {
1115 let _ = Agent::new().resume("sess-abc123");
1116 let _ = Agent::new().resume("a1b2c3d4_session");
1117 let _ = Agent::new().resume("abc-DEF-123_456");
1118 }
1119
1120 #[tokio::test]
1121 #[should_panic(expected = "prompt must not be empty")]
1122 async fn run_without_prompt_panics() {
1123 let provider = TestProvider {
1124 output: default_output(),
1125 };
1126 let _ = Agent::new().run(&provider).await;
1127 }
1128
1129 #[tokio::test]
1130 #[should_panic(expected = "prompt must not be empty")]
1131 async fn run_with_whitespace_only_prompt_panics() {
1132 let provider = TestProvider {
1133 output: default_output(),
1134 };
1135 let _ = Agent::new().prompt(" ").run(&provider).await;
1136 }
1137
1138 #[tokio::test]
1141 async fn model_accepts_custom_string() {
1142 let provider = ConfigCapture {
1143 output: default_output(),
1144 };
1145 let result = Agent::new()
1146 .prompt("hi")
1147 .model("mistral-large-latest")
1148 .run(&provider)
1149 .await
1150 .unwrap();
1151 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1152 }
1153
1154 #[tokio::test]
1155 async fn verbose_sets_config_flag() {
1156 let provider = ConfigCapture {
1157 output: default_output(),
1158 };
1159 let result = Agent::new()
1160 .prompt("hi")
1161 .verbose()
1162 .run(&provider)
1163 .await
1164 .unwrap();
1165 assert_eq!(result.value()["verbose"], json!(true));
1166 }
1167
1168 #[tokio::test]
1169 async fn verbose_not_set_by_default() {
1170 let provider = ConfigCapture {
1171 output: default_output(),
1172 };
1173 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1174 assert_eq!(result.value()["verbose"], json!(false));
1175 }
1176
1177 #[tokio::test]
1178 async fn debug_messages_none_without_verbose() {
1179 let provider = TestProvider {
1180 output: default_output(),
1181 };
1182 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1183 assert!(result.debug_messages().is_none());
1184 }
1185
1186 #[tokio::test]
1187 async fn model_accepts_owned_string() {
1188 let provider = ConfigCapture {
1189 output: default_output(),
1190 };
1191 let model_name = String::from("gpt-4o");
1192 let result = Agent::new()
1193 .prompt("hi")
1194 .model(model_name)
1195 .run(&provider)
1196 .await
1197 .unwrap();
1198 assert_eq!(result.value()["model"], json!("gpt-4o"));
1199 }
1200
1201 #[tokio::test]
1202 async fn into_json_success() {
1203 #[derive(Deserialize, PartialEq, Debug)]
1204 struct Out {
1205 name: String,
1206 }
1207 let provider = TestProvider {
1208 output: AgentOutput {
1209 value: json!({"name": "test"}),
1210 ..default_output()
1211 },
1212 };
1213 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1214 let parsed: Out = result.into_json().unwrap();
1215 assert_eq!(parsed.name, "test");
1216 }
1217
1218 #[tokio::test]
1219 async fn into_json_failure() {
1220 #[derive(Debug, Deserialize)]
1221 #[allow(dead_code)]
1222 struct Out {
1223 name: String,
1224 }
1225 let provider = TestProvider {
1226 output: AgentOutput {
1227 value: json!(42),
1228 ..default_output()
1229 },
1230 };
1231 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1232 let err = result.into_json::<Out>().unwrap_err();
1233 assert!(matches!(err, OperationError::Deserialize { .. }));
1234 }
1235
1236 #[test]
1237 fn from_output_creates_result() {
1238 let output = AgentOutput {
1239 value: json!("hello"),
1240 ..default_output()
1241 };
1242 let result = AgentResult::from_output(output);
1243 assert_eq!(result.text(), "hello");
1244 assert_eq!(result.cost_usd(), Some(0.05));
1245 }
1246
1247 #[test]
1248 #[should_panic(expected = "budget must be a positive finite number")]
1249 fn max_budget_zero_panics() {
1250 let _ = Agent::new().max_budget_usd(0.0);
1251 }
1252
1253 #[test]
1254 fn model_constant_equality() {
1255 assert_eq!(Model::SONNET, "sonnet");
1256 assert_ne!(Model::SONNET, Model::OPUS);
1257 }
1258
1259 #[test]
1260 fn permission_mode_serialize_deserialize_roundtrip() {
1261 for mode in [
1262 PermissionMode::Default,
1263 PermissionMode::Auto,
1264 PermissionMode::DontAsk,
1265 PermissionMode::BypassPermissions,
1266 ] {
1267 let json = to_string(&mode).unwrap();
1268 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1269 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1270 }
1271 }
1272
1273 #[test]
1276 fn retry_builder_stores_policy() {
1277 let agent = Agent::new().retry(3);
1278 assert!(agent.retry_policy.is_some());
1279 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1280 }
1281
1282 #[test]
1283 fn retry_policy_builder_stores_custom_policy() {
1284 use crate::retry::RetryPolicy;
1285 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1286 let agent = Agent::new().retry_policy(policy);
1287 let p = agent.retry_policy.unwrap();
1288 assert_eq!(p.max_retries(), 5);
1289 }
1290
1291 #[test]
1292 fn no_retry_by_default() {
1293 let agent = Agent::new();
1294 assert!(agent.retry_policy.is_none());
1295 }
1296
1297 use std::sync::Arc;
1300 use std::sync::atomic::{AtomicU32, Ordering};
1301 use std::time::Duration;
1302
1303 struct FailNTimesProvider {
1304 fail_count: AtomicU32,
1305 failures_before_success: u32,
1306 output: AgentOutput,
1307 }
1308
1309 impl AgentProvider for FailNTimesProvider {
1310 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1311 Box::pin(async move {
1312 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1313 if current < self.failures_before_success {
1314 Err(AgentError::ProcessFailed {
1315 exit_code: 1,
1316 stderr: format!("transient failure #{}", current + 1),
1317 })
1318 } else {
1319 Ok(AgentOutput {
1320 value: self.output.value.clone(),
1321 session_id: self.output.session_id.clone(),
1322 cost_usd: self.output.cost_usd,
1323 input_tokens: self.output.input_tokens,
1324 output_tokens: self.output.output_tokens,
1325 model: self.output.model.clone(),
1326 duration_ms: self.output.duration_ms,
1327 debug_messages: None,
1328 })
1329 }
1330 })
1331 }
1332 }
1333
1334 #[tokio::test]
1335 async fn retry_succeeds_after_transient_failures() {
1336 let provider = FailNTimesProvider {
1337 fail_count: AtomicU32::new(0),
1338 failures_before_success: 2,
1339 output: default_output(),
1340 };
1341 let result = Agent::new()
1342 .prompt("test")
1343 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1344 .run(&provider)
1345 .await;
1346
1347 assert!(result.is_ok());
1348 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1350
1351 #[tokio::test]
1352 async fn retry_exhausted_returns_last_error() {
1353 let provider = FailNTimesProvider {
1354 fail_count: AtomicU32::new(0),
1355 failures_before_success: 10, output: default_output(),
1357 };
1358 let result = Agent::new()
1359 .prompt("test")
1360 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1361 .run(&provider)
1362 .await;
1363
1364 assert!(result.is_err());
1365 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1367 }
1368
1369 #[tokio::test]
1370 async fn retry_does_not_retry_non_retryable_errors() {
1371 let call_count = Arc::new(AtomicU32::new(0));
1372 let count = call_count.clone();
1373
1374 struct CountingNonRetryable {
1375 count: Arc<AtomicU32>,
1376 }
1377 impl AgentProvider for CountingNonRetryable {
1378 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1379 self.count.fetch_add(1, Ordering::SeqCst);
1380 Box::pin(async move {
1381 Err(AgentError::SchemaValidation {
1382 expected: "object".to_string(),
1383 got: "string".to_string(),
1384 debug_messages: Vec::new(),
1385 partial_usage: Box::default(),
1386 raw_response: None,
1387 })
1388 })
1389 }
1390 }
1391
1392 let provider = CountingNonRetryable { count };
1393 let result = Agent::new()
1394 .prompt("test")
1395 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1396 .run(&provider)
1397 .await;
1398
1399 assert!(result.is_err());
1400 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1402 }
1403
1404 #[tokio::test]
1405 async fn no_retry_without_policy() {
1406 let provider = FailNTimesProvider {
1407 fail_count: AtomicU32::new(0),
1408 failures_before_success: 1,
1409 output: default_output(),
1410 };
1411 let result = Agent::new().prompt("test").run(&provider).await;
1412
1413 assert!(result.is_err());
1414 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1415 }
1416
1417 use crate::test_support::VecSink;
1420
1421 struct SinkCapture {
1422 output: AgentOutput,
1423 saw_logs: Arc<AtomicU32>,
1424 }
1425
1426 impl AgentProvider for SinkCapture {
1427 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1428 Box::pin(async {
1429 Ok(AgentOutput {
1430 value: self.output.value.clone(),
1431 session_id: self.output.session_id.clone(),
1432 cost_usd: self.output.cost_usd,
1433 input_tokens: self.output.input_tokens,
1434 output_tokens: self.output.output_tokens,
1435 model: self.output.model.clone(),
1436 duration_ms: self.output.duration_ms,
1437 debug_messages: None,
1438 })
1439 })
1440 }
1441
1442 fn invoke_with_logs<'a>(
1443 &'a self,
1444 config: &'a AgentConfig,
1445 log_sink: Arc<dyn LogSink>,
1446 ) -> InvokeFuture<'a> {
1447 self.saw_logs.fetch_add(1, Ordering::SeqCst);
1448 log_sink.log("stdout", "streaming line");
1449 self.invoke(config)
1450 }
1451 }
1452
1453 #[tokio::test]
1454 async fn log_sink_routes_to_invoke_with_logs() {
1455 let saw_logs = Arc::new(AtomicU32::new(0));
1456 let provider = SinkCapture {
1457 output: default_output(),
1458 saw_logs: saw_logs.clone(),
1459 };
1460 let sink: Arc<dyn LogSink> = VecSink::new();
1461
1462 let result = Agent::new()
1463 .prompt("test")
1464 .log_sink(sink)
1465 .run(&provider)
1466 .await;
1467
1468 assert!(result.is_ok());
1469 assert_eq!(saw_logs.load(Ordering::SeqCst), 1);
1470 }
1471
1472 #[tokio::test]
1473 async fn no_log_sink_routes_to_invoke() {
1474 let saw_logs = Arc::new(AtomicU32::new(0));
1475 let provider = SinkCapture {
1476 output: default_output(),
1477 saw_logs: saw_logs.clone(),
1478 };
1479
1480 let result = Agent::new().prompt("test").run(&provider).await;
1481
1482 assert!(result.is_ok());
1483 assert_eq!(saw_logs.load(Ordering::SeqCst), 0);
1484 }
1485
1486 #[tokio::test]
1487 async fn log_sink_receives_provider_lines() {
1488 let saw_logs = Arc::new(AtomicU32::new(0));
1489 let provider = SinkCapture {
1490 output: default_output(),
1491 saw_logs: saw_logs.clone(),
1492 };
1493 let sink = VecSink::new();
1494
1495 let _ = Agent::new()
1496 .prompt("test")
1497 .log_sink(sink.clone() as Arc<dyn LogSink>)
1498 .run(&provider)
1499 .await;
1500
1501 let lines = sink.0.lock().unwrap();
1502 assert_eq!(lines.len(), 1);
1503 assert_eq!(lines[0].0, "stdout");
1504 assert_eq!(lines[0].1, "streaming line");
1505 }
1506}