1use std::any;
29use std::sync::Arc;
30
31use schemars::{JsonSchema, schema_for};
32use serde::de::DeserializeOwned;
33use serde::{Deserialize, Serialize};
34use serde_json::{Value, from_value, to_string};
35use tokio::time;
36use tracing::{info, warn};
37
38use crate::error::OperationError;
39#[cfg(feature = "prometheus")]
40use crate::metric_names;
41use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, LogSink};
42use crate::retry::RetryPolicy;
43
44pub struct Model;
75
76impl Model {
77 pub const SONNET: &str = "sonnet";
81 pub const OPUS: &str = "opus";
83 pub const HAIKU: &str = "haiku";
85
86 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
90
91 pub const SONNET_46: &str = "claude-sonnet-4-6";
95 pub const OPUS_46: &str = "claude-opus-4-6";
97
98 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
102 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
104
105 pub const OPUS_47: &str = "claude-opus-4-7";
109 pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
111
112 pub const OPUS_48: &str = "claude-opus-4-8";
116 pub const OPUS_48_1M: &str = "claude-opus-4-8[1m]";
118}
119
120#[derive(Debug, Default, Clone, Copy, Serialize)]
125pub enum PermissionMode {
126 #[default]
128 Default,
129 Auto,
131 DontAsk,
133 BypassPermissions,
137}
138
139impl<'de> Deserialize<'de> for PermissionMode {
140 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141 where
142 D: serde::Deserializer<'de>,
143 {
144 let s = String::deserialize(deserializer)?;
145 Ok(match s.to_lowercase().replace('_', "").as_str() {
146 "auto" => Self::Auto,
147 "dontask" => Self::DontAsk,
148 "bypass" | "bypasspermissions" => Self::BypassPermissions,
149 _ => Self::Default,
150 })
151 }
152}
153
154#[must_use = "an Agent does nothing until .run() is awaited"]
184pub struct Agent {
185 config: AgentConfig,
186 dry_run: Option<bool>,
187 retry_policy: Option<RetryPolicy>,
188 log_sink: Option<Arc<dyn LogSink>>,
189}
190
191impl Agent {
192 pub fn new() -> Self {
197 Self {
198 config: AgentConfig::new(""),
199 dry_run: None,
200 retry_policy: None,
201 log_sink: None,
202 }
203 }
204
205 pub fn from_config(config: impl Into<AgentConfig>) -> Self {
224 Self {
225 config: config.into(),
226 dry_run: None,
227 retry_policy: None,
228 log_sink: None,
229 }
230 }
231
232 pub fn system_prompt(mut self, prompt: &str) -> Self {
234 self.config.system_prompt = Some(prompt.to_string());
235 self
236 }
237
238 pub fn prompt(mut self, prompt: &str) -> Self {
240 self.config.prompt = prompt.to_string();
241 self
242 }
243
244 pub fn model(mut self, model: impl Into<String>) -> Self {
251 self.config.model = model.into();
252 self
253 }
254
255 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
260 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
261 self
262 }
263
264 pub fn max_turns(mut self, turns: u32) -> Self {
270 assert!(turns > 0, "max_turns must be greater than 0");
271 self.config.max_turns = Some(turns);
272 self
273 }
274
275 pub fn max_budget_usd(mut self, budget: f64) -> Self {
281 assert!(
282 budget.is_finite() && budget > 0.0,
283 "budget must be a positive finite number, got {budget}"
284 );
285 self.config.max_budget_usd = Some(budget);
286 self
287 }
288
289 pub fn working_dir(mut self, dir: &str) -> Self {
291 self.config.working_dir = Some(dir.to_string());
292 self
293 }
294
295 pub fn mcp_config(mut self, config: &str) -> Self {
297 self.config.mcp_config = Some(config.to_string());
298 self
299 }
300
301 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
305 self.config.permission_mode = mode;
306 self
307 }
308
309 pub fn output<T: JsonSchema>(mut self) -> Self {
340 let schema = schema_for!(T);
341 self.config.json_schema = match to_string(&schema) {
342 Ok(s) => Some(s),
343 Err(e) => {
344 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
345 None
346 }
347 };
348 self
349 }
350
351 pub fn output_schema_raw(mut self, schema: &str) -> Self {
375 self.config.json_schema = Some(schema.to_string());
376 self
377 }
378
379 pub fn retry(mut self, max_retries: u32) -> Self {
407 self.retry_policy = Some(RetryPolicy::new(max_retries));
408 self
409 }
410
411 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
438 self.retry_policy = Some(policy);
439 self
440 }
441
442 pub fn dry_run(mut self, enabled: bool) -> Self {
451 self.dry_run = Some(enabled);
452 self
453 }
454
455 pub fn log_sink(mut self, sink: Arc<dyn LogSink>) -> Self {
482 self.log_sink = Some(sink);
483 self
484 }
485
486 pub fn verbose(mut self) -> Self {
516 self.config.verbose = true;
517 self
518 }
519
520 pub fn resume(mut self, session_id: &str) -> Self {
556 assert!(!session_id.is_empty(), "session_id must not be empty");
557 assert!(
558 session_id
559 .chars()
560 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
561 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
562 );
563 self.config.resume_session_id = Some(session_id.to_string());
564 self
565 }
566
567 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
588 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
589 assert!(
590 !self.config.prompt.trim().is_empty(),
591 "prompt must not be empty - call .prompt(\"...\") before .run()"
592 );
593
594 if crate::dry_run::effective_dry_run(self.dry_run) {
595 info!(
596 prompt_len = self.config.prompt.len(),
597 "[dry-run] agent call skipped"
598 );
599 let mut output =
600 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
601 output.cost_usd = Some(0.0);
602 output.input_tokens = Some(0);
603 output.output_tokens = Some(0);
604 return Ok(AgentResult { output });
605 }
606
607 let result = self.invoke_once(provider).await;
608
609 let default_schema_retry = RetryPolicy::new(2);
610 let policy = match &self.retry_policy {
611 Some(p) => p,
612 None if self.config.json_schema.is_some() => &default_schema_retry,
613 None => return result,
614 };
615
616 if let Err(ref err) = result {
618 if !crate::retry::is_retryable(err) {
619 return result;
620 }
621 } else {
622 return result;
623 }
624
625 let mut last_result = result;
626
627 for attempt in 0..policy.max_retries {
628 let delay = policy.delay_for_attempt(attempt);
629 let retry_reason = if matches!(
630 &last_result,
631 Err(OperationError::Agent(
632 crate::error::AgentError::SchemaValidation { .. }
633 ))
634 ) {
635 "structured_output was null (CLI non-determinism)"
636 } else {
637 "transient failure"
638 };
639 warn!(
640 attempt = attempt + 1,
641 max_retries = policy.max_retries,
642 delay_ms = delay.as_millis() as u64,
643 reason = retry_reason,
644 "retrying agent invocation"
645 );
646 time::sleep(delay).await;
647
648 last_result = self.invoke_once(provider).await;
649
650 match &last_result {
651 Ok(_) => return last_result,
652 Err(err) if !crate::retry::is_retryable(err) => return last_result,
653 _ => {}
654 }
655 }
656
657 last_result
658 }
659
660 async fn invoke_once(
662 &self,
663 provider: &dyn AgentProvider,
664 ) -> Result<AgentResult, OperationError> {
665 #[cfg(feature = "prometheus")]
666 let model_label = self.config.model.to_string();
667
668 let invoke_result = match self.log_sink {
669 Some(ref sink) => provider.invoke_with_logs(&self.config, sink.clone()).await,
670 None => provider.invoke(&self.config).await,
671 };
672 let output = match invoke_result {
673 Ok(output) => output,
674 Err(e) => {
675 #[cfg(feature = "prometheus")]
676 {
677 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
678 }
679 return Err(OperationError::Agent(e));
680 }
681 };
682
683 info!(
684 duration_ms = output.duration_ms,
685 cost_usd = output.cost_usd,
686 input_tokens = output.input_tokens,
687 output_tokens = output.output_tokens,
688 model = output.model,
689 "agent completed"
690 );
691
692 #[cfg(feature = "prometheus")]
693 {
694 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
695 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
696 .record(output.duration_ms as f64 / 1000.0);
697 if let Some(cost) = output.cost_usd {
698 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
699 .increment(cost);
700 }
701 if let Some(tokens) = output.input_tokens {
702 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
703 }
704 if let Some(tokens) = output.output_tokens {
705 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
706 .increment(tokens);
707 }
708 }
709
710 Ok(AgentResult { output })
711 }
712}
713
714impl Default for Agent {
715 fn default() -> Self {
716 Self::new()
717 }
718}
719
720#[derive(Debug)]
725pub struct AgentResult {
726 output: AgentOutput,
727}
728
729impl AgentResult {
730 pub fn text(&self) -> &str {
735 match self.output.value.as_str() {
736 Some(s) => s,
737 None => {
738 warn!(
739 value_type = self.output.value.to_string(),
740 "agent output is not a string, returning empty"
741 );
742 ""
743 }
744 }
745 }
746
747 pub fn value(&self) -> &Value {
749 &self.output.value
750 }
751
752 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
762 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
763 }
764
765 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
771 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
772 }
773
774 #[cfg(test)]
779 pub(crate) fn from_output(output: AgentOutput) -> Self {
780 Self { output }
781 }
782
783 pub fn session_id(&self) -> Option<&str> {
785 self.output.session_id.as_deref()
786 }
787
788 pub fn cost_usd(&self) -> Option<f64> {
790 self.output.cost_usd
791 }
792
793 pub fn input_tokens(&self) -> Option<u64> {
795 self.output.input_tokens
796 }
797
798 pub fn output_tokens(&self) -> Option<u64> {
800 self.output.output_tokens
801 }
802
803 pub fn duration_ms(&self) -> u64 {
805 self.output.duration_ms
806 }
807
808 pub fn model(&self) -> Option<&str> {
810 self.output.model.as_deref()
811 }
812
813 pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
819 self.output.debug_messages.as_deref()
820 }
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826 use crate::error::AgentError;
827 use crate::provider::InvokeFuture;
828 use serde_json::json;
829
830 struct TestProvider {
831 output: AgentOutput,
832 }
833
834 impl AgentProvider for TestProvider {
835 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
836 Box::pin(async move {
837 Ok(AgentOutput {
838 value: self.output.value.clone(),
839 session_id: self.output.session_id.clone(),
840 cost_usd: self.output.cost_usd,
841 input_tokens: self.output.input_tokens,
842 output_tokens: self.output.output_tokens,
843 model: self.output.model.clone(),
844 duration_ms: self.output.duration_ms,
845 debug_messages: None,
846 })
847 })
848 }
849 }
850
851 struct ConfigCapture {
852 output: AgentOutput,
853 }
854
855 impl AgentProvider for ConfigCapture {
856 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
857 let config_json = serde_json::to_value(config).unwrap();
858 Box::pin(async move {
859 Ok(AgentOutput {
860 value: config_json,
861 session_id: self.output.session_id.clone(),
862 cost_usd: self.output.cost_usd,
863 input_tokens: self.output.input_tokens,
864 output_tokens: self.output.output_tokens,
865 model: self.output.model.clone(),
866 duration_ms: self.output.duration_ms,
867 debug_messages: None,
868 })
869 })
870 }
871 }
872
873 fn default_output() -> AgentOutput {
874 AgentOutput {
875 value: json!("test output"),
876 session_id: Some("sess-123".to_string()),
877 cost_usd: Some(0.05),
878 input_tokens: Some(100),
879 output_tokens: Some(50),
880 model: Some("sonnet".to_string()),
881 duration_ms: 1500,
882 debug_messages: None,
883 }
884 }
885
886 #[test]
889 fn model_constants_have_expected_values() {
890 assert_eq!(Model::SONNET, "sonnet");
891 assert_eq!(Model::OPUS, "opus");
892 assert_eq!(Model::HAIKU, "haiku");
893 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
894 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
895 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
896 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
897 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
898 assert_eq!(Model::OPUS_47, "claude-opus-4-7");
899 assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
900 assert_eq!(Model::OPUS_48, "claude-opus-4-8");
901 assert_eq!(Model::OPUS_48_1M, "claude-opus-4-8[1m]");
902 }
903
904 #[tokio::test]
907 async fn agent_new_default_values() {
908 let provider = ConfigCapture {
909 output: default_output(),
910 };
911 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
912
913 let config = result.value();
914 assert_eq!(config["system_prompt"], json!(null));
915 assert_eq!(config["prompt"], json!("hi"));
916 assert_eq!(config["model"], json!("sonnet"));
917 assert_eq!(config["allowed_tools"], json!([]));
918 assert_eq!(config["max_turns"], json!(null));
919 assert_eq!(config["max_budget_usd"], json!(null));
920 assert_eq!(config["working_dir"], json!(null));
921 assert_eq!(config["mcp_config"], json!(null));
922 assert_eq!(config["permission_mode"], json!("Default"));
923 assert_eq!(config["json_schema"], json!(null));
924 }
925
926 #[tokio::test]
927 async fn agent_default_matches_new() {
928 let provider = ConfigCapture {
929 output: default_output(),
930 };
931 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
932 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
933
934 assert_eq!(result_new.value(), result_default.value());
935 }
936
937 #[tokio::test]
940 async fn builder_methods_store_values_correctly() {
941 let provider = ConfigCapture {
942 output: default_output(),
943 };
944 let result = Agent::new()
945 .system_prompt("you are a bot")
946 .prompt("do something")
947 .model(Model::OPUS)
948 .allowed_tools(&["Read", "Write"])
949 .max_turns(5)
950 .max_budget_usd(1.5)
951 .working_dir("/tmp")
952 .mcp_config("{}")
953 .permission_mode(PermissionMode::Auto)
954 .run(&provider)
955 .await
956 .unwrap();
957
958 let config = result.value();
959 assert_eq!(config["system_prompt"], json!("you are a bot"));
960 assert_eq!(config["prompt"], json!("do something"));
961 assert_eq!(config["model"], json!("opus"));
962 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
963 assert_eq!(config["max_turns"], json!(5));
964 assert_eq!(config["max_budget_usd"], json!(1.5));
965 assert_eq!(config["working_dir"], json!("/tmp"));
966 assert_eq!(config["mcp_config"], json!("{}"));
967 assert_eq!(config["permission_mode"], json!("Auto"));
968 }
969
970 #[test]
973 #[should_panic(expected = "max_turns must be greater than 0")]
974 fn max_turns_zero_panics() {
975 let _ = Agent::new().max_turns(0);
976 }
977
978 #[test]
979 #[should_panic(expected = "budget must be a positive finite number")]
980 fn max_budget_negative_panics() {
981 let _ = Agent::new().max_budget_usd(-1.0);
982 }
983
984 #[test]
985 #[should_panic(expected = "budget must be a positive finite number")]
986 fn max_budget_nan_panics() {
987 let _ = Agent::new().max_budget_usd(f64::NAN);
988 }
989
990 #[test]
991 #[should_panic(expected = "budget must be a positive finite number")]
992 fn max_budget_infinity_panics() {
993 let _ = Agent::new().max_budget_usd(f64::INFINITY);
994 }
995
996 #[tokio::test]
999 async fn agent_result_text_with_string_value() {
1000 let provider = TestProvider {
1001 output: AgentOutput {
1002 value: json!("hello world"),
1003 ..default_output()
1004 },
1005 };
1006 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1007 assert_eq!(result.text(), "hello world");
1008 }
1009
1010 #[tokio::test]
1011 async fn agent_result_text_with_non_string_value() {
1012 let provider = TestProvider {
1013 output: AgentOutput {
1014 value: json!(42),
1015 ..default_output()
1016 },
1017 };
1018 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1019 assert_eq!(result.text(), "");
1020 }
1021
1022 #[tokio::test]
1023 async fn agent_result_text_with_null_value() {
1024 let provider = TestProvider {
1025 output: AgentOutput {
1026 value: json!(null),
1027 ..default_output()
1028 },
1029 };
1030 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1031 assert_eq!(result.text(), "");
1032 }
1033
1034 #[tokio::test]
1035 async fn agent_result_json_successful_deserialize() {
1036 #[derive(Deserialize, PartialEq, Debug)]
1037 struct MyOutput {
1038 name: String,
1039 count: u32,
1040 }
1041 let provider = TestProvider {
1042 output: AgentOutput {
1043 value: json!({"name": "test", "count": 7}),
1044 ..default_output()
1045 },
1046 };
1047 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1048 let parsed: MyOutput = result.json().unwrap();
1049 assert_eq!(parsed.name, "test");
1050 assert_eq!(parsed.count, 7);
1051 }
1052
1053 #[tokio::test]
1054 async fn agent_result_json_failed_deserialize() {
1055 #[derive(Debug, Deserialize)]
1056 #[allow(dead_code)]
1057 struct MyOutput {
1058 name: String,
1059 }
1060 let provider = TestProvider {
1061 output: AgentOutput {
1062 value: json!(42),
1063 ..default_output()
1064 },
1065 };
1066 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1067 let err = result.json::<MyOutput>().unwrap_err();
1068 assert!(matches!(err, OperationError::Deserialize { .. }));
1069 }
1070
1071 #[tokio::test]
1072 async fn agent_result_accessors() {
1073 let provider = TestProvider {
1074 output: AgentOutput {
1075 value: json!("v"),
1076 session_id: Some("s-1".to_string()),
1077 cost_usd: Some(0.123),
1078 input_tokens: Some(999),
1079 output_tokens: Some(456),
1080 model: Some("opus".to_string()),
1081 duration_ms: 2000,
1082 debug_messages: None,
1083 },
1084 };
1085 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1086 assert_eq!(result.session_id(), Some("s-1"));
1087 assert_eq!(result.cost_usd(), Some(0.123));
1088 assert_eq!(result.input_tokens(), Some(999));
1089 assert_eq!(result.output_tokens(), Some(456));
1090 assert_eq!(result.duration_ms(), 2000);
1091 assert_eq!(result.model(), Some("opus"));
1092 }
1093
1094 #[tokio::test]
1097 async fn resume_passes_session_id_in_config() {
1098 let provider = ConfigCapture {
1099 output: default_output(),
1100 };
1101 let result = Agent::new()
1102 .prompt("followup")
1103 .resume("sess-abc")
1104 .run(&provider)
1105 .await
1106 .unwrap();
1107
1108 let config = result.value();
1109 assert_eq!(config["resume_session_id"], json!("sess-abc"));
1110 }
1111
1112 #[tokio::test]
1113 async fn no_resume_has_null_session_id() {
1114 let provider = ConfigCapture {
1115 output: default_output(),
1116 };
1117 let result = Agent::new()
1118 .prompt("first call")
1119 .run(&provider)
1120 .await
1121 .unwrap();
1122
1123 let config = result.value();
1124 assert_eq!(config["resume_session_id"], json!(null));
1125 }
1126
1127 #[test]
1128 #[should_panic(expected = "session_id must not be empty")]
1129 fn resume_empty_session_id_panics() {
1130 let _ = Agent::new().resume("");
1131 }
1132
1133 #[test]
1134 #[should_panic(expected = "session_id must only contain")]
1135 fn resume_invalid_chars_panics() {
1136 let _ = Agent::new().resume("sess;rm -rf /");
1137 }
1138
1139 #[test]
1140 fn resume_valid_formats_accepted() {
1141 let _ = Agent::new().resume("sess-abc123");
1142 let _ = Agent::new().resume("a1b2c3d4_session");
1143 let _ = Agent::new().resume("abc-DEF-123_456");
1144 }
1145
1146 #[tokio::test]
1147 #[should_panic(expected = "prompt must not be empty")]
1148 async fn run_without_prompt_panics() {
1149 let provider = TestProvider {
1150 output: default_output(),
1151 };
1152 let _ = Agent::new().run(&provider).await;
1153 }
1154
1155 #[tokio::test]
1156 #[should_panic(expected = "prompt must not be empty")]
1157 async fn run_with_whitespace_only_prompt_panics() {
1158 let provider = TestProvider {
1159 output: default_output(),
1160 };
1161 let _ = Agent::new().prompt(" ").run(&provider).await;
1162 }
1163
1164 #[tokio::test]
1167 async fn model_accepts_custom_string() {
1168 let provider = ConfigCapture {
1169 output: default_output(),
1170 };
1171 let result = Agent::new()
1172 .prompt("hi")
1173 .model("mistral-large-latest")
1174 .run(&provider)
1175 .await
1176 .unwrap();
1177 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1178 }
1179
1180 #[tokio::test]
1181 async fn verbose_sets_config_flag() {
1182 let provider = ConfigCapture {
1183 output: default_output(),
1184 };
1185 let result = Agent::new()
1186 .prompt("hi")
1187 .verbose()
1188 .run(&provider)
1189 .await
1190 .unwrap();
1191 assert_eq!(result.value()["verbose"], json!(true));
1192 }
1193
1194 #[tokio::test]
1195 async fn verbose_not_set_by_default() {
1196 let provider = ConfigCapture {
1197 output: default_output(),
1198 };
1199 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1200 assert_eq!(result.value()["verbose"], json!(false));
1201 }
1202
1203 #[tokio::test]
1204 async fn debug_messages_none_without_verbose() {
1205 let provider = TestProvider {
1206 output: default_output(),
1207 };
1208 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1209 assert!(result.debug_messages().is_none());
1210 }
1211
1212 #[tokio::test]
1213 async fn model_accepts_owned_string() {
1214 let provider = ConfigCapture {
1215 output: default_output(),
1216 };
1217 let model_name = String::from("gpt-4o");
1218 let result = Agent::new()
1219 .prompt("hi")
1220 .model(model_name)
1221 .run(&provider)
1222 .await
1223 .unwrap();
1224 assert_eq!(result.value()["model"], json!("gpt-4o"));
1225 }
1226
1227 #[tokio::test]
1228 async fn into_json_success() {
1229 #[derive(Deserialize, PartialEq, Debug)]
1230 struct Out {
1231 name: String,
1232 }
1233 let provider = TestProvider {
1234 output: AgentOutput {
1235 value: json!({"name": "test"}),
1236 ..default_output()
1237 },
1238 };
1239 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1240 let parsed: Out = result.into_json().unwrap();
1241 assert_eq!(parsed.name, "test");
1242 }
1243
1244 #[tokio::test]
1245 async fn into_json_failure() {
1246 #[derive(Debug, Deserialize)]
1247 #[allow(dead_code)]
1248 struct Out {
1249 name: String,
1250 }
1251 let provider = TestProvider {
1252 output: AgentOutput {
1253 value: json!(42),
1254 ..default_output()
1255 },
1256 };
1257 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1258 let err = result.into_json::<Out>().unwrap_err();
1259 assert!(matches!(err, OperationError::Deserialize { .. }));
1260 }
1261
1262 #[test]
1263 fn from_output_creates_result() {
1264 let output = AgentOutput {
1265 value: json!("hello"),
1266 ..default_output()
1267 };
1268 let result = AgentResult::from_output(output);
1269 assert_eq!(result.text(), "hello");
1270 assert_eq!(result.cost_usd(), Some(0.05));
1271 }
1272
1273 #[test]
1274 #[should_panic(expected = "budget must be a positive finite number")]
1275 fn max_budget_zero_panics() {
1276 let _ = Agent::new().max_budget_usd(0.0);
1277 }
1278
1279 #[test]
1280 fn model_constant_equality() {
1281 assert_eq!(Model::SONNET, "sonnet");
1282 assert_ne!(Model::SONNET, Model::OPUS);
1283 }
1284
1285 #[test]
1286 fn permission_mode_serialize_deserialize_roundtrip() {
1287 for mode in [
1288 PermissionMode::Default,
1289 PermissionMode::Auto,
1290 PermissionMode::DontAsk,
1291 PermissionMode::BypassPermissions,
1292 ] {
1293 let json = to_string(&mode).unwrap();
1294 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1295 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1296 }
1297 }
1298
1299 #[test]
1302 fn retry_builder_stores_policy() {
1303 let agent = Agent::new().retry(3);
1304 assert!(agent.retry_policy.is_some());
1305 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1306 }
1307
1308 #[test]
1309 fn retry_policy_builder_stores_custom_policy() {
1310 use crate::retry::RetryPolicy;
1311 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1312 let agent = Agent::new().retry_policy(policy);
1313 let p = agent.retry_policy.unwrap();
1314 assert_eq!(p.max_retries(), 5);
1315 }
1316
1317 #[test]
1318 fn no_retry_by_default() {
1319 let agent = Agent::new();
1320 assert!(agent.retry_policy.is_none());
1321 }
1322
1323 use std::sync::Arc;
1326 use std::sync::atomic::{AtomicU32, Ordering};
1327 use std::time::Duration;
1328
1329 struct FailNTimesProvider {
1330 fail_count: AtomicU32,
1331 failures_before_success: u32,
1332 output: AgentOutput,
1333 }
1334
1335 impl AgentProvider for FailNTimesProvider {
1336 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1337 Box::pin(async move {
1338 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1339 if current < self.failures_before_success {
1340 Err(AgentError::ProcessFailed {
1341 exit_code: 1,
1342 stderr: format!("transient failure #{}", current + 1),
1343 })
1344 } else {
1345 Ok(AgentOutput {
1346 value: self.output.value.clone(),
1347 session_id: self.output.session_id.clone(),
1348 cost_usd: self.output.cost_usd,
1349 input_tokens: self.output.input_tokens,
1350 output_tokens: self.output.output_tokens,
1351 model: self.output.model.clone(),
1352 duration_ms: self.output.duration_ms,
1353 debug_messages: None,
1354 })
1355 }
1356 })
1357 }
1358 }
1359
1360 #[tokio::test]
1361 async fn retry_succeeds_after_transient_failures() {
1362 let provider = FailNTimesProvider {
1363 fail_count: AtomicU32::new(0),
1364 failures_before_success: 2,
1365 output: default_output(),
1366 };
1367 let result = Agent::new()
1368 .prompt("test")
1369 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1370 .run(&provider)
1371 .await;
1372
1373 assert!(result.is_ok());
1374 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1376
1377 #[tokio::test]
1378 async fn retry_exhausted_returns_last_error() {
1379 let provider = FailNTimesProvider {
1380 fail_count: AtomicU32::new(0),
1381 failures_before_success: 10, output: default_output(),
1383 };
1384 let result = Agent::new()
1385 .prompt("test")
1386 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1387 .run(&provider)
1388 .await;
1389
1390 assert!(result.is_err());
1391 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1393 }
1394
1395 #[tokio::test]
1396 async fn retry_does_not_retry_prompt_too_large() {
1397 let call_count = Arc::new(AtomicU32::new(0));
1398 let count = call_count.clone();
1399
1400 struct CountingNonRetryable {
1401 count: Arc<AtomicU32>,
1402 }
1403 impl AgentProvider for CountingNonRetryable {
1404 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1405 self.count.fetch_add(1, Ordering::SeqCst);
1406 Box::pin(async move {
1407 Err(AgentError::PromptTooLarge {
1408 chars: 1_000_000,
1409 estimated_tokens: 250_000,
1410 model_limit: 200_000,
1411 })
1412 })
1413 }
1414 }
1415
1416 let provider = CountingNonRetryable { count };
1417 let result = Agent::new()
1418 .prompt("test")
1419 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1420 .run(&provider)
1421 .await;
1422
1423 assert!(result.is_err());
1424 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1425 }
1426
1427 #[tokio::test]
1428 async fn retry_retries_schema_validation_errors() {
1429 let call_count = Arc::new(AtomicU32::new(0));
1430 let count = call_count.clone();
1431
1432 struct SchemaFailProvider {
1433 count: Arc<AtomicU32>,
1434 }
1435 impl AgentProvider for SchemaFailProvider {
1436 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1437 self.count.fetch_add(1, Ordering::SeqCst);
1438 Box::pin(async move {
1439 Err(AgentError::SchemaValidation {
1440 expected: "object".to_string(),
1441 got: "null".to_string(),
1442 debug_messages: Vec::new(),
1443 partial_usage: Box::default(),
1444 raw_response: None,
1445 })
1446 })
1447 }
1448 }
1449
1450 let provider = SchemaFailProvider { count };
1451 let result = Agent::new()
1452 .prompt("test")
1453 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1454 .run(&provider)
1455 .await;
1456
1457 assert!(result.is_err());
1458 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1460 }
1461
1462 #[tokio::test]
1463 async fn schema_validation_succeeds_on_retry() {
1464 let call_count = Arc::new(AtomicU32::new(0));
1465 let count = call_count.clone();
1466
1467 struct SchemaFailThenSucceed {
1468 count: Arc<AtomicU32>,
1469 output: AgentOutput,
1470 }
1471 impl AgentProvider for SchemaFailThenSucceed {
1472 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1473 let current = self.count.fetch_add(1, Ordering::SeqCst);
1474 let output = self.output.clone();
1475 Box::pin(async move {
1476 if current == 0 {
1477 Err(AgentError::SchemaValidation {
1478 expected: "structured_output field".to_string(),
1479 got: "null".to_string(),
1480 debug_messages: Vec::new(),
1481 partial_usage: Box::default(),
1482 raw_response: None,
1483 })
1484 } else {
1485 Ok(output)
1486 }
1487 })
1488 }
1489 }
1490
1491 let provider = SchemaFailThenSucceed {
1492 count,
1493 output: default_output(),
1494 };
1495 let result = Agent::new()
1496 .prompt("test")
1497 .retry_policy(crate::retry::RetryPolicy::new(1).backoff(Duration::from_millis(1)))
1498 .run(&provider)
1499 .await;
1500
1501 assert!(result.is_ok());
1502 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1503 }
1504
1505 #[tokio::test]
1506 async fn auto_retry_applied_when_json_schema_set() {
1507 let call_count = Arc::new(AtomicU32::new(0));
1508 let count = call_count.clone();
1509
1510 struct AlwaysSchemaFail {
1511 count: Arc<AtomicU32>,
1512 }
1513 impl AgentProvider for AlwaysSchemaFail {
1514 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1515 self.count.fetch_add(1, Ordering::SeqCst);
1516 Box::pin(async move {
1517 Err(AgentError::SchemaValidation {
1518 expected: "object".to_string(),
1519 got: "null".to_string(),
1520 debug_messages: Vec::new(),
1521 partial_usage: Box::default(),
1522 raw_response: None,
1523 })
1524 })
1525 }
1526 }
1527
1528 let provider = AlwaysSchemaFail { count };
1529 let result = Agent::new()
1530 .prompt("test")
1531 .output_schema_raw(r#"{"type":"object"}"#)
1532 .run(&provider)
1533 .await;
1534
1535 assert!(result.is_err());
1536 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1538 }
1539
1540 #[tokio::test]
1541 async fn no_retry_without_policy() {
1542 let provider = FailNTimesProvider {
1543 fail_count: AtomicU32::new(0),
1544 failures_before_success: 1,
1545 output: default_output(),
1546 };
1547 let result = Agent::new().prompt("test").run(&provider).await;
1548
1549 assert!(result.is_err());
1550 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1551 }
1552
1553 use crate::test_support::VecSink;
1556
1557 struct SinkCapture {
1558 output: AgentOutput,
1559 saw_logs: Arc<AtomicU32>,
1560 }
1561
1562 impl AgentProvider for SinkCapture {
1563 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1564 Box::pin(async {
1565 Ok(AgentOutput {
1566 value: self.output.value.clone(),
1567 session_id: self.output.session_id.clone(),
1568 cost_usd: self.output.cost_usd,
1569 input_tokens: self.output.input_tokens,
1570 output_tokens: self.output.output_tokens,
1571 model: self.output.model.clone(),
1572 duration_ms: self.output.duration_ms,
1573 debug_messages: None,
1574 })
1575 })
1576 }
1577
1578 fn invoke_with_logs<'a>(
1579 &'a self,
1580 config: &'a AgentConfig,
1581 log_sink: Arc<dyn LogSink>,
1582 ) -> InvokeFuture<'a> {
1583 self.saw_logs.fetch_add(1, Ordering::SeqCst);
1584 log_sink.log("stdout", "streaming line");
1585 self.invoke(config)
1586 }
1587 }
1588
1589 #[tokio::test]
1590 async fn log_sink_routes_to_invoke_with_logs() {
1591 let saw_logs = Arc::new(AtomicU32::new(0));
1592 let provider = SinkCapture {
1593 output: default_output(),
1594 saw_logs: saw_logs.clone(),
1595 };
1596 let sink: Arc<dyn LogSink> = VecSink::new();
1597
1598 let result = Agent::new()
1599 .prompt("test")
1600 .log_sink(sink)
1601 .run(&provider)
1602 .await;
1603
1604 assert!(result.is_ok());
1605 assert_eq!(saw_logs.load(Ordering::SeqCst), 1);
1606 }
1607
1608 #[tokio::test]
1609 async fn no_log_sink_routes_to_invoke() {
1610 let saw_logs = Arc::new(AtomicU32::new(0));
1611 let provider = SinkCapture {
1612 output: default_output(),
1613 saw_logs: saw_logs.clone(),
1614 };
1615
1616 let result = Agent::new().prompt("test").run(&provider).await;
1617
1618 assert!(result.is_ok());
1619 assert_eq!(saw_logs.load(Ordering::SeqCst), 0);
1620 }
1621
1622 #[tokio::test]
1623 async fn log_sink_receives_provider_lines() {
1624 let saw_logs = Arc::new(AtomicU32::new(0));
1625 let provider = SinkCapture {
1626 output: default_output(),
1627 saw_logs: saw_logs.clone(),
1628 };
1629 let sink = VecSink::new();
1630
1631 let _ = Agent::new()
1632 .prompt("test")
1633 .log_sink(sink.clone() as Arc<dyn LogSink>)
1634 .run(&provider)
1635 .await;
1636
1637 let lines = sink.0.lock().unwrap();
1638 assert_eq!(lines.len(), 1);
1639 assert_eq!(lines[0].0, "stdout");
1640 assert_eq!(lines[0].1, "streaming line");
1641 }
1642}