zag_agent/agent.rs
1use crate::output::AgentOutput;
2use crate::sandbox::SandboxConfig;
3use anyhow::Result;
4use async_trait::async_trait;
5
6/// Model size categories that map to agent-specific models.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelSize {
9 /// Fast and lightweight model for simple tasks
10 Small,
11 /// Balanced model for most tasks (default)
12 Medium,
13 /// Most capable model for complex reasoning
14 Large,
15}
16
17impl std::str::FromStr for ModelSize {
18 type Err = ();
19
20 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
21 match s.to_lowercase().as_str() {
22 "small" | "s" => Ok(ModelSize::Small),
23 "medium" | "m" | "default" => Ok(ModelSize::Medium),
24 "large" | "l" | "max" => Ok(ModelSize::Large),
25 _ => Err(()),
26 }
27 }
28}
29
30#[async_trait]
31#[allow(dead_code)]
32pub trait Agent {
33 fn name(&self) -> &str;
34
35 fn default_model() -> &'static str
36 where
37 Self: Sized;
38
39 /// Get the model name for a given size category.
40 fn model_for_size(size: ModelSize) -> &'static str
41 where
42 Self: Sized;
43
44 /// Resolve a model input (either a size alias or specific model name).
45 ///
46 /// If the input is a size alias (small/medium/large), returns the
47 /// corresponding model for this agent. Otherwise returns the input as-is.
48 fn resolve_model(model_input: &str) -> String
49 where
50 Self: Sized,
51 {
52 if let Ok(size) = model_input.parse::<ModelSize>() {
53 Self::model_for_size(size).to_string()
54 } else {
55 model_input.to_string()
56 }
57 }
58
59 /// Get the list of available models for this agent.
60 fn available_models() -> &'static [&'static str]
61 where
62 Self: Sized;
63
64 /// Validate that a model name is supported by this agent.
65 ///
66 /// Returns Ok(()) if valid, or an error with available models if invalid.
67 fn validate_model(model: &str, agent_name: &str) -> Result<()>
68 where
69 Self: Sized,
70 {
71 let available = Self::available_models();
72 if available.contains(&model) {
73 Ok(())
74 } else {
75 // Build error message with size aliases first
76 let small = Self::model_for_size(ModelSize::Small);
77 let medium = Self::model_for_size(ModelSize::Medium);
78 let large = Self::model_for_size(ModelSize::Large);
79
80 let mut models = vec![
81 format!("{} (small)", small),
82 format!("{} (medium)", medium),
83 format!("{} (large)", large),
84 ];
85
86 // Add other available models that aren't already in the size mappings
87 for m in available {
88 if m != &small && m != &medium && m != &large {
89 models.push(m.to_string());
90 }
91 }
92
93 anyhow::bail!(
94 "Invalid model '{}' for {}. Available models: {}",
95 model,
96 agent_name,
97 models.join(", ")
98 )
99 }
100 }
101
102 fn system_prompt(&self) -> &str;
103
104 fn set_system_prompt(&mut self, prompt: String);
105
106 fn get_model(&self) -> &str;
107
108 fn set_model(&mut self, model: String);
109
110 fn set_root(&mut self, root: String);
111
112 fn set_skip_permissions(&mut self, skip: bool);
113
114 fn set_output_format(&mut self, format: Option<String>);
115
116 /// Enable output capture mode.
117 ///
118 /// When set, non-interactive `run()` pipes stdout, captures the text,
119 /// and returns `Some(AgentOutput)`. Default is `false` (streams to terminal).
120 /// Claude handles capture via output_format, so the default is a no-op.
121 fn set_capture_output(&mut self, _capture: bool) {}
122
123 /// Set the maximum number of agentic turns.
124 fn set_max_turns(&mut self, _turns: u32) {}
125
126 /// Set sandbox configuration for running inside a Docker sandbox.
127 fn set_sandbox(&mut self, _config: SandboxConfig) {}
128
129 /// Set additional directories for the agent to include.
130 fn set_add_dirs(&mut self, dirs: Vec<String>);
131
132 /// Set environment variables to pass to the agent subprocess.
133 fn set_env_vars(&mut self, _vars: Vec<(String, String)>) {}
134
135 /// Get a reference to the concrete agent type (for downcasting).
136 fn as_any_ref(&self) -> &dyn std::any::Any;
137
138 /// Get a mutable reference to the concrete agent type (for downcasting).
139 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
140
141 /// Run the agent in non-interactive mode.
142 ///
143 /// Returns `Some(AgentOutput)` if the agent supports structured output
144 /// (e.g., JSON mode), otherwise returns `None`.
145 async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>>;
146
147 async fn run_interactive(&self, prompt: Option<&str>) -> Result<()>;
148
149 /// Resume a previous session.
150 ///
151 /// If `session_id` is provided, resumes that specific session.
152 /// If `last` is true, resumes the most recent session.
153 /// If neither, shows a session picker or resumes the most recent.
154 async fn run_resume(&self, session_id: Option<&str>, last: bool) -> Result<()>;
155
156 /// Resume a previous session with a new prompt (for retry/correction).
157 ///
158 /// Returns `Some(AgentOutput)` if the agent supports structured output.
159 /// Default implementation returns an error indicating unsupported operation.
160 async fn run_resume_with_prompt(
161 &self,
162 _session_id: &str,
163 _prompt: &str,
164 ) -> Result<Option<AgentOutput>> {
165 anyhow::bail!("Resume with prompt is not supported by this agent")
166 }
167
168 /// Lightweight startup probe used by the provider fallback mechanism.
169 ///
170 /// Override this in providers that can cheaply detect a broken startup
171 /// state (e.g. missing auth) without consuming paid API quota. A non-Ok
172 /// return value is treated as a reason to downgrade to the next provider
173 /// in the tier list when the user has not pinned a provider with `-p`.
174 ///
175 /// The default implementation is a no-op because pre-flight PATH lookup
176 /// (`preflight::check_binary`) already catches the missing-binary case.
177 async fn probe(&self) -> Result<()> {
178 Ok(())
179 }
180
181 async fn cleanup(&self) -> Result<()>;
182}
183
184#[cfg(test)]
185#[path = "agent_tests.rs"]
186mod tests;