zag-agent 0.12.4

Core library for zag — a unified interface for AI coding agents
Documentation
use crate::output::AgentOutput;
use crate::sandbox::SandboxConfig;
use anyhow::Result;
use async_trait::async_trait;

/// Model size categories that map to agent-specific models.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelSize {
    /// Fast and lightweight model for simple tasks
    Small,
    /// Balanced model for most tasks (default)
    Medium,
    /// Most capable model for complex reasoning
    Large,
}

impl std::str::FromStr for ModelSize {
    type Err = ();

    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "small" | "s" => Ok(ModelSize::Small),
            "medium" | "m" | "default" => Ok(ModelSize::Medium),
            "large" | "l" | "max" => Ok(ModelSize::Large),
            _ => Err(()),
        }
    }
}

#[async_trait]
#[allow(dead_code)]
pub trait Agent {
    fn name(&self) -> &str;

    fn default_model() -> &'static str
    where
        Self: Sized;

    /// Get the model name for a given size category.
    fn model_for_size(size: ModelSize) -> &'static str
    where
        Self: Sized;

    /// Resolve a model input (either a size alias or specific model name).
    ///
    /// If the input is a size alias (small/medium/large), returns the
    /// corresponding model for this agent. Otherwise returns the input as-is.
    fn resolve_model(model_input: &str) -> String
    where
        Self: Sized,
    {
        if let Ok(size) = model_input.parse::<ModelSize>() {
            Self::model_for_size(size).to_string()
        } else {
            model_input.to_string()
        }
    }

    /// Get the list of available models for this agent.
    fn available_models() -> &'static [&'static str]
    where
        Self: Sized;

    /// Validate that a model name is supported by this agent.
    ///
    /// Returns Ok(()) if valid, or an error with available models if invalid.
    fn validate_model(model: &str, agent_name: &str) -> Result<()>
    where
        Self: Sized,
    {
        let available = Self::available_models();
        if available.contains(&model) {
            Ok(())
        } else {
            // Build error message with size aliases first
            let small = Self::model_for_size(ModelSize::Small);
            let medium = Self::model_for_size(ModelSize::Medium);
            let large = Self::model_for_size(ModelSize::Large);

            let mut models = vec![
                format!("{} (small)", small),
                format!("{} (medium)", medium),
                format!("{} (large)", large),
            ];

            // Add other available models that aren't already in the size mappings
            for m in available {
                if m != &small && m != &medium && m != &large {
                    models.push(m.to_string());
                }
            }

            anyhow::bail!(
                "Invalid model '{}' for {}. Available models: {}",
                model,
                agent_name,
                models.join(", ")
            )
        }
    }

    fn system_prompt(&self) -> &str;

    fn set_system_prompt(&mut self, prompt: String);

    fn get_model(&self) -> &str;

    fn set_model(&mut self, model: String);

    fn set_root(&mut self, root: String);

    fn set_skip_permissions(&mut self, skip: bool);

    fn set_output_format(&mut self, format: Option<String>);

    /// Enable output capture mode.
    ///
    /// When set, non-interactive `run()` pipes stdout, captures the text,
    /// and returns `Some(AgentOutput)`. Default is `false` (streams to terminal).
    /// Claude handles capture via output_format, so the default is a no-op.
    fn set_capture_output(&mut self, _capture: bool) {}

    /// Set the maximum number of agentic turns.
    fn set_max_turns(&mut self, _turns: u32) {}

    /// Set sandbox configuration for running inside a Docker sandbox.
    fn set_sandbox(&mut self, _config: SandboxConfig) {}

    /// Set additional directories for the agent to include.
    fn set_add_dirs(&mut self, dirs: Vec<String>);

    /// Set environment variables to pass to the agent subprocess.
    fn set_env_vars(&mut self, _vars: Vec<(String, String)>) {}

    /// Get a reference to the concrete agent type (for downcasting).
    fn as_any_ref(&self) -> &dyn std::any::Any;

    /// Get a mutable reference to the concrete agent type (for downcasting).
    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;

    /// Run the agent in non-interactive mode.
    ///
    /// Returns `Some(AgentOutput)` if the agent supports structured output
    /// (e.g., JSON mode), otherwise returns `None`.
    async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>>;

    async fn run_interactive(&self, prompt: Option<&str>) -> Result<()>;

    /// Resume a previous session.
    ///
    /// If `session_id` is provided, resumes that specific session.
    /// If `last` is true, resumes the most recent session.
    /// If neither, shows a session picker or resumes the most recent.
    async fn run_resume(&self, session_id: Option<&str>, last: bool) -> Result<()>;

    /// Resume a previous session with a new prompt (for retry/correction).
    ///
    /// Returns `Some(AgentOutput)` if the agent supports structured output.
    /// Default implementation returns an error indicating unsupported operation.
    async fn run_resume_with_prompt(
        &self,
        _session_id: &str,
        _prompt: &str,
    ) -> Result<Option<AgentOutput>> {
        anyhow::bail!("Resume with prompt is not supported by this agent")
    }

    /// Lightweight startup probe used by the provider fallback mechanism.
    ///
    /// Override this in providers that can cheaply detect a broken startup
    /// state (e.g. missing auth) without consuming paid API quota. A non-Ok
    /// return value is treated as a reason to downgrade to the next provider
    /// in the tier list when the user has not pinned a provider with `-p`.
    ///
    /// The default implementation is a no-op because pre-flight PATH lookup
    /// (`preflight::check_binary`) already catches the missing-binary case.
    async fn probe(&self) -> Result<()> {
        Ok(())
    }

    async fn cleanup(&self) -> Result<()>;
}

#[cfg(test)]
#[path = "agent_tests.rs"]
mod tests;