Skip to main content

zag_agent/
config.rs

1//! Configuration management for the zag CLI.
2//!
3//! Configuration is stored in `~/.zag/projects/<sanitized-path>/zag.toml`,
4//! where the sanitized path is derived from the git repository root or explicit `--root`.
5
6use anyhow::{Context, Result};
7use log::debug;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12/// Agent-specific model configuration.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct AgentModels {
15    pub claude: Option<String>,
16    pub codex: Option<String>,
17    pub gemini: Option<String>,
18    pub copilot: Option<String>,
19    pub ollama: Option<String>,
20}
21
22/// Ollama-specific configuration.
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct OllamaConfig {
25    /// Default model name (default: "qwen3.5")
26    pub model: Option<String>,
27    /// Default parameter size (default: "9b")
28    pub size: Option<String>,
29    /// Parameter size for small alias
30    pub size_small: Option<String>,
31    /// Parameter size for medium alias
32    pub size_medium: Option<String>,
33    /// Parameter size for large alias
34    pub size_large: Option<String>,
35}
36
37/// Default settings applied when not overridden by CLI flags.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct Defaults {
40    /// Auto-approve all actions (skip permission prompts)
41    pub auto_approve: Option<bool>,
42    /// Default model size for all agents (small, medium, large)
43    pub model: Option<String>,
44    /// Default provider (claude, codex, gemini, copilot)
45    pub provider: Option<String>,
46    /// Default maximum number of agentic turns
47    pub max_turns: Option<u32>,
48    /// Default system prompt for all agents
49    pub system_prompt: Option<String>,
50}
51
52/// Auto-selection configuration.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct AutoConfig {
55    /// Provider used for auto-selection (default: "claude")
56    pub provider: Option<String>,
57    /// Model used for auto-selection (default: "sonnet")
58    pub model: Option<String>,
59}
60
61/// Listen command configuration.
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63pub struct ListenConfig {
64    /// Default output format: "text", "json", or "rich-text"
65    pub format: Option<String>,
66    /// strftime-style format for timestamps (default: "%H:%M:%S")
67    pub timestamp_format: Option<String>,
68}
69
70/// Root configuration structure.
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct Config {
73    /// Default settings
74    #[serde(default)]
75    pub defaults: Defaults,
76    /// Per-agent model defaults
77    #[serde(default)]
78    pub models: AgentModels,
79    /// Auto-selection settings
80    #[serde(default)]
81    pub auto: AutoConfig,
82    /// Ollama-specific settings
83    #[serde(default)]
84    pub ollama: OllamaConfig,
85    /// Listen command settings
86    #[serde(default)]
87    pub listen: ListenConfig,
88}
89
90impl Config {
91    /// Load configuration from `~/.zag/projects/<id>/zag.toml`.
92    ///
93    /// The project ID is derived from the git repo root or explicit `--root`.
94    /// Returns default config if file doesn't exist.
95    pub fn load(root: Option<&str>) -> Result<Self> {
96        let path = Self::config_path(root);
97        debug!("Loading config from {}", path.display());
98        if !path.exists() {
99            debug!("Config file not found, using defaults");
100            return Ok(Self::default());
101        }
102
103        let content = std::fs::read_to_string(&path)
104            .with_context(|| format!("Failed to read config: {}", path.display()))?;
105        let config: Config = toml::from_str(&content)
106            .with_context(|| format!("Failed to parse config: {}", path.display()))?;
107        debug!("Config loaded successfully from {}", path.display());
108        Ok(config)
109    }
110
111    /// Save configuration to `~/.zag/projects/<id>/zag.toml`.
112    ///
113    /// Creates the directory if it doesn't exist.
114    pub fn save(&self, root: Option<&str>) -> Result<()> {
115        let path = Self::config_path(root);
116        debug!("Saving config to {}", path.display());
117        if let Some(parent) = path.parent() {
118            std::fs::create_dir_all(parent)
119                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
120        }
121
122        let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
123        std::fs::write(&path, content)
124            .with_context(|| format!("Failed to write config: {}", path.display()))?;
125        debug!("Config saved to {}", path.display());
126        Ok(())
127    }
128
129    /// Initialize config file with defaults if it doesn't exist.
130    ///
131    /// Returns true if a new config was created, false if it already existed.
132    pub fn init(root: Option<&str>) -> Result<bool> {
133        let path = Self::config_path(root);
134        if path.exists() {
135            debug!("Config already exists at {}", path.display());
136            return Ok(false);
137        }
138
139        debug!("Initializing new config at {}", path.display());
140        let config = Self::default_with_comments();
141        if let Some(parent) = path.parent() {
142            std::fs::create_dir_all(parent)
143                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
144        }
145
146        std::fs::write(&path, config)
147            .with_context(|| format!("Failed to write config: {}", path.display()))?;
148
149        Ok(true)
150    }
151
152    /// Detect git repository root from a given directory.
153    /// Returns None if not in a git repository.
154    fn find_git_root(start_dir: &Path) -> Option<PathBuf> {
155        let output = Command::new("git")
156            .arg("rev-parse")
157            .arg("--show-toplevel")
158            .current_dir(start_dir)
159            .output()
160            .ok()?;
161
162        if output.status.success() {
163            let root = String::from_utf8(output.stdout).ok()?;
164            Some(PathBuf::from(root.trim()))
165        } else {
166            None
167        }
168    }
169
170    /// Get the global base directory (~/.zag).
171    pub fn global_base_dir() -> PathBuf {
172        dirs::home_dir()
173            .unwrap_or_else(|| PathBuf::from("."))
174            .join(".zag")
175    }
176
177    /// Sanitize an absolute path into a directory name.
178    /// Strips leading `/` and replaces `/` with `-`.
179    pub fn sanitize_path(path: &str) -> String {
180        path.trim_start_matches('/').replace('/', "-")
181    }
182
183    /// Resolve the project directory for config/session storage.
184    ///
185    /// All state is stored under `~/.zag/`:
186    /// - Per-project: `~/.zag/projects/<sanitized-path>/`
187    /// - Global (no repo): `~/.zag/`
188    fn resolve_project_dir(root: Option<&str>) -> PathBuf {
189        let base = Self::global_base_dir();
190
191        // Keep this helper free of logging. It is used by config/session path
192        // resolution on hot paths, and debug logging here can re-enter the same
193        // resolution flow through logger setup and formatting.
194        if let Some(r) = root {
195            let sanitized = Self::sanitize_path(r);
196            return base.join("projects").join(sanitized);
197        }
198
199        let current_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
200
201        // Try to find git root
202        if let Some(git_root) = Self::find_git_root(&current_dir) {
203            let sanitized = Self::sanitize_path(&git_root.to_string_lossy());
204            return base.join("projects").join(sanitized);
205        }
206
207        // Fall back to global base directory (no project subdir)
208        base
209    }
210
211    /// Get the path to the config file.
212    pub fn config_path(root: Option<&str>) -> PathBuf {
213        Self::resolve_project_dir(root).join("zag.toml")
214    }
215
216    /// Get the project directory path (for sessions, etc.).
217    #[allow(dead_code)]
218    pub fn agent_dir(root: Option<&str>) -> PathBuf {
219        Self::resolve_project_dir(root)
220    }
221
222    /// Get the global logs directory path.
223    pub fn global_logs_dir() -> PathBuf {
224        Self::global_base_dir().join("logs")
225    }
226
227    /// Get the default model for a specific agent, if configured.
228    /// Checks agent-specific model first, then falls back to defaults.model.
229    pub fn get_model(&self, agent: &str) -> Option<&str> {
230        // First check agent-specific model
231        let agent_model = match agent {
232            "claude" => self.models.claude.as_deref(),
233            "codex" => self.models.codex.as_deref(),
234            "gemini" => self.models.gemini.as_deref(),
235            "copilot" => self.models.copilot.as_deref(),
236            "ollama" => self.models.ollama.as_deref(),
237            _ => None,
238        };
239
240        // Return agent-specific model if set, otherwise fall back to default
241        agent_model.or(self.defaults.model.as_deref())
242    }
243
244    /// Get the global default model (without agent-specific override).
245    #[allow(dead_code)]
246    pub fn default_model(&self) -> Option<&str> {
247        self.defaults.model.as_deref()
248    }
249
250    /// Get the ollama model name (default: "qwen3.5").
251    pub fn ollama_model(&self) -> &str {
252        self.ollama.model.as_deref().unwrap_or("qwen3.5")
253    }
254
255    /// Get the ollama default size (default: "9b").
256    pub fn ollama_size(&self) -> &str {
257        self.ollama.size.as_deref().unwrap_or("9b")
258    }
259
260    /// Get the ollama size for a model size alias, with config override.
261    pub fn ollama_size_for<'a>(&'a self, size: &'a str) -> &'a str {
262        match size {
263            "small" | "s" => self.ollama.size_small.as_deref().unwrap_or("2b"),
264            "medium" | "m" | "default" => self.ollama.size_medium.as_deref().unwrap_or("9b"),
265            "large" | "l" | "max" => self.ollama.size_large.as_deref().unwrap_or("35b"),
266            _ => size, // passthrough for explicit sizes like "27b"
267        }
268    }
269
270    /// Check if auto-approve is enabled by default.
271    pub fn auto_approve(&self) -> bool {
272        self.defaults.auto_approve.unwrap_or(false)
273    }
274
275    /// Get the default max turns, if configured.
276    pub fn max_turns(&self) -> Option<u32> {
277        self.defaults.max_turns
278    }
279
280    /// Get the default system prompt, if configured.
281    pub fn system_prompt(&self) -> Option<&str> {
282        self.defaults.system_prompt.as_deref()
283    }
284
285    /// Get the default provider, if configured.
286    pub fn provider(&self) -> Option<&str> {
287        self.defaults.provider.as_deref()
288    }
289
290    /// Get the auto-selection provider, if configured.
291    pub fn auto_provider(&self) -> Option<&str> {
292        self.auto.provider.as_deref()
293    }
294
295    /// Get the auto-selection model, if configured.
296    pub fn auto_model(&self) -> Option<&str> {
297        self.auto.model.as_deref()
298    }
299
300    /// Get the listen output format, if configured.
301    pub fn listen_format(&self) -> Option<&str> {
302        self.listen.format.as_deref()
303    }
304
305    /// Get the listen timestamp format (strftime-style, default: "%H:%M:%S").
306    pub fn listen_timestamp_format(&self) -> &str {
307        self.listen
308            .timestamp_format
309            .as_deref()
310            .unwrap_or("%H:%M:%S")
311    }
312
313    /// Valid provider names (including "auto").
314    #[cfg(not(test))]
315    pub const VALID_PROVIDERS: &'static [&'static str] =
316        &["claude", "codex", "gemini", "copilot", "ollama", "auto"];
317
318    /// Valid provider names (including "auto" and "mock" for testing).
319    #[cfg(test)]
320    pub const VALID_PROVIDERS: &'static [&'static str] = &[
321        "claude", "codex", "gemini", "copilot", "ollama", "auto", "mock",
322    ];
323
324    /// All valid config keys for listing/discovery.
325    pub const VALID_KEYS: &'static [&'static str] = &[
326        "provider",
327        "model",
328        "auto_approve",
329        "max_turns",
330        "system_prompt",
331        "model.claude",
332        "model.codex",
333        "model.gemini",
334        "model.copilot",
335        "model.ollama",
336        "auto.provider",
337        "auto.model",
338        "ollama.model",
339        "ollama.size",
340        "ollama.size_small",
341        "ollama.size_medium",
342        "ollama.size_large",
343        "listen.format",
344        "listen.timestamp_format",
345    ];
346
347    /// Get a config value by dot-notation key.
348    /// Get a config value by dot-notation key.
349    pub fn get_value(&self, key: &str) -> Option<String> {
350        match key {
351            "provider" => self.defaults.provider.clone(),
352            "model" => self.defaults.model.clone(),
353            "auto_approve" => self.defaults.auto_approve.map(|v| v.to_string()),
354            "max_turns" => self.defaults.max_turns.map(|v| v.to_string()),
355            "system_prompt" => self.defaults.system_prompt.clone(),
356            "model.claude" => self.models.claude.clone(),
357            "model.codex" => self.models.codex.clone(),
358            "model.gemini" => self.models.gemini.clone(),
359            "model.copilot" => self.models.copilot.clone(),
360            "model.ollama" => self.models.ollama.clone(),
361            "auto.provider" => self.auto.provider.clone(),
362            "auto.model" => self.auto.model.clone(),
363            "ollama.model" => self.ollama.model.clone(),
364            "ollama.size" => self.ollama.size.clone(),
365            "ollama.size_small" => self.ollama.size_small.clone(),
366            "ollama.size_medium" => self.ollama.size_medium.clone(),
367            "ollama.size_large" => self.ollama.size_large.clone(),
368            "listen.format" => self.listen.format.clone(),
369            "listen.timestamp_format" => self.listen.timestamp_format.clone(),
370            _ => None,
371        }
372    }
373
374    /// Set a config value by dot-notation key. Validates inputs.
375    pub fn set_value(&mut self, key: &str, value: &str) -> Result<()> {
376        debug!("Setting config: {} = {}", key, value);
377        match key {
378            "provider" => {
379                let v = value.to_lowercase();
380                if !Self::VALID_PROVIDERS.contains(&v.as_str()) {
381                    anyhow::bail!(
382                        "Invalid provider '{}'. Available: {}",
383                        value,
384                        Self::VALID_PROVIDERS.join(", ")
385                    );
386                }
387                self.defaults.provider = Some(v);
388            }
389            "model" => {
390                self.defaults.model = Some(value.to_string());
391            }
392            "max_turns" => {
393                let turns: u32 = value.parse().map_err(|_| {
394                    anyhow::anyhow!(
395                        "Invalid value '{}' for max_turns. Must be a positive integer.",
396                        value
397                    )
398                })?;
399                self.defaults.max_turns = Some(turns);
400            }
401            "system_prompt" => {
402                self.defaults.system_prompt = Some(value.to_string());
403            }
404            "auto_approve" => match value.to_lowercase().as_str() {
405                "true" | "1" | "yes" => self.defaults.auto_approve = Some(true),
406                "false" | "0" | "no" => self.defaults.auto_approve = Some(false),
407                _ => anyhow::bail!(
408                    "Invalid value '{}' for auto_approve. Use true or false.",
409                    value
410                ),
411            },
412            "model.claude" => self.models.claude = Some(value.to_string()),
413            "model.codex" => self.models.codex = Some(value.to_string()),
414            "model.gemini" => self.models.gemini = Some(value.to_string()),
415            "model.copilot" => self.models.copilot = Some(value.to_string()),
416            "model.ollama" => self.models.ollama = Some(value.to_string()),
417            "auto.provider" => self.auto.provider = Some(value.to_string()),
418            "auto.model" => self.auto.model = Some(value.to_string()),
419            "ollama.model" => self.ollama.model = Some(value.to_string()),
420            "ollama.size" => self.ollama.size = Some(value.to_string()),
421            "ollama.size_small" => self.ollama.size_small = Some(value.to_string()),
422            "ollama.size_medium" => self.ollama.size_medium = Some(value.to_string()),
423            "ollama.size_large" => self.ollama.size_large = Some(value.to_string()),
424            "listen.format" => {
425                let v = value.to_lowercase();
426                if !["text", "json", "rich-text"].contains(&v.as_str()) {
427                    anyhow::bail!(
428                        "Invalid listen format '{}'. Available: text, json, rich-text",
429                        value
430                    );
431                }
432                self.listen.format = Some(v);
433            }
434            "listen.timestamp_format" => {
435                self.listen.timestamp_format = Some(value.to_string());
436            }
437            _ => anyhow::bail!(
438                "Unknown config key '{}'. Available: provider, model, auto_approve, max_turns, system_prompt, model.claude, model.codex, model.gemini, model.copilot, model.ollama, auto.provider, auto.model, ollama.model, ollama.size, ollama.size_small, ollama.size_medium, ollama.size_large, listen.format, listen.timestamp_format",
439                key
440            ),
441        }
442        Ok(())
443    }
444
445    /// Unset a config value by dot-notation key (revert to default).
446    pub fn unset_value(&mut self, key: &str) -> Result<()> {
447        debug!("Unsetting config: {}", key);
448        match key {
449            "provider" => self.defaults.provider = None,
450            "model" => self.defaults.model = None,
451            "auto_approve" => self.defaults.auto_approve = None,
452            "max_turns" => self.defaults.max_turns = None,
453            "system_prompt" => self.defaults.system_prompt = None,
454            "model.claude" => self.models.claude = None,
455            "model.codex" => self.models.codex = None,
456            "model.gemini" => self.models.gemini = None,
457            "model.copilot" => self.models.copilot = None,
458            "model.ollama" => self.models.ollama = None,
459            "auto.provider" => self.auto.provider = None,
460            "auto.model" => self.auto.model = None,
461            "ollama.model" => self.ollama.model = None,
462            "ollama.size" => self.ollama.size = None,
463            "ollama.size_small" => self.ollama.size_small = None,
464            "ollama.size_medium" => self.ollama.size_medium = None,
465            "ollama.size_large" => self.ollama.size_large = None,
466            "listen.format" => self.listen.format = None,
467            "listen.timestamp_format" => self.listen.timestamp_format = None,
468            _ => anyhow::bail!(
469                "Unknown config key '{}'. Run 'zag config list' to see available keys.",
470                key
471            ),
472        }
473        Ok(())
474    }
475
476    /// Generate default config content with comments.
477    fn default_with_comments() -> String {
478        r#"# Zag CLI Configuration
479# This file configures default behavior for the zag CLI.
480# Settings here can be overridden by command-line flags.
481
482[defaults]
483# Default provider (claude, codex, gemini, copilot)
484# provider = "claude"
485
486# Auto-approve all actions (skip permission prompts)
487# auto_approve = false
488
489# Default model size for all agents (small, medium, large)
490# Can be overridden per-agent in [models] section
491model = "medium"
492
493# Default maximum number of agentic turns
494# max_turns = 10
495
496# Default system prompt for all agents
497# system_prompt = ""
498
499[models]
500# Default models for each agent (overrides defaults.model)
501# Use size aliases (small, medium, large) or specific model names
502# claude = "opus"
503# codex = "gpt-5.4"
504# gemini = "auto"
505# copilot = "claude-sonnet-4.6"
506
507[auto]
508# Settings for auto provider/model selection (-p auto / -m auto)
509# provider = "claude"
510# model = "haiku"
511
512[ollama]
513# Ollama-specific settings
514# model = "qwen3.5"
515# size = "9b"
516# size_small = "2b"
517# size_medium = "9b"
518# size_large = "35b"
519
520[listen]
521# Default output format for listen command: "text", "json", or "rich-text"
522# format = "text"
523# Timestamp format for --timestamps flag (strftime-style, default: "%H:%M:%S")
524# timestamp_format = "%H:%M:%S"
525"#
526        .to_string()
527    }
528}
529
530/// Resolve the provider name from a CLI flag, config default, or hardcoded fallback.
531///
532/// Validates the provider name against [`Config::VALID_PROVIDERS`].
533pub fn resolve_provider(flag: Option<&str>, root: Option<&str>) -> anyhow::Result<String> {
534    if let Some(p) = flag {
535        let p = p.to_lowercase();
536        if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
537            anyhow::bail!(
538                "Invalid provider '{}'. Available: {}",
539                p,
540                Config::VALID_PROVIDERS.join(", ")
541            );
542        }
543        return Ok(p);
544    }
545
546    let config = Config::load(root).unwrap_or_default();
547    if let Some(p) = config.provider() {
548        return Ok(p.to_string());
549    }
550
551    Ok("claude".to_string())
552}
553
554#[cfg(test)]
555#[path = "config_tests.rs"]
556mod tests;