Skip to main content

git_iris/
config.rs

1//! Configuration management for Git-Iris.
2//!
3//! Handles personal config (~/.config/git-iris/config.toml) and
4//! per-project config (.irisconfig) with proper layering.
5
6use crate::git::GitRepo;
7use crate::instruction_presets::get_instruction_preset_library;
8use crate::log_debug;
9use crate::providers::{Provider, ProviderConfig};
10
11use anyhow::{Context, Result, anyhow};
12use dirs::{config_dir, home_dir};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs;
16use std::path::{Path, PathBuf};
17
18/// Project configuration filename
19pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
20
21/// Main configuration structure
22#[derive(Deserialize, Serialize, Clone, Debug)]
23pub struct Config {
24    /// Default LLM provider
25    #[serde(default, skip_serializing_if = "String::is_empty")]
26    pub default_provider: String,
27    /// Provider-specific configurations (keyed by provider name)
28    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
29    pub providers: HashMap<String, ProviderConfig>,
30    /// Use gitmoji in commit messages
31    #[serde(default = "default_true", skip_serializing_if = "is_true")]
32    pub use_gitmoji: bool,
33    /// Custom instructions for all operations
34    #[serde(default, skip_serializing_if = "String::is_empty")]
35    pub instructions: String,
36    /// Instruction preset name
37    #[serde(default = "default_preset", skip_serializing_if = "is_default_preset")]
38    pub instruction_preset: String,
39    /// Theme name (empty = default `SilkCircuit` Neon)
40    #[serde(default, skip_serializing_if = "String::is_empty")]
41    pub theme: String,
42    /// Timeout in seconds for parallel subagent tasks (default: 120)
43    #[serde(
44        default = "default_subagent_timeout",
45        skip_serializing_if = "is_default_subagent_timeout"
46    )]
47    pub subagent_timeout_secs: u64,
48    /// Turn budget for parallel subagent tasks (default: 20)
49    #[serde(
50        default = "default_subagent_max_turns",
51        skip_serializing_if = "is_default_subagent_max_turns"
52    )]
53    pub subagent_max_turns: usize,
54    /// Run a critic verification pass after generated artifacts (default: true)
55    #[serde(default = "default_true", skip_serializing_if = "is_true")]
56    pub critic_enabled: bool,
57    /// Runtime-only: temporary instructions override
58    #[serde(skip)]
59    pub temp_instructions: Option<String>,
60    /// Runtime-only: temporary preset override
61    #[serde(skip)]
62    pub temp_preset: Option<String>,
63    /// Runtime-only: flag if loaded from project config
64    #[serde(skip)]
65    pub is_project_config: bool,
66    /// Runtime-only: whether gitmoji was explicitly set via CLI (None = use style detection)
67    #[serde(skip)]
68    pub gitmoji_override: Option<bool>,
69}
70
71fn default_true() -> bool {
72    true
73}
74
75#[allow(clippy::trivially_copy_pass_by_ref)]
76fn is_true(val: &bool) -> bool {
77    *val
78}
79
80fn default_preset() -> String {
81    "default".to_string()
82}
83
84fn is_default_preset(val: &str) -> bool {
85    val.is_empty() || val == "default"
86}
87
88fn default_subagent_timeout() -> u64 {
89    120 // 2 minutes
90}
91
92#[allow(clippy::trivially_copy_pass_by_ref)]
93fn is_default_subagent_timeout(val: &u64) -> bool {
94    *val == 120
95}
96
97fn default_subagent_max_turns() -> usize {
98    20
99}
100
101#[allow(clippy::trivially_copy_pass_by_ref)]
102fn is_default_subagent_max_turns(val: &usize) -> bool {
103    *val == 20
104}
105
106impl Default for Config {
107    fn default() -> Self {
108        let mut providers = HashMap::new();
109        for provider in Provider::ALL {
110            providers.insert(
111                provider.name().to_string(),
112                ProviderConfig::with_defaults(*provider),
113            );
114        }
115
116        Self {
117            default_provider: Provider::default().name().to_string(),
118            providers,
119            use_gitmoji: true,
120            instructions: String::new(),
121            instruction_preset: default_preset(),
122            theme: String::new(),
123            subagent_timeout_secs: default_subagent_timeout(),
124            subagent_max_turns: default_subagent_max_turns(),
125            critic_enabled: true,
126            temp_instructions: None,
127            temp_preset: None,
128            is_project_config: false,
129            gitmoji_override: None,
130        }
131    }
132}
133
134impl Config {
135    /// Load configuration (personal + project overlay)
136    ///
137    /// # Errors
138    ///
139    /// Returns an error when personal or project configuration cannot be read or parsed.
140    pub fn load() -> Result<Self> {
141        let config_path = Self::get_personal_config_path()?;
142        let mut config = if config_path.exists() {
143            let content = fs::read_to_string(&config_path)?;
144            let parsed: Self = toml::from_str(&content)?;
145            let (migrated, needs_save) = Self::migrate_if_needed(parsed);
146            if needs_save && let Err(e) = migrated.save() {
147                log_debug!("Failed to save migrated config: {}", e);
148            }
149            migrated
150        } else {
151            Self::default()
152        };
153
154        // Overlay project config if available
155        if let Ok((project_config, project_source)) = Self::load_project_config_with_source() {
156            config.merge_loaded_project_config(project_config, &project_source);
157        }
158
159        log_debug!(
160            "Configuration loaded (provider: {}, gitmoji: {})",
161            config.default_provider,
162            config.use_gitmoji
163        );
164        Ok(config)
165    }
166
167    /// Load project-specific configuration
168    ///
169    /// # Errors
170    ///
171    /// Returns an error when the project configuration file is missing or invalid.
172    pub fn load_project_config() -> Result<Self> {
173        let (config, _) = Self::load_project_config_with_source()?;
174        Ok(config)
175    }
176
177    fn load_project_config_with_source() -> Result<(Self, toml::Value)> {
178        let config_path = Self::get_project_config_path()?;
179        if !config_path.exists() {
180            return Err(anyhow!("Project configuration file not found"));
181        }
182
183        let content = fs::read_to_string(&config_path)
184            .with_context(|| format!("Failed to read {}", config_path.display()))?;
185        let project_source = toml::from_str(&content).with_context(|| {
186            format!(
187                "Invalid {} format. Check for syntax errors.",
188                PROJECT_CONFIG_FILENAME
189            )
190        })?;
191
192        let mut config: Self = toml::from_str(&content).with_context(|| {
193            format!(
194                "Invalid {} format. Check for syntax errors.",
195                PROJECT_CONFIG_FILENAME
196            )
197        })?;
198
199        config.is_project_config = true;
200        Ok((config, project_source))
201    }
202
203    /// Get path to project config file
204    ///
205    /// # Errors
206    ///
207    /// Returns an error when the current repository root cannot be resolved.
208    pub fn get_project_config_path() -> Result<PathBuf> {
209        let repo_root = GitRepo::get_repo_root()?;
210        Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
211    }
212
213    /// Merge project config into this config (project takes precedence, but never API keys)
214    pub fn merge_with_project_config(&mut self, project_config: Self) {
215        log_debug!("Merging with project configuration");
216
217        // Override default provider if set
218        if !project_config.default_provider.is_empty()
219            && project_config.default_provider != Provider::default().name()
220        {
221            self.default_provider = project_config.default_provider;
222        }
223
224        // Merge provider configs (never override API keys from project config)
225        for (provider_name, proj_config) in project_config.providers {
226            let entry = self.providers.entry(provider_name).or_default();
227
228            if !proj_config.model.is_empty() {
229                entry.model = proj_config.model;
230            }
231            if proj_config.fast_model.is_some() {
232                entry.fast_model = proj_config.fast_model;
233            }
234            if proj_config.token_limit.is_some() {
235                entry.token_limit = proj_config.token_limit;
236            }
237            entry
238                .additional_params
239                .extend(proj_config.additional_params);
240        }
241
242        // Override other settings
243        self.use_gitmoji = project_config.use_gitmoji;
244        self.instructions = project_config.instructions;
245
246        if project_config.instruction_preset != default_preset() {
247            self.instruction_preset = project_config.instruction_preset;
248        }
249
250        // Theme override
251        if !project_config.theme.is_empty() {
252            self.theme = project_config.theme;
253        }
254
255        // Subagent timeout override
256        if project_config.subagent_timeout_secs != default_subagent_timeout() {
257            self.subagent_timeout_secs = project_config.subagent_timeout_secs;
258        }
259        if project_config.subagent_max_turns != default_subagent_max_turns() {
260            self.subagent_max_turns = project_config.subagent_max_turns;
261        }
262    }
263
264    fn merge_loaded_project_config(&mut self, project_config: Self, project_source: &toml::Value) {
265        log_debug!("Merging loaded project configuration with explicit field tracking");
266
267        self.merge_project_provider_config(&project_config);
268
269        if Self::project_config_has_key(project_source, "default_provider") {
270            self.default_provider = project_config.default_provider;
271        }
272        if Self::project_config_has_key(project_source, "use_gitmoji") {
273            self.use_gitmoji = project_config.use_gitmoji;
274        }
275        if Self::project_config_has_key(project_source, "instructions") {
276            self.instructions = project_config.instructions;
277        }
278        if Self::project_config_has_key(project_source, "instruction_preset") {
279            self.instruction_preset = project_config.instruction_preset;
280        }
281        if Self::project_config_has_key(project_source, "theme") {
282            self.theme = project_config.theme;
283        }
284        if Self::project_config_has_key(project_source, "subagent_timeout_secs") {
285            self.subagent_timeout_secs = project_config.subagent_timeout_secs;
286        }
287        if Self::project_config_has_key(project_source, "subagent_max_turns") {
288            self.subagent_max_turns = project_config.subagent_max_turns;
289        }
290        if Self::project_config_has_key(project_source, "critic_enabled") {
291            self.critic_enabled = project_config.critic_enabled;
292        }
293    }
294
295    fn merge_project_provider_config(&mut self, project_config: &Self) {
296        for (provider_name, proj_config) in &project_config.providers {
297            let entry = self.providers.entry(provider_name.clone()).or_default();
298
299            if !proj_config.model.is_empty() {
300                proj_config.model.clone_into(&mut entry.model);
301            }
302            if proj_config.fast_model.is_some() {
303                entry.fast_model.clone_from(&proj_config.fast_model);
304            }
305            if proj_config.token_limit.is_some() {
306                entry.token_limit = proj_config.token_limit;
307            }
308            entry
309                .additional_params
310                .extend(proj_config.additional_params.clone());
311        }
312    }
313
314    fn project_config_has_key(project_source: &toml::Value, key: &str) -> bool {
315        project_source
316            .as_table()
317            .is_some_and(|table| table.contains_key(key))
318    }
319
320    /// Migrate older config formats. Pure — never touches the filesystem.
321    ///
322    /// Returns the (possibly updated) config and a flag indicating whether any
323    /// migration actually happened. Callers that loaded from disk (i.e. `load`)
324    /// are responsible for persisting the migrated form; tests and other
325    /// in-memory users can ignore the flag. Keeping this pure stops test
326    /// fixtures from clobbering the user's real config file.
327    fn migrate_if_needed(mut config: Self) -> (Self, bool) {
328        let mut migrated = false;
329
330        for (legacy, canonical) in [("claude", "anthropic"), ("gemini", "google")] {
331            if let Some(legacy_config) = config.providers.remove(legacy) {
332                log_debug!("Migrating '{legacy}' provider to '{canonical}'");
333
334                if config.providers.contains_key(canonical) {
335                    log_debug!(
336                        "Keeping existing '{canonical}' config and dropping legacy '{legacy}' entry"
337                    );
338                } else {
339                    config
340                        .providers
341                        .insert(canonical.to_string(), legacy_config);
342                }
343
344                migrated = true;
345            }
346
347            if config.default_provider.eq_ignore_ascii_case(legacy) {
348                config.default_provider = canonical.to_string();
349                migrated = true;
350            }
351        }
352
353        (config, migrated)
354    }
355
356    /// Save configuration to personal config file
357    ///
358    /// # Errors
359    ///
360    /// Returns an error when the personal configuration file cannot be serialized or written.
361    pub fn save(&self) -> Result<()> {
362        if self.is_project_config {
363            return Ok(());
364        }
365
366        let config_path = Self::get_personal_config_path()?;
367        let content = toml::to_string_pretty(self)?;
368        Self::write_config_file(&config_path, &content)?;
369        log_debug!("Configuration saved");
370        Ok(())
371    }
372
373    /// Save as project-specific configuration (strips API keys)
374    ///
375    /// # Errors
376    ///
377    /// Returns an error when the project configuration file cannot be serialized or written.
378    pub fn save_as_project_config(&self) -> Result<()> {
379        let config_path = Self::get_project_config_path()?;
380
381        let mut project_config = self.clone();
382        project_config.is_project_config = true;
383
384        // Strip API keys for security
385        for provider_config in project_config.providers.values_mut() {
386            provider_config.api_key.clear();
387        }
388
389        let content = toml::to_string_pretty(&project_config)?;
390        Self::write_config_file(&config_path, &content)?;
391        Ok(())
392    }
393
394    /// Write content to a config file with restricted permissions.
395    ///
396    /// On Unix, creates a temp file with 0o600 permissions first, writes content,
397    /// then renames into place — so the target path is never world-readable.
398    /// Warns (via stderr) if permission hardening fails rather than silently ignoring.
399    fn write_config_file(path: &Path, content: &str) -> Result<()> {
400        #[cfg(unix)]
401        {
402            use std::os::unix::fs::PermissionsExt;
403
404            // Write to a sibling temp file so rename is atomic on the same filesystem
405            let tmp_path = path.with_extension("tmp");
406            fs::write(&tmp_path, content)?;
407            if let Err(e) = fs::set_permissions(&tmp_path, fs::Permissions::from_mode(0o600)) {
408                eprintln!(
409                    "Warning: Could not restrict config permissions on {}: {e}",
410                    tmp_path.display()
411                );
412            }
413            fs::rename(&tmp_path, path)?;
414        }
415
416        #[cfg(not(unix))]
417        {
418            fs::write(path, content)?;
419        }
420
421        Ok(())
422    }
423
424    /// Resolve the directory that should hold `config.toml`.
425    ///
426    /// Precedence:
427    /// 1. `$XDG_CONFIG_HOME/git-iris` when the env var is set and non-empty.
428    /// 2. `~/Library/Application Support/git-iris` on macOS **only** when a
429    ///    config already exists there — this keeps pre-XDG installs working.
430    /// 3. `$HOME/.config/git-iris` — the XDG-style default that lines up with
431    ///    how `gh`, `neovim`, `bat`, `ripgrep`, `helix`, `starship`, and the
432    ///    rest of the modern CLI ecosystem behave on macOS.
433    /// 4. `dirs::config_dir()/git-iris` as a last-resort fallback when `$HOME`
434    ///    is unreachable (should only happen in exotic sandboxes).
435    ///
436    /// This function is pure — filesystem probing for the legacy macOS path
437    /// happens in `get_personal_config_path` so the resolver stays easy to
438    /// unit-test with synthetic inputs.
439    fn resolve_personal_config_dir(
440        xdg_config_home: Option<PathBuf>,
441        home_dir: Option<PathBuf>,
442        platform_config_dir: Option<PathBuf>,
443        legacy_macos_config_exists: bool,
444    ) -> Result<PathBuf> {
445        if let Some(xdg) = xdg_config_home.filter(|path| !path.as_os_str().is_empty()) {
446            return Ok(xdg.join("git-iris"));
447        }
448
449        if legacy_macos_config_exists && let Some(platform) = platform_config_dir.clone() {
450            return Ok(platform.join("git-iris"));
451        }
452
453        if let Some(home) = home_dir {
454            return Ok(home.join(".config").join("git-iris"));
455        }
456
457        platform_config_dir
458            .map(|p| p.join("git-iris"))
459            .ok_or_else(|| anyhow!("Unable to determine config directory"))
460    }
461
462    /// Get path to personal config file
463    ///
464    /// # Errors
465    ///
466    /// Returns an error when the config directory cannot be resolved or created.
467    pub fn get_personal_config_path() -> Result<PathBuf> {
468        let platform_dir = config_dir();
469
470        // Only probe the legacy macOS location on macOS. On every other
471        // platform `dirs::config_dir()` already maps to `$HOME/.config` (or an
472        // equivalent), so treating the existence check as macOS-only avoids
473        // falsely flagging a Linux user's `~/.config/git-iris` as "legacy".
474        let legacy_macos_config_exists = cfg!(target_os = "macos")
475            && platform_dir
476                .as_ref()
477                .is_some_and(|dir| dir.join("git-iris").join("config.toml").exists());
478
479        let mut path = Self::resolve_personal_config_dir(
480            std::env::var_os("XDG_CONFIG_HOME").map(PathBuf::from),
481            home_dir(),
482            platform_dir,
483            legacy_macos_config_exists,
484        )?;
485        fs::create_dir_all(&path)?;
486        path.push("config.toml");
487        Ok(path)
488    }
489
490    /// Check environment prerequisites
491    ///
492    /// # Errors
493    ///
494    /// Returns an error when the current working directory is not inside a Git repository.
495    pub fn check_environment(&self) -> Result<()> {
496        if !GitRepo::is_inside_work_tree()? {
497            return Err(anyhow!(
498                "Not in a Git repository. Please run this command from within a Git repository."
499            ));
500        }
501        Ok(())
502    }
503
504    /// Set temporary instructions for this session
505    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
506        self.temp_instructions = instructions;
507    }
508
509    /// Set temporary preset for this session
510    pub fn set_temp_preset(&mut self, preset: Option<String>) {
511        self.temp_preset = preset;
512    }
513
514    /// Get effective preset name (temp overrides saved)
515    #[must_use]
516    pub fn get_effective_preset_name(&self) -> &str {
517        self.temp_preset
518            .as_deref()
519            .unwrap_or(&self.instruction_preset)
520    }
521
522    /// Get effective instructions (combines preset + custom)
523    #[must_use]
524    pub fn get_effective_instructions(&self) -> String {
525        let preset_library = get_instruction_preset_library();
526        let preset_instructions = self
527            .temp_preset
528            .as_ref()
529            .or(Some(&self.instruction_preset))
530            .and_then(|p| preset_library.get_preset(p))
531            .map(|p| p.instructions.clone())
532            .unwrap_or_default();
533
534        let custom = self
535            .temp_instructions
536            .as_ref()
537            .unwrap_or(&self.instructions);
538
539        format!("{preset_instructions}\n\n{custom}")
540            .trim()
541            .to_string()
542    }
543
544    /// Update configuration with new values
545    #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
546    ///
547    /// # Errors
548    ///
549    /// Returns an error when the provider is invalid or the provider config cannot be updated.
550    pub fn update(
551        &mut self,
552        provider: Option<String>,
553        api_key: Option<String>,
554        model: Option<String>,
555        fast_model: Option<String>,
556        additional_params: Option<HashMap<String, String>>,
557        use_gitmoji: Option<bool>,
558        instructions: Option<String>,
559        token_limit: Option<usize>,
560    ) -> Result<()> {
561        if let Some(ref provider_name) = provider {
562            // Validate provider
563            let parsed: Provider = provider_name.parse().with_context(|| {
564                format!(
565                    "Unknown provider '{}'. Supported: {}",
566                    provider_name,
567                    Provider::all_names().join(", ")
568                )
569            })?;
570
571            self.default_provider = parsed.name().to_string();
572
573            // Ensure provider config exists
574            if !self.providers.contains_key(parsed.name()) {
575                self.providers.insert(
576                    parsed.name().to_string(),
577                    ProviderConfig::with_defaults(parsed),
578                );
579            }
580        }
581
582        let provider_config = self
583            .providers
584            .get_mut(&self.default_provider)
585            .context("Could not get default provider config")?;
586
587        if let Some(key) = api_key {
588            provider_config.api_key = key;
589        }
590        if let Some(m) = model {
591            provider_config.model = m;
592        }
593        if let Some(fm) = fast_model {
594            provider_config.fast_model = Some(fm);
595        }
596        if let Some(params) = additional_params {
597            provider_config.additional_params.extend(params);
598        }
599        if let Some(gitmoji) = use_gitmoji {
600            self.use_gitmoji = gitmoji;
601        }
602        if let Some(instr) = instructions {
603            self.instructions = instr;
604        }
605        if let Some(limit) = token_limit {
606            provider_config.token_limit = Some(limit);
607        }
608
609        log_debug!("Configuration updated");
610        Ok(())
611    }
612
613    /// Get the provider configuration for a specific provider
614    #[must_use]
615    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
616        // Handle legacy/common aliases
617        let name = if provider.eq_ignore_ascii_case("claude") {
618            "anthropic"
619        } else if provider.eq_ignore_ascii_case("gemini") {
620            "google"
621        } else {
622            provider
623        };
624
625        self.providers
626            .get(name)
627            .or_else(|| self.providers.get(&name.to_lowercase()))
628    }
629
630    /// Get the current provider as `Provider` enum
631    #[must_use]
632    pub fn provider(&self) -> Option<Provider> {
633        self.default_provider.parse().ok()
634    }
635
636    /// Validate that the current provider is properly configured
637    ///
638    /// # Errors
639    ///
640    /// Returns an error when the provider is invalid or no API key is configured.
641    pub fn validate(&self) -> Result<()> {
642        let provider: Provider = self
643            .default_provider
644            .parse()
645            .with_context(|| format!("Invalid provider: {}", self.default_provider))?;
646
647        let config = self
648            .get_provider_config(provider.name())
649            .ok_or_else(|| anyhow!("No configuration found for provider: {}", provider.name()))?;
650
651        if !config.has_api_key() {
652            // Check environment variable as fallback
653            if std::env::var(provider.api_key_env()).is_err() {
654                return Err(anyhow!(
655                    "API key required for {}. Set {} or configure in ~/.config/git-iris/config.toml",
656                    provider.name(),
657                    provider.api_key_env()
658                ));
659            }
660        }
661
662        Ok(())
663    }
664}
665
666#[cfg(test)]
667mod tests;