Skip to main content

zag_agent/
agent.rs

1use crate::output::AgentOutput;
2use crate::sandbox::SandboxConfig;
3use anyhow::Result;
4use async_trait::async_trait;
5
6/// Model size categories that map to agent-specific models.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelSize {
9    /// Fast and lightweight model for simple tasks
10    Small,
11    /// Balanced model for most tasks (default)
12    Medium,
13    /// Most capable model for complex reasoning
14    Large,
15}
16
17impl std::str::FromStr for ModelSize {
18    type Err = ();
19
20    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
21        match s.to_lowercase().as_str() {
22            "small" | "s" => Ok(ModelSize::Small),
23            "medium" | "m" | "default" => Ok(ModelSize::Medium),
24            "large" | "l" | "max" => Ok(ModelSize::Large),
25            _ => Err(()),
26        }
27    }
28}
29
30#[async_trait]
31#[allow(dead_code)]
32pub trait Agent {
33    fn name(&self) -> &str;
34
35    fn default_model() -> &'static str
36    where
37        Self: Sized;
38
39    /// Get the model name for a given size category.
40    fn model_for_size(size: ModelSize) -> &'static str
41    where
42        Self: Sized;
43
44    /// Resolve a model input (either a size alias or specific model name).
45    ///
46    /// If the input is a size alias (small/medium/large), returns the
47    /// corresponding model for this agent. Otherwise returns the input as-is.
48    fn resolve_model(model_input: &str) -> String
49    where
50        Self: Sized,
51    {
52        if let Ok(size) = model_input.parse::<ModelSize>() {
53            Self::model_for_size(size).to_string()
54        } else {
55            model_input.to_string()
56        }
57    }
58
59    /// Get the list of available models for this agent.
60    fn available_models() -> &'static [&'static str]
61    where
62        Self: Sized;
63
64    /// Validate that a model name is supported by this agent.
65    ///
66    /// Returns Ok(()) if valid, or an error with available models if invalid.
67    fn validate_model(model: &str, agent_name: &str) -> Result<()>
68    where
69        Self: Sized,
70    {
71        let available = Self::available_models();
72        if available.contains(&model) {
73            Ok(())
74        } else {
75            // Build error message with size aliases first
76            let small = Self::model_for_size(ModelSize::Small);
77            let medium = Self::model_for_size(ModelSize::Medium);
78            let large = Self::model_for_size(ModelSize::Large);
79
80            let mut models = vec![
81                format!("{} (small)", small),
82                format!("{} (medium)", medium),
83                format!("{} (large)", large),
84            ];
85
86            // Add other available models that aren't already in the size mappings
87            for m in available {
88                if m != &small && m != &medium && m != &large {
89                    models.push(m.to_string());
90                }
91            }
92
93            anyhow::bail!(
94                "Invalid model '{}' for {}. Available models: {}",
95                model,
96                agent_name,
97                models.join(", ")
98            )
99        }
100    }
101
102    fn system_prompt(&self) -> &str;
103
104    fn set_system_prompt(&mut self, prompt: String);
105
106    fn get_model(&self) -> &str;
107
108    fn set_model(&mut self, model: String);
109
110    fn set_root(&mut self, root: String);
111
112    fn set_skip_permissions(&mut self, skip: bool);
113
114    fn set_output_format(&mut self, format: Option<String>);
115
116    /// Enable output capture mode.
117    ///
118    /// When set, non-interactive `run()` pipes stdout, captures the text,
119    /// and returns `Some(AgentOutput)`. Default is `false` (streams to terminal).
120    /// Claude handles capture via output_format, so the default is a no-op.
121    fn set_capture_output(&mut self, _capture: bool) {}
122
123    /// Set the maximum number of agentic turns.
124    fn set_max_turns(&mut self, _turns: u32) {}
125
126    /// Set sandbox configuration for running inside a Docker sandbox.
127    fn set_sandbox(&mut self, _config: SandboxConfig) {}
128
129    /// Set additional directories for the agent to include.
130    fn set_add_dirs(&mut self, dirs: Vec<String>);
131
132    /// Get a reference to the concrete agent type (for downcasting).
133    fn as_any_ref(&self) -> &dyn std::any::Any;
134
135    /// Get a mutable reference to the concrete agent type (for downcasting).
136    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
137
138    /// Run the agent in non-interactive mode.
139    ///
140    /// Returns `Some(AgentOutput)` if the agent supports structured output
141    /// (e.g., JSON mode), otherwise returns `None`.
142    async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>>;
143
144    async fn run_interactive(&self, prompt: Option<&str>) -> Result<()>;
145
146    /// Resume a previous session.
147    ///
148    /// If `session_id` is provided, resumes that specific session.
149    /// If `last` is true, resumes the most recent session.
150    /// If neither, shows a session picker or resumes the most recent.
151    async fn run_resume(&self, session_id: Option<&str>, last: bool) -> Result<()>;
152
153    /// Resume a previous session with a new prompt (for retry/correction).
154    ///
155    /// Returns `Some(AgentOutput)` if the agent supports structured output.
156    /// Default implementation returns an error indicating unsupported operation.
157    async fn run_resume_with_prompt(
158        &self,
159        _session_id: &str,
160        _prompt: &str,
161    ) -> Result<Option<AgentOutput>> {
162        anyhow::bail!("Resume with prompt is not supported by this agent")
163    }
164
165    async fn cleanup(&self) -> Result<()>;
166}
167
168#[cfg(test)]
169#[path = "agent_tests.rs"]
170mod tests;