1use std::fmt::{self, Display};
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tracing::{info, warn};
35
36use crate::error::{AgentError, OperationError};
37#[cfg(feature = "prometheus")]
38use crate::metric_names;
39use crate::provider::{AgentConfig, AgentOutput, AgentProvider};
40use crate::retry::RetryPolicy;
41use crate::utils::estimate_tokens;
42
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(rename_all = "lowercase")]
50pub enum Model {
51 Sonnet,
54 Opus,
56 Haiku,
58
59 #[serde(rename = "claude-haiku-4-5-20251001")]
62 Haiku45,
63
64 #[serde(rename = "claude-sonnet-4-6")]
67 Sonnet46,
68 #[serde(rename = "claude-opus-4-6")]
70 Opus46,
71
72 #[serde(rename = "claude-sonnet-4-6[1m]")]
75 Sonnet46_1M,
76 #[serde(rename = "claude-opus-4-6[1m]")]
78 Opus46_1M,
79}
80
81impl Model {
82 pub fn context_window(self) -> usize {
84 match self {
85 Self::Sonnet46_1M | Self::Opus46_1M => 1_000_000,
86 _ => 200_000,
87 }
88 }
89}
90
91impl Display for Model {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match serde_json::to_value(self) {
94 Ok(Value::String(s)) => f.write_str(&s),
95 _ => f.write_str("unknown"),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
105pub enum PermissionMode {
106 Default,
108 Auto,
110 DontAsk,
112 BypassPermissions,
116}
117
118#[must_use = "an Agent does nothing until .run() is awaited"]
148pub struct Agent {
149 config: AgentConfig,
150 dry_run: Option<bool>,
151 retry_policy: Option<RetryPolicy>,
152}
153
154impl Agent {
155 pub fn new() -> Self {
160 Self {
161 config: AgentConfig::new(""),
162 dry_run: None,
163 retry_policy: None,
164 }
165 }
166
167 pub fn system_prompt(mut self, prompt: &str) -> Self {
169 self.config.system_prompt = Some(prompt.to_string());
170 self
171 }
172
173 pub fn prompt(mut self, prompt: &str) -> Self {
175 self.config.prompt = prompt.to_string();
176 self
177 }
178
179 pub fn model(mut self, model: Model) -> Self {
183 self.config.model = model;
184 self
185 }
186
187 pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
192 self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
193 self
194 }
195
196 pub fn max_turns(mut self, turns: u32) -> Self {
202 assert!(turns > 0, "max_turns must be greater than 0");
203 self.config.max_turns = Some(turns);
204 self
205 }
206
207 pub fn max_budget_usd(mut self, budget: f64) -> Self {
213 assert!(
214 budget.is_finite() && budget > 0.0,
215 "budget must be a positive finite number, got {budget}"
216 );
217 self.config.max_budget_usd = Some(budget);
218 self
219 }
220
221 pub fn working_dir(mut self, dir: &str) -> Self {
223 self.config.working_dir = Some(dir.to_string());
224 self
225 }
226
227 pub fn mcp_config(mut self, config: &str) -> Self {
229 self.config.mcp_config = Some(config.to_string());
230 self
231 }
232
233 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
237 self.config.permission_mode = mode;
238 self
239 }
240
241 pub fn output<T: JsonSchema>(mut self) -> Self {
272 let schema = schema_for!(T);
273 self.config.json_schema = match to_string(&schema) {
274 Ok(s) => Some(s),
275 Err(e) => {
276 warn!(error = %e, type_name = std::any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
277 None
278 }
279 };
280 self
281 }
282
283 pub fn retry(mut self, max_retries: u32) -> Self {
311 self.retry_policy = Some(RetryPolicy::new(max_retries));
312 self
313 }
314
315 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
342 self.retry_policy = Some(policy);
343 self
344 }
345
346 pub fn dry_run(mut self, enabled: bool) -> Self {
355 self.dry_run = Some(enabled);
356 self
357 }
358
359 pub fn resume(mut self, session_id: &str) -> Self {
395 assert!(!session_id.is_empty(), "session_id must not be empty");
396 assert!(
397 session_id
398 .chars()
399 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
400 "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
401 );
402 self.config.resume_session_id = Some(session_id.to_string());
403 self
404 }
405
406 #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
423 pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
424 assert!(
425 !self.config.prompt.trim().is_empty(),
426 "prompt must not be empty - call .prompt(\"...\") before .run()"
427 );
428
429 let total_chars =
431 self.config.prompt.len() + self.config.system_prompt.as_ref().map_or(0, |s| s.len());
432 let estimated_tokens = estimate_tokens(total_chars);
433 let model_limit = self.config.model.context_window();
434 if estimated_tokens > model_limit {
435 return Err(OperationError::Agent(AgentError::PromptTooLarge {
436 chars: total_chars,
437 estimated_tokens,
438 model_limit,
439 }));
440 }
441
442 if crate::dry_run::effective_dry_run(self.dry_run) {
443 info!(
444 prompt_len = self.config.prompt.len(),
445 "[dry-run] agent call skipped"
446 );
447 let mut output =
448 AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
449 output.cost_usd = Some(0.0);
450 output.input_tokens = Some(0);
451 output.output_tokens = Some(0);
452 return Ok(AgentResult { output });
453 }
454
455 let result = self.invoke_once(provider).await;
456
457 let policy = match &self.retry_policy {
458 Some(p) => p,
459 None => return result,
460 };
461
462 if let Err(ref err) = result {
464 if !crate::retry::is_retryable(err) {
465 return result;
466 }
467 } else {
468 return result;
469 }
470
471 let mut last_result = result;
472
473 for attempt in 0..policy.max_retries {
474 let delay = policy.delay_for_attempt(attempt);
475 warn!(
476 attempt = attempt + 1,
477 max_retries = policy.max_retries,
478 delay_ms = delay.as_millis() as u64,
479 "retrying agent invocation"
480 );
481 tokio::time::sleep(delay).await;
482
483 last_result = self.invoke_once(provider).await;
484
485 match &last_result {
486 Ok(_) => return last_result,
487 Err(err) if !crate::retry::is_retryable(err) => return last_result,
488 _ => {}
489 }
490 }
491
492 last_result
493 }
494
495 async fn invoke_once(
497 &self,
498 provider: &dyn AgentProvider,
499 ) -> Result<AgentResult, OperationError> {
500 #[cfg(feature = "prometheus")]
501 let model_label = self.config.model.to_string();
502
503 let output = match provider.invoke(&self.config).await {
504 Ok(output) => output,
505 Err(e) => {
506 #[cfg(feature = "prometheus")]
507 {
508 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
509 }
510 return Err(OperationError::Agent(e));
511 }
512 };
513
514 info!(
515 duration_ms = output.duration_ms,
516 cost_usd = output.cost_usd,
517 input_tokens = output.input_tokens,
518 output_tokens = output.output_tokens,
519 model = output.model,
520 "agent completed"
521 );
522
523 #[cfg(feature = "prometheus")]
524 {
525 metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
526 metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
527 .record(output.duration_ms as f64 / 1000.0);
528 if let Some(cost) = output.cost_usd {
529 metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
530 .increment(cost);
531 }
532 if let Some(tokens) = output.input_tokens {
533 metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
534 }
535 if let Some(tokens) = output.output_tokens {
536 metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
537 .increment(tokens);
538 }
539 }
540
541 Ok(AgentResult { output })
542 }
543}
544
545impl Default for Agent {
546 fn default() -> Self {
547 Self::new()
548 }
549}
550
551#[derive(Debug)]
556pub struct AgentResult {
557 output: AgentOutput,
558}
559
560impl AgentResult {
561 pub fn text(&self) -> &str {
566 match self.output.value.as_str() {
567 Some(s) => s,
568 None => {
569 warn!(
570 value_type = self.output.value.to_string(),
571 "agent output is not a string, returning empty"
572 );
573 ""
574 }
575 }
576 }
577
578 pub fn value(&self) -> &Value {
580 &self.output.value
581 }
582
583 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
593 from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
594 }
595
596 pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
602 from_value(self.output.value).map_err(OperationError::deserialize::<T>)
603 }
604
605 #[cfg(test)]
610 pub(crate) fn from_output(output: AgentOutput) -> Self {
611 Self { output }
612 }
613
614 pub fn session_id(&self) -> Option<&str> {
616 self.output.session_id.as_deref()
617 }
618
619 pub fn cost_usd(&self) -> Option<f64> {
621 self.output.cost_usd
622 }
623
624 pub fn input_tokens(&self) -> Option<u64> {
626 self.output.input_tokens
627 }
628
629 pub fn output_tokens(&self) -> Option<u64> {
631 self.output.output_tokens
632 }
633
634 pub fn duration_ms(&self) -> u64 {
636 self.output.duration_ms
637 }
638
639 pub fn model(&self) -> Option<&str> {
641 self.output.model.as_deref()
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use crate::provider::InvokeFuture;
649 use serde_json::json;
650
651 struct TestProvider {
652 output: AgentOutput,
653 }
654
655 impl AgentProvider for TestProvider {
656 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
657 Box::pin(async move {
658 Ok(AgentOutput {
659 value: self.output.value.clone(),
660 session_id: self.output.session_id.clone(),
661 cost_usd: self.output.cost_usd,
662 input_tokens: self.output.input_tokens,
663 output_tokens: self.output.output_tokens,
664 model: self.output.model.clone(),
665 duration_ms: self.output.duration_ms,
666 })
667 })
668 }
669 }
670
671 struct ConfigCapture {
672 output: AgentOutput,
673 }
674
675 impl AgentProvider for ConfigCapture {
676 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
677 let config_json = serde_json::to_value(config).unwrap();
678 Box::pin(async move {
679 Ok(AgentOutput {
680 value: config_json,
681 session_id: self.output.session_id.clone(),
682 cost_usd: self.output.cost_usd,
683 input_tokens: self.output.input_tokens,
684 output_tokens: self.output.output_tokens,
685 model: self.output.model.clone(),
686 duration_ms: self.output.duration_ms,
687 })
688 })
689 }
690 }
691
692 fn default_output() -> AgentOutput {
693 AgentOutput {
694 value: json!("test output"),
695 session_id: Some("sess-123".to_string()),
696 cost_usd: Some(0.05),
697 input_tokens: Some(100),
698 output_tokens: Some(50),
699 model: Some("sonnet".to_string()),
700 duration_ms: 1500,
701 }
702 }
703
704 #[test]
707 fn model_display_sonnet() {
708 assert_eq!(Model::Sonnet.to_string(), "sonnet");
709 }
710
711 #[test]
712 fn model_display_opus() {
713 assert_eq!(Model::Opus.to_string(), "opus");
714 }
715
716 #[test]
717 fn model_display_haiku() {
718 assert_eq!(Model::Haiku.to_string(), "haiku");
719 }
720
721 #[tokio::test]
724 async fn agent_new_default_values() {
725 let provider = ConfigCapture {
726 output: default_output(),
727 };
728 let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
729
730 let config = result.value();
731 assert_eq!(config["system_prompt"], json!(null));
732 assert_eq!(config["prompt"], json!("hi"));
733 assert_eq!(config["model"], json!("sonnet"));
734 assert_eq!(config["allowed_tools"], json!([]));
735 assert_eq!(config["max_turns"], json!(null));
736 assert_eq!(config["max_budget_usd"], json!(null));
737 assert_eq!(config["working_dir"], json!(null));
738 assert_eq!(config["mcp_config"], json!(null));
739 assert_eq!(config["permission_mode"], json!("Default"));
740 assert_eq!(config["json_schema"], json!(null));
741 }
742
743 #[tokio::test]
744 async fn agent_default_matches_new() {
745 let provider = ConfigCapture {
746 output: default_output(),
747 };
748 let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
749 let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
750
751 assert_eq!(result_new.value(), result_default.value());
752 }
753
754 #[tokio::test]
757 async fn builder_methods_store_values_correctly() {
758 let provider = ConfigCapture {
759 output: default_output(),
760 };
761 let result = Agent::new()
762 .system_prompt("you are a bot")
763 .prompt("do something")
764 .model(Model::Opus)
765 .allowed_tools(&["Read", "Write"])
766 .max_turns(5)
767 .max_budget_usd(1.5)
768 .working_dir("/tmp")
769 .mcp_config("{}")
770 .permission_mode(PermissionMode::Auto)
771 .run(&provider)
772 .await
773 .unwrap();
774
775 let config = result.value();
776 assert_eq!(config["system_prompt"], json!("you are a bot"));
777 assert_eq!(config["prompt"], json!("do something"));
778 assert_eq!(config["model"], json!("opus"));
779 assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
780 assert_eq!(config["max_turns"], json!(5));
781 assert_eq!(config["max_budget_usd"], json!(1.5));
782 assert_eq!(config["working_dir"], json!("/tmp"));
783 assert_eq!(config["mcp_config"], json!("{}"));
784 assert_eq!(config["permission_mode"], json!("Auto"));
785 }
786
787 #[test]
790 #[should_panic(expected = "max_turns must be greater than 0")]
791 fn max_turns_zero_panics() {
792 let _ = Agent::new().max_turns(0);
793 }
794
795 #[test]
796 #[should_panic(expected = "budget must be a positive finite number")]
797 fn max_budget_negative_panics() {
798 let _ = Agent::new().max_budget_usd(-1.0);
799 }
800
801 #[test]
802 #[should_panic(expected = "budget must be a positive finite number")]
803 fn max_budget_nan_panics() {
804 let _ = Agent::new().max_budget_usd(f64::NAN);
805 }
806
807 #[test]
808 #[should_panic(expected = "budget must be a positive finite number")]
809 fn max_budget_infinity_panics() {
810 let _ = Agent::new().max_budget_usd(f64::INFINITY);
811 }
812
813 #[tokio::test]
816 async fn agent_result_text_with_string_value() {
817 let provider = TestProvider {
818 output: AgentOutput {
819 value: json!("hello world"),
820 ..default_output()
821 },
822 };
823 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
824 assert_eq!(result.text(), "hello world");
825 }
826
827 #[tokio::test]
828 async fn agent_result_text_with_non_string_value() {
829 let provider = TestProvider {
830 output: AgentOutput {
831 value: json!(42),
832 ..default_output()
833 },
834 };
835 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
836 assert_eq!(result.text(), "");
837 }
838
839 #[tokio::test]
840 async fn agent_result_text_with_null_value() {
841 let provider = TestProvider {
842 output: AgentOutput {
843 value: json!(null),
844 ..default_output()
845 },
846 };
847 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
848 assert_eq!(result.text(), "");
849 }
850
851 #[tokio::test]
852 async fn agent_result_json_successful_deserialize() {
853 #[derive(Deserialize, PartialEq, Debug)]
854 struct MyOutput {
855 name: String,
856 count: u32,
857 }
858 let provider = TestProvider {
859 output: AgentOutput {
860 value: json!({"name": "test", "count": 7}),
861 ..default_output()
862 },
863 };
864 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
865 let parsed: MyOutput = result.json().unwrap();
866 assert_eq!(parsed.name, "test");
867 assert_eq!(parsed.count, 7);
868 }
869
870 #[tokio::test]
871 async fn agent_result_json_failed_deserialize() {
872 #[derive(Debug, Deserialize)]
873 #[allow(dead_code)]
874 struct MyOutput {
875 name: String,
876 }
877 let provider = TestProvider {
878 output: AgentOutput {
879 value: json!(42),
880 ..default_output()
881 },
882 };
883 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
884 let err = result.json::<MyOutput>().unwrap_err();
885 assert!(matches!(err, OperationError::Deserialize { .. }));
886 }
887
888 #[tokio::test]
889 async fn agent_result_accessors() {
890 let provider = TestProvider {
891 output: AgentOutput {
892 value: json!("v"),
893 session_id: Some("s-1".to_string()),
894 cost_usd: Some(0.123),
895 input_tokens: Some(999),
896 output_tokens: Some(456),
897 model: Some("opus".to_string()),
898 duration_ms: 2000,
899 },
900 };
901 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
902 assert_eq!(result.session_id(), Some("s-1"));
903 assert_eq!(result.cost_usd(), Some(0.123));
904 assert_eq!(result.input_tokens(), Some(999));
905 assert_eq!(result.output_tokens(), Some(456));
906 assert_eq!(result.duration_ms(), 2000);
907 assert_eq!(result.model(), Some("opus"));
908 }
909
910 #[tokio::test]
913 async fn resume_passes_session_id_in_config() {
914 let provider = ConfigCapture {
915 output: default_output(),
916 };
917 let result = Agent::new()
918 .prompt("followup")
919 .resume("sess-abc")
920 .run(&provider)
921 .await
922 .unwrap();
923
924 let config = result.value();
925 assert_eq!(config["resume_session_id"], json!("sess-abc"));
926 }
927
928 #[tokio::test]
929 async fn no_resume_has_null_session_id() {
930 let provider = ConfigCapture {
931 output: default_output(),
932 };
933 let result = Agent::new()
934 .prompt("first call")
935 .run(&provider)
936 .await
937 .unwrap();
938
939 let config = result.value();
940 assert_eq!(config["resume_session_id"], json!(null));
941 }
942
943 #[test]
944 #[should_panic(expected = "session_id must not be empty")]
945 fn resume_empty_session_id_panics() {
946 let _ = Agent::new().resume("");
947 }
948
949 #[test]
950 #[should_panic(expected = "session_id must only contain")]
951 fn resume_invalid_chars_panics() {
952 let _ = Agent::new().resume("sess;rm -rf /");
953 }
954
955 #[test]
956 fn resume_valid_formats_accepted() {
957 let _ = Agent::new().resume("sess-abc123");
958 let _ = Agent::new().resume("a1b2c3d4_session");
959 let _ = Agent::new().resume("abc-DEF-123_456");
960 }
961
962 #[tokio::test]
963 #[should_panic(expected = "prompt must not be empty")]
964 async fn run_without_prompt_panics() {
965 let provider = TestProvider {
966 output: default_output(),
967 };
968 let _ = Agent::new().run(&provider).await;
969 }
970
971 #[tokio::test]
972 #[should_panic(expected = "prompt must not be empty")]
973 async fn run_with_whitespace_only_prompt_panics() {
974 let provider = TestProvider {
975 output: default_output(),
976 };
977 let _ = Agent::new().prompt(" ").run(&provider).await;
978 }
979
980 #[test]
983 fn model_serialize_deserialize_roundtrip() {
984 for model in [
985 Model::Sonnet,
986 Model::Opus,
987 Model::Haiku,
988 Model::Haiku45,
989 Model::Sonnet46,
990 Model::Opus46,
991 Model::Sonnet46_1M,
992 Model::Opus46_1M,
993 ] {
994 let json = serde_json::to_string(&model).unwrap();
995 let back: Model = serde_json::from_str(&json).unwrap();
996 assert_eq!(model.to_string(), back.to_string());
997 }
998 }
999
1000 #[test]
1001 fn model_display_explicit_ids() {
1002 assert_eq!(Model::Haiku45.to_string(), "claude-haiku-4-5-20251001");
1003 assert_eq!(Model::Sonnet46.to_string(), "claude-sonnet-4-6");
1004 assert_eq!(Model::Opus46.to_string(), "claude-opus-4-6");
1005 assert_eq!(Model::Sonnet46_1M.to_string(), "claude-sonnet-4-6[1m]");
1006 assert_eq!(Model::Opus46_1M.to_string(), "claude-opus-4-6[1m]");
1007 }
1008
1009 #[test]
1010 fn model_context_window() {
1011 assert_eq!(Model::Sonnet.context_window(), 200_000);
1012 assert_eq!(Model::Opus.context_window(), 200_000);
1013 assert_eq!(Model::Haiku.context_window(), 200_000);
1014 assert_eq!(Model::Sonnet46_1M.context_window(), 1_000_000);
1015 assert_eq!(Model::Opus46_1M.context_window(), 1_000_000);
1016 }
1017
1018 #[tokio::test]
1019 async fn into_json_success() {
1020 #[derive(Deserialize, PartialEq, Debug)]
1021 struct Out {
1022 name: String,
1023 }
1024 let provider = TestProvider {
1025 output: AgentOutput {
1026 value: json!({"name": "test"}),
1027 ..default_output()
1028 },
1029 };
1030 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1031 let parsed: Out = result.into_json().unwrap();
1032 assert_eq!(parsed.name, "test");
1033 }
1034
1035 #[tokio::test]
1036 async fn into_json_failure() {
1037 #[derive(Debug, Deserialize)]
1038 #[allow(dead_code)]
1039 struct Out {
1040 name: String,
1041 }
1042 let provider = TestProvider {
1043 output: AgentOutput {
1044 value: json!(42),
1045 ..default_output()
1046 },
1047 };
1048 let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1049 let err = result.into_json::<Out>().unwrap_err();
1050 assert!(matches!(err, OperationError::Deserialize { .. }));
1051 }
1052
1053 #[test]
1054 fn from_output_creates_result() {
1055 let output = AgentOutput {
1056 value: json!("hello"),
1057 ..default_output()
1058 };
1059 let result = AgentResult::from_output(output);
1060 assert_eq!(result.text(), "hello");
1061 assert_eq!(result.cost_usd(), Some(0.05));
1062 }
1063
1064 #[test]
1065 #[should_panic(expected = "budget must be a positive finite number")]
1066 fn max_budget_zero_panics() {
1067 let _ = Agent::new().max_budget_usd(0.0);
1068 }
1069
1070 #[tokio::test]
1071 async fn run_with_prompt_too_large_returns_error() {
1072 let provider = TestProvider {
1073 output: default_output(),
1074 };
1075 let large_prompt = "x".repeat(800_004);
1077 let result = Agent::new()
1078 .prompt(&large_prompt)
1079 .model(Model::Sonnet)
1080 .run(&provider)
1081 .await;
1082
1083 let err = result.unwrap_err();
1084 match err {
1085 OperationError::Agent(crate::error::AgentError::PromptTooLarge {
1086 chars,
1087 estimated_tokens,
1088 model_limit,
1089 }) => {
1090 assert_eq!(chars, 800_004);
1091 assert_eq!(estimated_tokens, 200_001);
1092 assert_eq!(model_limit, 200_000);
1093 }
1094 other => panic!("expected PromptTooLarge, got: {other}"),
1095 }
1096 }
1097
1098 #[tokio::test]
1099 async fn run_with_prompt_at_limit_succeeds() {
1100 let provider = TestProvider {
1101 output: default_output(),
1102 };
1103 let prompt = "x".repeat(800_000);
1105 let result = Agent::new()
1106 .prompt(&prompt)
1107 .model(Model::Sonnet)
1108 .run(&provider)
1109 .await;
1110 assert!(result.is_ok());
1111 }
1112
1113 #[tokio::test]
1114 async fn run_with_system_prompt_counts_toward_limit() {
1115 let provider = TestProvider {
1116 output: default_output(),
1117 };
1118 let system = "s".repeat(400_004);
1120 let prompt = "p".repeat(400_000);
1121 let result = Agent::new()
1122 .system_prompt(&system)
1123 .prompt(&prompt)
1124 .model(Model::Sonnet)
1125 .run(&provider)
1126 .await;
1127
1128 let err = result.unwrap_err();
1129 assert!(matches!(
1130 err,
1131 OperationError::Agent(crate::error::AgentError::PromptTooLarge { .. })
1132 ));
1133 }
1134
1135 #[tokio::test]
1136 async fn run_with_1m_model_allows_larger_prompt() {
1137 let provider = TestProvider {
1138 output: default_output(),
1139 };
1140 let prompt = "x".repeat(900_000);
1142 let result = Agent::new()
1143 .prompt(&prompt)
1144 .model(Model::Sonnet46_1M)
1145 .run(&provider)
1146 .await;
1147 assert!(result.is_ok());
1148 }
1149
1150 #[test]
1151 fn model_equality() {
1152 assert_eq!(Model::Sonnet, Model::Sonnet);
1153 assert_ne!(Model::Sonnet, Model::Opus);
1154 }
1155
1156 #[test]
1157 fn permission_mode_serialize_deserialize_roundtrip() {
1158 for mode in [
1159 PermissionMode::Default,
1160 PermissionMode::Auto,
1161 PermissionMode::DontAsk,
1162 PermissionMode::BypassPermissions,
1163 ] {
1164 let json = serde_json::to_string(&mode).unwrap();
1165 let back: PermissionMode = serde_json::from_str(&json).unwrap();
1166 assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1167 }
1168 }
1169
1170 #[test]
1173 fn retry_builder_stores_policy() {
1174 let agent = Agent::new().retry(3);
1175 assert!(agent.retry_policy.is_some());
1176 assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1177 }
1178
1179 #[test]
1180 fn retry_policy_builder_stores_custom_policy() {
1181 use crate::retry::RetryPolicy;
1182 let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1183 let agent = Agent::new().retry_policy(policy);
1184 let p = agent.retry_policy.unwrap();
1185 assert_eq!(p.max_retries(), 5);
1186 }
1187
1188 #[test]
1189 fn no_retry_by_default() {
1190 let agent = Agent::new();
1191 assert!(agent.retry_policy.is_none());
1192 }
1193
1194 use std::sync::Arc;
1197 use std::sync::atomic::{AtomicU32, Ordering};
1198 use std::time::Duration;
1199
1200 struct FailNTimesProvider {
1201 fail_count: AtomicU32,
1202 failures_before_success: u32,
1203 output: AgentOutput,
1204 }
1205
1206 impl AgentProvider for FailNTimesProvider {
1207 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1208 Box::pin(async move {
1209 let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1210 if current < self.failures_before_success {
1211 Err(AgentError::ProcessFailed {
1212 exit_code: 1,
1213 stderr: format!("transient failure #{}", current + 1),
1214 })
1215 } else {
1216 Ok(AgentOutput {
1217 value: self.output.value.clone(),
1218 session_id: self.output.session_id.clone(),
1219 cost_usd: self.output.cost_usd,
1220 input_tokens: self.output.input_tokens,
1221 output_tokens: self.output.output_tokens,
1222 model: self.output.model.clone(),
1223 duration_ms: self.output.duration_ms,
1224 })
1225 }
1226 })
1227 }
1228 }
1229
1230 #[tokio::test]
1231 async fn retry_succeeds_after_transient_failures() {
1232 let provider = FailNTimesProvider {
1233 fail_count: AtomicU32::new(0),
1234 failures_before_success: 2,
1235 output: default_output(),
1236 };
1237 let result = Agent::new()
1238 .prompt("test")
1239 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1240 .run(&provider)
1241 .await;
1242
1243 assert!(result.is_ok());
1244 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
1246
1247 #[tokio::test]
1248 async fn retry_exhausted_returns_last_error() {
1249 let provider = FailNTimesProvider {
1250 fail_count: AtomicU32::new(0),
1251 failures_before_success: 10, output: default_output(),
1253 };
1254 let result = Agent::new()
1255 .prompt("test")
1256 .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1257 .run(&provider)
1258 .await;
1259
1260 assert!(result.is_err());
1261 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1263 }
1264
1265 #[tokio::test]
1266 async fn retry_does_not_retry_non_retryable_errors() {
1267 let call_count = Arc::new(AtomicU32::new(0));
1268 let count = call_count.clone();
1269
1270 struct CountingNonRetryable {
1271 count: Arc<AtomicU32>,
1272 }
1273 impl AgentProvider for CountingNonRetryable {
1274 fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1275 self.count.fetch_add(1, Ordering::SeqCst);
1276 Box::pin(async move {
1277 Err(AgentError::SchemaValidation {
1278 expected: "object".to_string(),
1279 got: "string".to_string(),
1280 })
1281 })
1282 }
1283 }
1284
1285 let provider = CountingNonRetryable { count };
1286 let result = Agent::new()
1287 .prompt("test")
1288 .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1289 .run(&provider)
1290 .await;
1291
1292 assert!(result.is_err());
1293 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1295 }
1296
1297 #[tokio::test]
1298 async fn no_retry_without_policy() {
1299 let provider = FailNTimesProvider {
1300 fail_count: AtomicU32::new(0),
1301 failures_before_success: 1,
1302 output: default_output(),
1303 };
1304 let result = Agent::new().prompt("test").run(&provider).await;
1305
1306 assert!(result.is_err());
1307 assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1308 }
1309}