Skip to main content

zag_agent/
agent.rs

1use crate::output::AgentOutput;
2use crate::sandbox::SandboxConfig;
3use anyhow::Result;
4use async_trait::async_trait;
5use std::sync::Arc;
6
7/// Callback invoked once with the OS pid of the spawned agent subprocess.
8///
9/// Set via [`Agent::set_on_spawn_hook`] (or
10/// [`crate::builder::AgentBuilder::on_spawn`]) so callers that need to
11/// act on the running child — for example, updating a process registry
12/// so `zag ps kill self` can SIGTERM the agent child instead of the
13/// parent zag process — can capture the pid right after spawn and
14/// before the terminal wait.
15///
16/// The callback fires *once per spawn*, with the pid of the direct
17/// provider subprocess. On retries or resumes the callback fires again
18/// for the new child. `pid` is not guaranteed to still be alive by the
19/// time the callback runs; use the OS to confirm before signaling.
20pub type OnSpawnHook = Arc<dyn Fn(u32) + Send + Sync>;
21
22/// Model size categories that map to agent-specific models.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ModelSize {
25    /// Fast and lightweight model for simple tasks
26    Small,
27    /// Balanced model for most tasks (default)
28    Medium,
29    /// Most capable model for complex reasoning
30    Large,
31}
32
33impl std::str::FromStr for ModelSize {
34    type Err = ();
35
36    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
37        match s.to_lowercase().as_str() {
38            "small" | "s" => Ok(ModelSize::Small),
39            "medium" | "m" | "default" => Ok(ModelSize::Medium),
40            "large" | "l" | "max" => Ok(ModelSize::Large),
41            _ => Err(()),
42        }
43    }
44}
45
46#[async_trait]
47#[allow(dead_code)]
48pub trait Agent {
49    fn name(&self) -> &str;
50
51    fn default_model() -> &'static str
52    where
53        Self: Sized;
54
55    /// Get the model name for a given size category.
56    fn model_for_size(size: ModelSize) -> &'static str
57    where
58        Self: Sized;
59
60    /// Resolve a model input (either a size alias or specific model name).
61    ///
62    /// If the input is a size alias (small/medium/large), returns the
63    /// corresponding model for this agent. Otherwise returns the input as-is.
64    fn resolve_model(model_input: &str) -> String
65    where
66        Self: Sized,
67    {
68        if let Ok(size) = model_input.parse::<ModelSize>() {
69            Self::model_for_size(size).to_string()
70        } else {
71            model_input.to_string()
72        }
73    }
74
75    /// Get the list of available models for this agent.
76    fn available_models() -> &'static [&'static str]
77    where
78        Self: Sized;
79
80    /// Validate that a model name is supported by this agent.
81    ///
82    /// Returns Ok(()) if valid, or an error with available models if invalid.
83    fn validate_model(model: &str, agent_name: &str) -> Result<()>
84    where
85        Self: Sized,
86    {
87        let available = Self::available_models();
88        if available.contains(&model) {
89            Ok(())
90        } else {
91            // Build error message with size aliases first
92            let small = Self::model_for_size(ModelSize::Small);
93            let medium = Self::model_for_size(ModelSize::Medium);
94            let large = Self::model_for_size(ModelSize::Large);
95
96            let mut models = vec![
97                format!("{} (small)", small),
98                format!("{} (medium)", medium),
99                format!("{} (large)", large),
100            ];
101
102            // Add other available models that aren't already in the size mappings
103            for m in available {
104                if m != &small && m != &medium && m != &large {
105                    models.push(m.to_string());
106                }
107            }
108
109            anyhow::bail!(
110                "Invalid model '{}' for {}. Available models: {}",
111                model,
112                agent_name,
113                models.join(", ")
114            )
115        }
116    }
117
118    fn system_prompt(&self) -> &str;
119
120    fn set_system_prompt(&mut self, prompt: String);
121
122    fn get_model(&self) -> &str;
123
124    fn set_model(&mut self, model: String);
125
126    fn set_root(&mut self, root: String);
127
128    fn set_skip_permissions(&mut self, skip: bool);
129
130    fn set_output_format(&mut self, format: Option<String>);
131
132    /// Enable output capture mode.
133    ///
134    /// When set, non-interactive `run()` pipes stdout, captures the text,
135    /// and returns `Some(AgentOutput)`. Default is `false` (streams to terminal).
136    /// Claude handles capture via output_format, so the default is a no-op.
137    fn set_capture_output(&mut self, _capture: bool) {}
138
139    /// Set the maximum number of agentic turns.
140    fn set_max_turns(&mut self, _turns: u32) {}
141
142    /// Set sandbox configuration for running inside a Docker sandbox.
143    fn set_sandbox(&mut self, _config: SandboxConfig) {}
144
145    /// Set additional directories for the agent to include.
146    fn set_add_dirs(&mut self, dirs: Vec<String>);
147
148    /// Set environment variables to pass to the agent subprocess.
149    fn set_env_vars(&mut self, _vars: Vec<(String, String)>) {}
150
151    /// Register a callback that fires with the OS pid of the spawned
152    /// agent subprocess.
153    ///
154    /// Default impl is a no-op; providers that spawn an OS subprocess
155    /// override this to invoke the hook after spawn. See [`OnSpawnHook`]
156    /// for callback semantics.
157    fn set_on_spawn_hook(&mut self, _hook: OnSpawnHook) {}
158
159    /// Get a reference to the concrete agent type (for downcasting).
160    fn as_any_ref(&self) -> &dyn std::any::Any;
161
162    /// Get a mutable reference to the concrete agent type (for downcasting).
163    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
164
165    /// Run the agent in non-interactive mode.
166    ///
167    /// Returns `Some(AgentOutput)` if the agent supports structured output
168    /// (e.g., JSON mode), otherwise returns `None`.
169    async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>>;
170
171    async fn run_interactive(&self, prompt: Option<&str>) -> Result<()>;
172
173    /// Resume a previous session.
174    ///
175    /// If `session_id` is provided, resumes that specific session.
176    /// If `last` is true, resumes the most recent session.
177    /// If neither, shows a session picker or resumes the most recent.
178    async fn run_resume(&self, session_id: Option<&str>, last: bool) -> Result<()>;
179
180    /// Resume a previous session with a new prompt (for retry/correction).
181    ///
182    /// Returns `Some(AgentOutput)` if the agent supports structured output.
183    /// Default implementation returns an error indicating unsupported operation.
184    async fn run_resume_with_prompt(
185        &self,
186        _session_id: &str,
187        _prompt: &str,
188    ) -> Result<Option<AgentOutput>> {
189        anyhow::bail!("Resume with prompt is not supported by this agent")
190    }
191
192    /// Lightweight startup probe used by the provider fallback mechanism.
193    ///
194    /// Override this in providers that can cheaply detect a broken startup
195    /// state (e.g. missing auth) without consuming paid API quota. A non-Ok
196    /// return value is treated as a reason to downgrade to the next provider
197    /// in the tier list when the user has not pinned a provider with `-p`.
198    ///
199    /// The default implementation is a no-op because pre-flight PATH lookup
200    /// (`preflight::check_binary`) already catches the missing-binary case.
201    async fn probe(&self) -> Result<()> {
202        Ok(())
203    }
204
205    async fn cleanup(&self) -> Result<()>;
206}
207
208#[cfg(test)]
209#[path = "agent_tests.rs"]
210mod tests;