1use std::any;
29use std::sync::Arc;
30
31use schemars::{JsonSchema, schema_for};
32use serde::de::DeserializeOwned;
33use serde::{Deserialize, Serialize};
34use serde_json::{Value, from_value, to_string};
35use tokio::time;
36use tracing::{info, warn};
37
38use crate::error::OperationError;
39#[cfg(feature = "prometheus")]
40use crate::metric_names;
41use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, LogSink};
42use crate::retry::RetryPolicy;
43
44pub struct Model;
75
76impl Model {
77 pub const SONNET: &str = "sonnet";
81 pub const OPUS: &str = "opus";
83 pub const HAIKU: &str = "haiku";
85
86 pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
90
91 pub const SONNET_46: &str = "claude-sonnet-4-6";
95 pub const OPUS_46: &str = "claude-opus-4-6";
97
98 pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
102 pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
104
105 pub const OPUS_47: &str = "claude-opus-4-7";
109 pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
111}
112
113#[derive(Debug, Default, Clone, Copy, Serialize)]
118pub enum PermissionMode {
119 #[default]
121 Default,
122 Auto,
124 DontAsk,
126 BypassPermissions,
130}
131
132impl<'de> Deserialize<'de> for PermissionMode {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: serde::Deserializer<'de>,
136 {
137 let s = String::deserialize(deserializer)?;
138 Ok(match s.to_lowercase().replace('_', "").as_str() {
139 "auto" => Self::Auto,
140 "dontask" => Self::DontAsk,
141 "bypass" | "bypasspermissions" => Self::BypassPermissions,
142 _ => Self::Default,
143 })
144 }
145}
146
147#[must_use = "an Agent does nothing until .run() is awaited"]
177pub struct Agent {
178 config: AgentConfig,
179 dry_run: Option<bool>,
180 retry_policy: Option<RetryPolicy>,
181 log_sink: Option<Arc<dyn LogSink>>,
182}
183
184impl Agent {
185 pub fn new() -> Self {
190 Self {
191 config: AgentConfig::new(""),
192 dry_run: None,
193 retry_policy: None,
194 log_sink: None,
195 }
196 }
197
198 pub fn from_config(config: impl Into<AgentConfig>) -> Self {
217 Self {
218 config: config.into(),
219 dry_run: None,
220 retry_policy: None,
221 log_sink: None,
222 }
223 }
224
225 pub fn system_prompt(mut self, prompt: &str) -> Self {
227 self.config.system_prompt = Some(prompt.to_string());
228 self
229 }
230
231 pub fn prompt(mut self, prompt: &str) -> Self {
233 self.config.prompt = prompt.to_string();
234 self
235 }
236
237 pub fn model(mut self, model: impl Into<String>) -> Self {
244 self.config.model = model.into();
245 self
246 }
247
248 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
253 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
254 self
255 }
256
257 pub fn max_turns(mut self, turns: u32) -> Self {
263 assert!(turns > 0, "max_turns must be greater than 0");
264 self.config.max_turns = Some(turns);
265 self
266 }
267
268 pub fn max_budget_usd(mut self, budget: f64) -> Self {
274 assert!(
275 budget.is_finite() && budget > 0.0,
276 "budget must be a positive finite number, got {budget}"
277 );
278 self.config.max_budget_usd = Some(budget);
279 self
280 }
281
282 pub fn working_dir(mut self, dir: &str) -> Self {
284 self.config.working_dir = Some(dir.to_string());
285 self
286 }
287
288 pub fn mcp_config(mut self, config: &str) -> Self {
290 self.config.mcp_config = Some(config.to_string());
291 self
292 }
293
294 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
298 self.config.permission_mode = mode;
299 self
300 }
301
302 pub fn output<T: JsonSchema>(mut self) -> Self {
333 let schema = schema_for!(T);
334 self.config.json_schema = match to_string(&schema) {
335 Ok(s) => Some(s),
336 Err(e) => {
337 warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
338 None
339 }
340 };
341 self
342 }
343
344 pub fn output_schema_raw(mut self, schema: &str) -> Self {
368 self.config.json_schema = Some(schema.to_string());
369 self
370 }
371
372 pub fn retry(mut self, max_retries: u32) -> Self {
400 self.retry_policy = Some(RetryPolicy::new(max_retries));
401 self
402 }
403
404 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
431 self.retry_policy = Some(policy);
432 self
433 }
434
435 pub fn dry_run(mut self, enabled: bool) -> Self {
444 self.dry_run = Some(enabled);
445 self
446 }
447
448 pub fn log_sink(mut self, sink: Arc<dyn LogSink>) -> Self {
475 self.log_sink = Some(sink);
476 self
477 }
478
479 pub fn verbose(mut self) -> Self {
509 self.config.verbose = true;
510 self
511 }
512
513 pub fn resume(mut self, session_id: &str) -> Self {
549 assert!(!session_id.is_empty(), "session_id must not be empty");
550 assert!(
551 session_id
552 .chars()
553 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
554 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
555 );
556 self.config.resume_session_id = Some(session_id.to_string());
557 self
558 }
559
560 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
581 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
582 assert!(
583 !self.config.prompt.trim().is_empty(),
584 "prompt must not be empty - call .prompt(\"...\") before .run()"
585 );
586
587 if crate::dry_run::effective_dry_run(self.dry_run) {
588 info!(
589 prompt_len = self.config.prompt.len(),
590 "[dry-run] agent call skipped"
591 );
592 let mut output =
593 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
594 output.cost_usd = Some(0.0);
595 output.input_tokens = Some(0);
596 output.output_tokens = Some(0);
597 return Ok(AgentResult { output });
598 }
599
600 let result = self.invoke_once(provider).await;
601
602 let default_schema_retry = RetryPolicy::new(2);
603 let policy = match &self.retry_policy {
604 Some(p) => p,
605 None if self.config.json_schema.is_some() => &default_schema_retry,
606 None => return result,
607 };
608
609 if let Err(ref err) = result {
611 if !crate::retry::is_retryable(err) {
612 return result;
613 }
614 } else {
615 return result;
616 }
617
618 let mut last_result = result;
619
620 for attempt in 0..policy.max_retries {
621 let delay = policy.delay_for_attempt(attempt);
622 let retry_reason = if matches!(
623 &last_result,
624 Err(OperationError::Agent(
625 crate::error::AgentError::SchemaValidation { .. }
626 ))
627 ) {
628 "structured_output was null (CLI non-determinism)"
629 } else {
630 "transient failure"
631 };
632 warn!(
633 attempt = attempt + 1,
634 max_retries = policy.max_retries,
635 delay_ms = delay.as_millis() as u64,
636 reason = retry_reason,
637 "retrying agent invocation"
638 );
639 time::sleep(delay).await;
640
641 last_result = self.invoke_once(provider).await;
642
643 match &last_result {
644 Ok(_) => return last_result,
645 Err(err) if !crate::retry::is_retryable(err) => return last_result,
646 _ => {}
647 }
648 }
649
650 last_result
651 }
652
653 async fn invoke_once(
655 &self,
656 provider: &dyn AgentProvider,
657 ) -> Result<AgentResult, OperationError> {
658 #[cfg(feature = "prometheus")]
659 let model_label = self.config.model.to_string();
660
661 let invoke_result = match self.log_sink {
662 Some(ref sink) => provider.invoke_with_logs(&self.config, sink.clone()).await,
663 None => provider.invoke(&self.config).await,
664 };
665 let output = match invoke_result {
666 Ok(output) => output,
667 Err(e) => {
668 #[cfg(feature = "prometheus")]
669 {
670 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
671 }
672 return Err(OperationError::Agent(e));
673 }
674 };
675
676 info!(
677 duration_ms = output.duration_ms,
678 cost_usd = output.cost_usd,
679 input_tokens = output.input_tokens,
680 output_tokens = output.output_tokens,
681 model = output.model,
682 "agent completed"
683 );
684
685 #[cfg(feature = "prometheus")]
686 {
687 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
688 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
689 .record(output.duration_ms as f64 / 1000.0);
690 if let Some(cost) = output.cost_usd {
691 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
692 .increment(cost);
693 }
694 if let Some(tokens) = output.input_tokens {
695 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
696 }
697 if let Some(tokens) = output.output_tokens {
698 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
699 .increment(tokens);
700 }
701 }
702
703 Ok(AgentResult { output })
704 }
705}
706
707impl Default for Agent {
708 fn default() -> Self {
709 Self::new()
710 }
711}
712
713#[derive(Debug)]
718pub struct AgentResult {
719 output: AgentOutput,
720}
721
722impl AgentResult {
723 pub fn text(&self) -> &str {
728 match self.output.value.as_str() {
729 Some(s) => s,
730 None => {
731 warn!(
732 value_type = self.output.value.to_string(),
733 "agent output is not a string, returning empty"
734 );
735 ""
736 }
737 }
738 }
739
740 pub fn value(&self) -> &Value {
742 &self.output.value
743 }
744
745 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
755 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
756 }
757
758 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
764 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
765 }
766
767 #[cfg(test)]
772 pub(crate) fn from_output(output: AgentOutput) -> Self {
773 Self { output }
774 }
775
776 pub fn session_id(&self) -> Option<&str> {
778 self.output.session_id.as_deref()
779 }
780
781 pub fn cost_usd(&self) -> Option<f64> {
783 self.output.cost_usd
784 }
785
786 pub fn input_tokens(&self) -> Option<u64> {
788 self.output.input_tokens
789 }
790
791 pub fn output_tokens(&self) -> Option<u64> {
793 self.output.output_tokens
794 }
795
796 pub fn duration_ms(&self) -> u64 {
798 self.output.duration_ms
799 }
800
801 pub fn model(&self) -> Option<&str> {
803 self.output.model.as_deref()
804 }
805
806 pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
812 self.output.debug_messages.as_deref()
813 }
814}
815
816#[cfg(test)]
817mod tests {
818 use super::*;
819 use crate::error::AgentError;
820 use crate::provider::InvokeFuture;
821 use serde_json::json;
822
823 struct TestProvider {
824 output: AgentOutput,
825 }
826
827 impl AgentProvider for TestProvider {
828 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
829 Box::pin(async move {
830 Ok(AgentOutput {
831 value: self.output.value.clone(),
832 session_id: self.output.session_id.clone(),
833 cost_usd: self.output.cost_usd,
834 input_tokens: self.output.input_tokens,
835 output_tokens: self.output.output_tokens,
836 model: self.output.model.clone(),
837 duration_ms: self.output.duration_ms,
838 debug_messages: None,
839 })
840 })
841 }
842 }
843
844 struct ConfigCapture {
845 output: AgentOutput,
846 }
847
848 impl AgentProvider for ConfigCapture {
849 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
850 let config_json = serde_json::to_value(config).unwrap();
851 Box::pin(async move {
852 Ok(AgentOutput {
853 value: config_json,
854 session_id: self.output.session_id.clone(),
855 cost_usd: self.output.cost_usd,
856 input_tokens: self.output.input_tokens,
857 output_tokens: self.output.output_tokens,
858 model: self.output.model.clone(),
859 duration_ms: self.output.duration_ms,
860 debug_messages: None,
861 })
862 })
863 }
864 }
865
866 fn default_output() -> AgentOutput {
867 AgentOutput {
868 value: json!("test output"),
869 session_id: Some("sess-123".to_string()),
870 cost_usd: Some(0.05),
871 input_tokens: Some(100),
872 output_tokens: Some(50),
873 model: Some("sonnet".to_string()),
874 duration_ms: 1500,
875 debug_messages: None,
876 }
877 }
878
879 #[test]
882 fn model_constants_have_expected_values() {
883 assert_eq!(Model::SONNET, "sonnet");
884 assert_eq!(Model::OPUS, "opus");
885 assert_eq!(Model::HAIKU, "haiku");
886 assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
887 assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
888 assert_eq!(Model::OPUS_46, "claude-opus-4-6");
889 assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
890 assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
891 assert_eq!(Model::OPUS_47, "claude-opus-4-7");
892 assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
893 }
894
895 #[tokio::test]
898 async fn agent_new_default_values() {
899 let provider = ConfigCapture {
900 output: default_output(),
901 };
902 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
903
904 let config = result.value();
905 assert_eq!(config["system_prompt"], json!(null));
906 assert_eq!(config["prompt"], json!("hi"));
907 assert_eq!(config["model"], json!("sonnet"));
908 assert_eq!(config["allowed_tools"], json!([]));
909 assert_eq!(config["max_turns"], json!(null));
910 assert_eq!(config["max_budget_usd"], json!(null));
911 assert_eq!(config["working_dir"], json!(null));
912 assert_eq!(config["mcp_config"], json!(null));
913 assert_eq!(config["permission_mode"], json!("Default"));
914 assert_eq!(config["json_schema"], json!(null));
915 }
916
917 #[tokio::test]
918 async fn agent_default_matches_new() {
919 let provider = ConfigCapture {
920 output: default_output(),
921 };
922 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
923 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
924
925 assert_eq!(result_new.value(), result_default.value());
926 }
927
928 #[tokio::test]
931 async fn builder_methods_store_values_correctly() {
932 let provider = ConfigCapture {
933 output: default_output(),
934 };
935 let result = Agent::new()
936 .system_prompt("you are a bot")
937 .prompt("do something")
938 .model(Model::OPUS)
939 .allowed_tools(&["Read", "Write"])
940 .max_turns(5)
941 .max_budget_usd(1.5)
942 .working_dir("/tmp")
943 .mcp_config("{}")
944 .permission_mode(PermissionMode::Auto)
945 .run(&provider)
946 .await
947 .unwrap();
948
949 let config = result.value();
950 assert_eq!(config["system_prompt"], json!("you are a bot"));
951 assert_eq!(config["prompt"], json!("do something"));
952 assert_eq!(config["model"], json!("opus"));
953 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
954 assert_eq!(config["max_turns"], json!(5));
955 assert_eq!(config["max_budget_usd"], json!(1.5));
956 assert_eq!(config["working_dir"], json!("/tmp"));
957 assert_eq!(config["mcp_config"], json!("{}"));
958 assert_eq!(config["permission_mode"], json!("Auto"));
959 }
960
961 #[test]
964 #[should_panic(expected = "max_turns must be greater than 0")]
965 fn max_turns_zero_panics() {
966 let _ = Agent::new().max_turns(0);
967 }
968
969 #[test]
970 #[should_panic(expected = "budget must be a positive finite number")]
971 fn max_budget_negative_panics() {
972 let _ = Agent::new().max_budget_usd(-1.0);
973 }
974
975 #[test]
976 #[should_panic(expected = "budget must be a positive finite number")]
977 fn max_budget_nan_panics() {
978 let _ = Agent::new().max_budget_usd(f64::NAN);
979 }
980
981 #[test]
982 #[should_panic(expected = "budget must be a positive finite number")]
983 fn max_budget_infinity_panics() {
984 let _ = Agent::new().max_budget_usd(f64::INFINITY);
985 }
986
987 #[tokio::test]
990 async fn agent_result_text_with_string_value() {
991 let provider = TestProvider {
992 output: AgentOutput {
993 value: json!("hello world"),
994 ..default_output()
995 },
996 };
997 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
998 assert_eq!(result.text(), "hello world");
999 }
1000
1001 #[tokio::test]
1002 async fn agent_result_text_with_non_string_value() {
1003 let provider = TestProvider {
1004 output: AgentOutput {
1005 value: json!(42),
1006 ..default_output()
1007 },
1008 };
1009 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1010 assert_eq!(result.text(), "");
1011 }
1012
1013 #[tokio::test]
1014 async fn agent_result_text_with_null_value() {
1015 let provider = TestProvider {
1016 output: AgentOutput {
1017 value: json!(null),
1018 ..default_output()
1019 },
1020 };
1021 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1022 assert_eq!(result.text(), "");
1023 }
1024
1025 #[tokio::test]
1026 async fn agent_result_json_successful_deserialize() {
1027 #[derive(Deserialize, PartialEq, Debug)]
1028 struct MyOutput {
1029 name: String,
1030 count: u32,
1031 }
1032 let provider = TestProvider {
1033 output: AgentOutput {
1034 value: json!({"name": "test", "count": 7}),
1035 ..default_output()
1036 },
1037 };
1038 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1039 let parsed: MyOutput = result.json().unwrap();
1040 assert_eq!(parsed.name, "test");
1041 assert_eq!(parsed.count, 7);
1042 }
1043
1044 #[tokio::test]
1045 async fn agent_result_json_failed_deserialize() {
1046 #[derive(Debug, Deserialize)]
1047 #[allow(dead_code)]
1048 struct MyOutput {
1049 name: String,
1050 }
1051 let provider = TestProvider {
1052 output: AgentOutput {
1053 value: json!(42),
1054 ..default_output()
1055 },
1056 };
1057 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1058 let err = result.json::<MyOutput>().unwrap_err();
1059 assert!(matches!(err, OperationError::Deserialize { .. }));
1060 }
1061
1062 #[tokio::test]
1063 async fn agent_result_accessors() {
1064 let provider = TestProvider {
1065 output: AgentOutput {
1066 value: json!("v"),
1067 session_id: Some("s-1".to_string()),
1068 cost_usd: Some(0.123),
1069 input_tokens: Some(999),
1070 output_tokens: Some(456),
1071 model: Some("opus".to_string()),
1072 duration_ms: 2000,
1073 debug_messages: None,
1074 },
1075 };
1076 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1077 assert_eq!(result.session_id(), Some("s-1"));
1078 assert_eq!(result.cost_usd(), Some(0.123));
1079 assert_eq!(result.input_tokens(), Some(999));
1080 assert_eq!(result.output_tokens(), Some(456));
1081 assert_eq!(result.duration_ms(), 2000);
1082 assert_eq!(result.model(), Some("opus"));
1083 }
1084
1085 #[tokio::test]
1088 async fn resume_passes_session_id_in_config() {
1089 let provider = ConfigCapture {
1090 output: default_output(),
1091 };
1092 let result = Agent::new()
1093 .prompt("followup")
1094 .resume("sess-abc")
1095 .run(&provider)
1096 .await
1097 .unwrap();
1098
1099 let config = result.value();
1100 assert_eq!(config["resume_session_id"], json!("sess-abc"));
1101 }
1102
1103 #[tokio::test]
1104 async fn no_resume_has_null_session_id() {
1105 let provider = ConfigCapture {
1106 output: default_output(),
1107 };
1108 let result = Agent::new()
1109 .prompt("first call")
1110 .run(&provider)
1111 .await
1112 .unwrap();
1113
1114 let config = result.value();
1115 assert_eq!(config["resume_session_id"], json!(null));
1116 }
1117
1118 #[test]
1119 #[should_panic(expected = "session_id must not be empty")]
1120 fn resume_empty_session_id_panics() {
1121 let _ = Agent::new().resume("");
1122 }
1123
1124 #[test]
1125 #[should_panic(expected = "session_id must only contain")]
1126 fn resume_invalid_chars_panics() {
1127 let _ = Agent::new().resume("sess;rm -rf /");
1128 }
1129
1130 #[test]
1131 fn resume_valid_formats_accepted() {
1132 let _ = Agent::new().resume("sess-abc123");
1133 let _ = Agent::new().resume("a1b2c3d4_session");
1134 let _ = Agent::new().resume("abc-DEF-123_456");
1135 }
1136
1137 #[tokio::test]
1138 #[should_panic(expected = "prompt must not be empty")]
1139 async fn run_without_prompt_panics() {
1140 let provider = TestProvider {
1141 output: default_output(),
1142 };
1143 let _ = Agent::new().run(&provider).await;
1144 }
1145
1146 #[tokio::test]
1147 #[should_panic(expected = "prompt must not be empty")]
1148 async fn run_with_whitespace_only_prompt_panics() {
1149 let provider = TestProvider {
1150 output: default_output(),
1151 };
1152 let _ = Agent::new().prompt(" ").run(&provider).await;
1153 }
1154
1155 #[tokio::test]
1158 async fn model_accepts_custom_string() {
1159 let provider = ConfigCapture {
1160 output: default_output(),
1161 };
1162 let result = Agent::new()
1163 .prompt("hi")
1164 .model("mistral-large-latest")
1165 .run(&provider)
1166 .await
1167 .unwrap();
1168 assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1169 }
1170
1171 #[tokio::test]
1172 async fn verbose_sets_config_flag() {
1173 let provider = ConfigCapture {
1174 output: default_output(),
1175 };
1176 let result = Agent::new()
1177 .prompt("hi")
1178 .verbose()
1179 .run(&provider)
1180 .await
1181 .unwrap();
1182 assert_eq!(result.value()["verbose"], json!(true));
1183 }
1184
1185 #[tokio::test]
1186 async fn verbose_not_set_by_default() {
1187 let provider = ConfigCapture {
1188 output: default_output(),
1189 };
1190 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1191 assert_eq!(result.value()["verbose"], json!(false));
1192 }
1193
1194 #[tokio::test]
1195 async fn debug_messages_none_without_verbose() {
1196 let provider = TestProvider {
1197 output: default_output(),
1198 };
1199 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1200 assert!(result.debug_messages().is_none());
1201 }
1202
1203 #[tokio::test]
1204 async fn model_accepts_owned_string() {
1205 let provider = ConfigCapture {
1206 output: default_output(),
1207 };
1208 let model_name = String::from("gpt-4o");
1209 let result = Agent::new()
1210 .prompt("hi")
1211 .model(model_name)
1212 .run(&provider)
1213 .await
1214 .unwrap();
1215 assert_eq!(result.value()["model"], json!("gpt-4o"));
1216 }
1217
1218 #[tokio::test]
1219 async fn into_json_success() {
1220 #[derive(Deserialize, PartialEq, Debug)]
1221 struct Out {
1222 name: String,
1223 }
1224 let provider = TestProvider {
1225 output: AgentOutput {
1226 value: json!({"name": "test"}),
1227 ..default_output()
1228 },
1229 };
1230 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1231 let parsed: Out = result.into_json().unwrap();
1232 assert_eq!(parsed.name, "test");
1233 }
1234
1235 #[tokio::test]
1236 async fn into_json_failure() {
1237 #[derive(Debug, Deserialize)]
1238 #[allow(dead_code)]
1239 struct Out {
1240 name: String,
1241 }
1242 let provider = TestProvider {
1243 output: AgentOutput {
1244 value: json!(42),
1245 ..default_output()
1246 },
1247 };
1248 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1249 let err = result.into_json::<Out>().unwrap_err();
1250 assert!(matches!(err, OperationError::Deserialize { .. }));
1251 }
1252
1253 #[test]
1254 fn from_output_creates_result() {
1255 let output = AgentOutput {
1256 value: json!("hello"),
1257 ..default_output()
1258 };
1259 let result = AgentResult::from_output(output);
1260 assert_eq!(result.text(), "hello");
1261 assert_eq!(result.cost_usd(), Some(0.05));
1262 }
1263
1264 #[test]
1265 #[should_panic(expected = "budget must be a positive finite number")]
1266 fn max_budget_zero_panics() {
1267 let _ = Agent::new().max_budget_usd(0.0);
1268 }
1269
1270 #[test]
1271 fn model_constant_equality() {
1272 assert_eq!(Model::SONNET, "sonnet");
1273 assert_ne!(Model::SONNET, Model::OPUS);
1274 }
1275
1276 #[test]
1277 fn permission_mode_serialize_deserialize_roundtrip() {
1278 for mode in [
1279 PermissionMode::Default,
1280 PermissionMode::Auto,
1281 PermissionMode::DontAsk,
1282 PermissionMode::BypassPermissions,
1283 ] {
1284 let json = to_string(&mode).unwrap();
1285 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1286 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1287 }
1288 }
1289
1290 #[test]
1293 fn retry_builder_stores_policy() {
1294 let agent = Agent::new().retry(3);
1295 assert!(agent.retry_policy.is_some());
1296 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1297 }
1298
1299 #[test]
1300 fn retry_policy_builder_stores_custom_policy() {
1301 use crate::retry::RetryPolicy;
1302 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1303 let agent = Agent::new().retry_policy(policy);
1304 let p = agent.retry_policy.unwrap();
1305 assert_eq!(p.max_retries(), 5);
1306 }
1307
1308 #[test]
1309 fn no_retry_by_default() {
1310 let agent = Agent::new();
1311 assert!(agent.retry_policy.is_none());
1312 }
1313
1314 use std::sync::Arc;
1317 use std::sync::atomic::{AtomicU32, Ordering};
1318 use std::time::Duration;
1319
1320 struct FailNTimesProvider {
1321 fail_count: AtomicU32,
1322 failures_before_success: u32,
1323 output: AgentOutput,
1324 }
1325
1326 impl AgentProvider for FailNTimesProvider {
1327 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1328 Box::pin(async move {
1329 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1330 if current < self.failures_before_success {
1331 Err(AgentError::ProcessFailed {
1332 exit_code: 1,
1333 stderr: format!("transient failure #{}", current + 1),
1334 })
1335 } else {
1336 Ok(AgentOutput {
1337 value: self.output.value.clone(),
1338 session_id: self.output.session_id.clone(),
1339 cost_usd: self.output.cost_usd,
1340 input_tokens: self.output.input_tokens,
1341 output_tokens: self.output.output_tokens,
1342 model: self.output.model.clone(),
1343 duration_ms: self.output.duration_ms,
1344 debug_messages: None,
1345 })
1346 }
1347 })
1348 }
1349 }
1350
1351 #[tokio::test]
1352 async fn retry_succeeds_after_transient_failures() {
1353 let provider = FailNTimesProvider {
1354 fail_count: AtomicU32::new(0),
1355 failures_before_success: 2,
1356 output: default_output(),
1357 };
1358 let result = Agent::new()
1359 .prompt("test")
1360 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1361 .run(&provider)
1362 .await;
1363
1364 assert!(result.is_ok());
1365 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1367
1368 #[tokio::test]
1369 async fn retry_exhausted_returns_last_error() {
1370 let provider = FailNTimesProvider {
1371 fail_count: AtomicU32::new(0),
1372 failures_before_success: 10, output: default_output(),
1374 };
1375 let result = Agent::new()
1376 .prompt("test")
1377 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1378 .run(&provider)
1379 .await;
1380
1381 assert!(result.is_err());
1382 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1384 }
1385
1386 #[tokio::test]
1387 async fn retry_does_not_retry_prompt_too_large() {
1388 let call_count = Arc::new(AtomicU32::new(0));
1389 let count = call_count.clone();
1390
1391 struct CountingNonRetryable {
1392 count: Arc<AtomicU32>,
1393 }
1394 impl AgentProvider for CountingNonRetryable {
1395 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1396 self.count.fetch_add(1, Ordering::SeqCst);
1397 Box::pin(async move {
1398 Err(AgentError::PromptTooLarge {
1399 chars: 1_000_000,
1400 estimated_tokens: 250_000,
1401 model_limit: 200_000,
1402 })
1403 })
1404 }
1405 }
1406
1407 let provider = CountingNonRetryable { count };
1408 let result = Agent::new()
1409 .prompt("test")
1410 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1411 .run(&provider)
1412 .await;
1413
1414 assert!(result.is_err());
1415 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1416 }
1417
1418 #[tokio::test]
1419 async fn retry_retries_schema_validation_errors() {
1420 let call_count = Arc::new(AtomicU32::new(0));
1421 let count = call_count.clone();
1422
1423 struct SchemaFailProvider {
1424 count: Arc<AtomicU32>,
1425 }
1426 impl AgentProvider for SchemaFailProvider {
1427 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1428 self.count.fetch_add(1, Ordering::SeqCst);
1429 Box::pin(async move {
1430 Err(AgentError::SchemaValidation {
1431 expected: "object".to_string(),
1432 got: "null".to_string(),
1433 debug_messages: Vec::new(),
1434 partial_usage: Box::default(),
1435 raw_response: None,
1436 })
1437 })
1438 }
1439 }
1440
1441 let provider = SchemaFailProvider { count };
1442 let result = Agent::new()
1443 .prompt("test")
1444 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1445 .run(&provider)
1446 .await;
1447
1448 assert!(result.is_err());
1449 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1451 }
1452
1453 #[tokio::test]
1454 async fn schema_validation_succeeds_on_retry() {
1455 let call_count = Arc::new(AtomicU32::new(0));
1456 let count = call_count.clone();
1457
1458 struct SchemaFailThenSucceed {
1459 count: Arc<AtomicU32>,
1460 output: AgentOutput,
1461 }
1462 impl AgentProvider for SchemaFailThenSucceed {
1463 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1464 let current = self.count.fetch_add(1, Ordering::SeqCst);
1465 let output = self.output.clone();
1466 Box::pin(async move {
1467 if current == 0 {
1468 Err(AgentError::SchemaValidation {
1469 expected: "structured_output field".to_string(),
1470 got: "null".to_string(),
1471 debug_messages: Vec::new(),
1472 partial_usage: Box::default(),
1473 raw_response: None,
1474 })
1475 } else {
1476 Ok(output)
1477 }
1478 })
1479 }
1480 }
1481
1482 let provider = SchemaFailThenSucceed {
1483 count,
1484 output: default_output(),
1485 };
1486 let result = Agent::new()
1487 .prompt("test")
1488 .retry_policy(crate::retry::RetryPolicy::new(1).backoff(Duration::from_millis(1)))
1489 .run(&provider)
1490 .await;
1491
1492 assert!(result.is_ok());
1493 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1494 }
1495
1496 #[tokio::test]
1497 async fn auto_retry_applied_when_json_schema_set() {
1498 let call_count = Arc::new(AtomicU32::new(0));
1499 let count = call_count.clone();
1500
1501 struct AlwaysSchemaFail {
1502 count: Arc<AtomicU32>,
1503 }
1504 impl AgentProvider for AlwaysSchemaFail {
1505 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1506 self.count.fetch_add(1, Ordering::SeqCst);
1507 Box::pin(async move {
1508 Err(AgentError::SchemaValidation {
1509 expected: "object".to_string(),
1510 got: "null".to_string(),
1511 debug_messages: Vec::new(),
1512 partial_usage: Box::default(),
1513 raw_response: None,
1514 })
1515 })
1516 }
1517 }
1518
1519 let provider = AlwaysSchemaFail { count };
1520 let result = Agent::new()
1521 .prompt("test")
1522 .output_schema_raw(r#"{"type":"object"}"#)
1523 .run(&provider)
1524 .await;
1525
1526 assert!(result.is_err());
1527 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1529 }
1530
1531 #[tokio::test]
1532 async fn no_retry_without_policy() {
1533 let provider = FailNTimesProvider {
1534 fail_count: AtomicU32::new(0),
1535 failures_before_success: 1,
1536 output: default_output(),
1537 };
1538 let result = Agent::new().prompt("test").run(&provider).await;
1539
1540 assert!(result.is_err());
1541 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1542 }
1543
1544 use crate::test_support::VecSink;
1547
1548 struct SinkCapture {
1549 output: AgentOutput,
1550 saw_logs: Arc<AtomicU32>,
1551 }
1552
1553 impl AgentProvider for SinkCapture {
1554 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1555 Box::pin(async {
1556 Ok(AgentOutput {
1557 value: self.output.value.clone(),
1558 session_id: self.output.session_id.clone(),
1559 cost_usd: self.output.cost_usd,
1560 input_tokens: self.output.input_tokens,
1561 output_tokens: self.output.output_tokens,
1562 model: self.output.model.clone(),
1563 duration_ms: self.output.duration_ms,
1564 debug_messages: None,
1565 })
1566 })
1567 }
1568
1569 fn invoke_with_logs<'a>(
1570 &'a self,
1571 config: &'a AgentConfig,
1572 log_sink: Arc<dyn LogSink>,
1573 ) -> InvokeFuture<'a> {
1574 self.saw_logs.fetch_add(1, Ordering::SeqCst);
1575 log_sink.log("stdout", "streaming line");
1576 self.invoke(config)
1577 }
1578 }
1579
1580 #[tokio::test]
1581 async fn log_sink_routes_to_invoke_with_logs() {
1582 let saw_logs = Arc::new(AtomicU32::new(0));
1583 let provider = SinkCapture {
1584 output: default_output(),
1585 saw_logs: saw_logs.clone(),
1586 };
1587 let sink: Arc<dyn LogSink> = VecSink::new();
1588
1589 let result = Agent::new()
1590 .prompt("test")
1591 .log_sink(sink)
1592 .run(&provider)
1593 .await;
1594
1595 assert!(result.is_ok());
1596 assert_eq!(saw_logs.load(Ordering::SeqCst), 1);
1597 }
1598
1599 #[tokio::test]
1600 async fn no_log_sink_routes_to_invoke() {
1601 let saw_logs = Arc::new(AtomicU32::new(0));
1602 let provider = SinkCapture {
1603 output: default_output(),
1604 saw_logs: saw_logs.clone(),
1605 };
1606
1607 let result = Agent::new().prompt("test").run(&provider).await;
1608
1609 assert!(result.is_ok());
1610 assert_eq!(saw_logs.load(Ordering::SeqCst), 0);
1611 }
1612
1613 #[tokio::test]
1614 async fn log_sink_receives_provider_lines() {
1615 let saw_logs = Arc::new(AtomicU32::new(0));
1616 let provider = SinkCapture {
1617 output: default_output(),
1618 saw_logs: saw_logs.clone(),
1619 };
1620 let sink = VecSink::new();
1621
1622 let _ = Agent::new()
1623 .prompt("test")
1624 .log_sink(sink.clone() as Arc<dyn LogSink>)
1625 .run(&provider)
1626 .await;
1627
1628 let lines = sink.0.lock().unwrap();
1629 assert_eq!(lines.len(), 1);
1630 assert_eq!(lines[0].0, "stdout");
1631 assert_eq!(lines[0].1, "streaming line");
1632 }
1633}