Skip to main content

auto_commit_rs/
config.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6const DEFAULT_SYSTEM_PROMPT: &str = "You are to act as an author of a commit message in git. \
7I'll send you an output of 'git diff --staged' command, and you are to convert \
8it into a commit message. Follow the Conventional Commits specification.";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AppConfig {
12    #[serde(default = "default_provider")]
13    pub provider: String,
14    #[serde(default = "default_model")]
15    pub model: String,
16    #[serde(default)]
17    pub api_key: String,
18    #[serde(default)]
19    pub api_url: String,
20    #[serde(default)]
21    pub api_headers: String,
22    #[serde(default = "default_locale")]
23    pub locale: String,
24    #[serde(default = "default_true")]
25    pub one_liner: bool,
26    #[serde(default = "default_commit_template")]
27    pub commit_template: String,
28    #[serde(default = "default_system_prompt")]
29    pub llm_system_prompt: String,
30    #[serde(default)]
31    pub use_gitmoji: bool,
32    #[serde(default = "default_gitmoji_format")]
33    pub gitmoji_format: String,
34    #[serde(default)]
35    pub review_commit: bool,
36    #[serde(default = "default_post_commit_push")]
37    pub post_commit_push: String,
38    #[serde(default)]
39    pub suppress_tool_output: bool,
40    #[serde(default = "default_true")]
41    pub warn_staged_files_enabled: bool,
42    #[serde(default = "default_warn_staged_files_threshold")]
43    pub warn_staged_files_threshold: usize,
44    #[serde(default = "default_true")]
45    pub confirm_new_version: bool,
46}
47
48fn default_provider() -> String {
49    "groq".into()
50}
51fn default_model() -> String {
52    "llama-3.3-70b-versatile".into()
53}
54fn default_locale() -> String {
55    "en".into()
56}
57fn default_true() -> bool {
58    true
59}
60fn default_post_commit_push() -> String {
61    "ask".into()
62}
63fn default_commit_template() -> String {
64    "$msg".into()
65}
66fn default_system_prompt() -> String {
67    DEFAULT_SYSTEM_PROMPT.into()
68}
69fn default_gitmoji_format() -> String {
70    "unicode".into()
71}
72fn default_warn_staged_files_threshold() -> usize {
73    20
74}
75
76impl Default for AppConfig {
77    fn default() -> Self {
78        Self {
79            provider: default_provider(),
80            model: default_model(),
81            api_key: String::new(),
82            api_url: String::new(),
83            api_headers: String::new(),
84            locale: default_locale(),
85            one_liner: true,
86            commit_template: default_commit_template(),
87            llm_system_prompt: default_system_prompt(),
88            use_gitmoji: false,
89            gitmoji_format: default_gitmoji_format(),
90            review_commit: false,
91            post_commit_push: default_post_commit_push(),
92            suppress_tool_output: false,
93            warn_staged_files_enabled: true,
94            warn_staged_files_threshold: default_warn_staged_files_threshold(),
95            confirm_new_version: true,
96        }
97    }
98}
99
100/// Map of ACR_ env var suffix → struct field name
101const ENV_FIELD_MAP: &[(&str, &str)] = &[
102    ("PROVIDER", "provider"),
103    ("MODEL", "model"),
104    ("API_KEY", "api_key"),
105    ("API_URL", "api_url"),
106    ("API_HEADERS", "api_headers"),
107    ("LOCALE", "locale"),
108    ("ONE_LINER", "one_liner"),
109    ("COMMIT_TEMPLATE", "commit_template"),
110    ("LLM_SYSTEM_PROMPT", "llm_system_prompt"),
111    ("USE_GITMOJI", "use_gitmoji"),
112    ("GITMOJI_FORMAT", "gitmoji_format"),
113    ("REVIEW_COMMIT", "review_commit"),
114    ("POST_COMMIT_PUSH", "post_commit_push"),
115    ("SUPPRESS_TOOL_OUTPUT", "suppress_tool_output"),
116    ("WARN_STAGED_FILES_ENABLED", "warn_staged_files_enabled"),
117    ("WARN_STAGED_FILES_THRESHOLD", "warn_staged_files_threshold"),
118    ("CONFIRM_NEW_VERSION", "confirm_new_version"),
119];
120
121impl AppConfig {
122    /// Load config with layered resolution: defaults → global TOML → local .env → env vars
123    pub fn load() -> Result<Self> {
124        let mut cfg = Self::default();
125
126        // Layer 1: Global TOML
127        if let Some(path) = global_config_path() {
128            if path.exists() {
129                let content = std::fs::read_to_string(&path)
130                    .with_context(|| format!("Failed to read {}", path.display()))?;
131                let file_cfg: AppConfig = toml::from_str(&content)
132                    .with_context(|| format!("Failed to parse {}", path.display()))?;
133                cfg.merge_from(&file_cfg);
134            }
135        }
136
137        // Layer 2: Local .env (in git repo root)
138        if let Ok(root) = crate::git::find_repo_root() {
139            let env_path = PathBuf::from(&root).join(".env");
140            if env_path.exists() {
141                let env_map = parse_dotenv(&env_path)?;
142                cfg.apply_env_map(&env_map);
143            }
144        }
145
146        // Layer 3: Actual environment variables
147        let mut env_map = HashMap::new();
148        for (suffix, _) in ENV_FIELD_MAP {
149            let key = format!("ACR_{suffix}");
150            if let Ok(val) = std::env::var(&key) {
151                env_map.insert(key, val);
152            }
153        }
154        cfg.apply_env_map(&env_map);
155        cfg.ensure_valid_locale()?;
156
157        Ok(cfg)
158    }
159
160    fn merge_from(&mut self, other: &AppConfig) {
161        if !other.provider.is_empty() {
162            self.provider = other.provider.clone();
163        }
164        if !other.model.is_empty() {
165            self.model = other.model.clone();
166        }
167        if !other.api_key.is_empty() {
168            self.api_key = other.api_key.clone();
169        }
170        if !other.api_url.is_empty() {
171            self.api_url = other.api_url.clone();
172        }
173        if !other.api_headers.is_empty() {
174            self.api_headers = other.api_headers.clone();
175        }
176        if !other.locale.is_empty() {
177            self.locale = other.locale.clone();
178        }
179        self.one_liner = other.one_liner;
180        if !other.commit_template.is_empty() {
181            self.commit_template = other.commit_template.clone();
182        }
183        if !other.llm_system_prompt.is_empty() {
184            self.llm_system_prompt = other.llm_system_prompt.clone();
185        }
186        self.use_gitmoji = other.use_gitmoji;
187        if !other.gitmoji_format.is_empty() {
188            self.gitmoji_format = other.gitmoji_format.clone();
189        }
190        self.review_commit = other.review_commit;
191        if !other.post_commit_push.is_empty() {
192            self.post_commit_push = normalize_post_commit_push(&other.post_commit_push);
193        }
194        self.suppress_tool_output = other.suppress_tool_output;
195        self.warn_staged_files_enabled = other.warn_staged_files_enabled;
196        self.warn_staged_files_threshold = other.warn_staged_files_threshold;
197        self.confirm_new_version = other.confirm_new_version;
198    }
199
200    fn apply_env_map(&mut self, map: &HashMap<String, String>) {
201        for (suffix, _field) in ENV_FIELD_MAP {
202            let key = format!("ACR_{suffix}");
203            if let Some(val) = map.get(&key) {
204                match *suffix {
205                    "PROVIDER" => self.provider = val.clone(),
206                    "MODEL" => self.model = val.clone(),
207                    "API_KEY" => self.api_key = val.clone(),
208                    "API_URL" => self.api_url = val.clone(),
209                    "API_HEADERS" => self.api_headers = val.clone(),
210                    "LOCALE" => self.locale = val.clone(),
211                    "ONE_LINER" => self.one_liner = val == "1" || val.eq_ignore_ascii_case("true"),
212                    "COMMIT_TEMPLATE" => self.commit_template = val.clone(),
213                    "LLM_SYSTEM_PROMPT" => self.llm_system_prompt = val.clone(),
214                    "USE_GITMOJI" => {
215                        self.use_gitmoji = val == "1" || val.eq_ignore_ascii_case("true")
216                    }
217                    "GITMOJI_FORMAT" => self.gitmoji_format = val.clone(),
218                    "REVIEW_COMMIT" => {
219                        self.review_commit = val == "1" || val.eq_ignore_ascii_case("true")
220                    }
221                    "POST_COMMIT_PUSH" => self.post_commit_push = normalize_post_commit_push(val),
222                    "SUPPRESS_TOOL_OUTPUT" => {
223                        self.suppress_tool_output = val == "1" || val.eq_ignore_ascii_case("true")
224                    }
225                    "WARN_STAGED_FILES_ENABLED" => {
226                        self.warn_staged_files_enabled =
227                            val == "1" || val.eq_ignore_ascii_case("true")
228                    }
229                    "WARN_STAGED_FILES_THRESHOLD" => {
230                        self.warn_staged_files_threshold =
231                            parse_usize_or_default(val, default_warn_staged_files_threshold());
232                    }
233                    "CONFIRM_NEW_VERSION" => {
234                        self.confirm_new_version = val == "1" || val.eq_ignore_ascii_case("true")
235                    }
236                    _ => {}
237                }
238            }
239        }
240    }
241
242    /// Save to global TOML config file
243    pub fn save_global(&self) -> Result<()> {
244        let path = global_config_path().context("Could not determine global config directory")?;
245        if let Some(parent) = path.parent() {
246            std::fs::create_dir_all(parent)
247                .with_context(|| format!("Failed to create {}", parent.display()))?;
248        }
249        let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
250        std::fs::write(&path, content)
251            .with_context(|| format!("Failed to write {}", path.display()))?;
252        Ok(())
253    }
254
255    /// Save to local .env file in the git repo root
256    pub fn save_local(&self) -> Result<()> {
257        let root = crate::git::find_repo_root().context("Not in a git repository")?;
258        let env_path = PathBuf::from(&root).join(".env");
259
260        let mut lines = Vec::new();
261        lines.push(format!("ACR_PROVIDER={}", self.provider));
262        lines.push(format!("ACR_MODEL={}", self.model));
263        if !self.api_key.is_empty() {
264            lines.push(format!("ACR_API_KEY={}", self.api_key));
265        }
266        if !self.api_url.is_empty() {
267            lines.push(format!("ACR_API_URL={}", self.api_url));
268        }
269        if !self.api_headers.is_empty() {
270            lines.push(format!("ACR_API_HEADERS={}", self.api_headers));
271        }
272        lines.push(format!("ACR_LOCALE={}", self.locale));
273        lines.push(format!(
274            "ACR_ONE_LINER={}",
275            if self.one_liner { "1" } else { "0" }
276        ));
277        if self.commit_template != "$msg" {
278            lines.push(format!("ACR_COMMIT_TEMPLATE={}", self.commit_template));
279        }
280        if self.llm_system_prompt != DEFAULT_SYSTEM_PROMPT {
281            lines.push(format!("ACR_LLM_SYSTEM_PROMPT={}", self.llm_system_prompt));
282        }
283        lines.push(format!(
284            "ACR_USE_GITMOJI={}",
285            if self.use_gitmoji { "1" } else { "0" }
286        ));
287        lines.push(format!("ACR_GITMOJI_FORMAT={}", self.gitmoji_format));
288        lines.push(format!(
289            "ACR_REVIEW_COMMIT={}",
290            if self.review_commit { "1" } else { "0" }
291        ));
292        lines.push(format!(
293            "ACR_POST_COMMIT_PUSH={}",
294            normalize_post_commit_push(&self.post_commit_push)
295        ));
296        lines.push(format!(
297            "ACR_SUPPRESS_TOOL_OUTPUT={}",
298            if self.suppress_tool_output { "1" } else { "0" }
299        ));
300        lines.push(format!(
301            "ACR_WARN_STAGED_FILES_ENABLED={}",
302            if self.warn_staged_files_enabled {
303                "1"
304            } else {
305                "0"
306            }
307        ));
308        lines.push(format!(
309            "ACR_WARN_STAGED_FILES_THRESHOLD={}",
310            self.warn_staged_files_threshold
311        ));
312        lines.push(format!(
313            "ACR_CONFIRM_NEW_VERSION={}",
314            if self.confirm_new_version { "1" } else { "0" }
315        ));
316
317        std::fs::write(&env_path, lines.join("\n") + "\n")
318            .with_context(|| format!("Failed to write {}", env_path.display()))?;
319        Ok(())
320    }
321
322    /// Get all fields as (display_name, env_suffix, current_value) tuples
323    pub fn fields_display(&self) -> Vec<(&'static str, &'static str, String)> {
324        vec![
325            ("Provider", "PROVIDER", self.provider.clone()),
326            ("Model", "MODEL", self.model.clone()),
327            (
328                "API Key",
329                "API_KEY",
330                if self.api_key.is_empty() {
331                    "(not set)".into()
332                } else {
333                    mask_key(&self.api_key)
334                },
335            ),
336            (
337                "API URL",
338                "API_URL",
339                if self.api_url.is_empty() {
340                    "(auto from provider)".into()
341                } else {
342                    self.api_url.clone()
343                },
344            ),
345            (
346                "API Headers",
347                "API_HEADERS",
348                if self.api_headers.is_empty() {
349                    "(auto from provider)".into()
350                } else {
351                    self.api_headers.clone()
352                },
353            ),
354            ("Locale", "LOCALE", self.locale.clone()),
355            (
356                "One-liner",
357                "ONE_LINER",
358                if self.one_liner {
359                    "1 (yes)".into()
360                } else {
361                    "0 (no)".into()
362                },
363            ),
364            (
365                "Commit Template",
366                "COMMIT_TEMPLATE",
367                self.commit_template.clone(),
368            ),
369            (
370                "System Prompt",
371                "LLM_SYSTEM_PROMPT",
372                truncate(&self.llm_system_prompt, 60),
373            ),
374            (
375                "Use Gitmoji",
376                "USE_GITMOJI",
377                if self.use_gitmoji {
378                    "1 (yes)".into()
379                } else {
380                    "0 (no)".into()
381                },
382            ),
383            (
384                "Gitmoji Format",
385                "GITMOJI_FORMAT",
386                self.gitmoji_format.clone(),
387            ),
388            (
389                "Review Commit",
390                "REVIEW_COMMIT",
391                if self.review_commit {
392                    "1 (yes)".into()
393                } else {
394                    "0 (no)".into()
395                },
396            ),
397            (
398                "Post Commit Push",
399                "POST_COMMIT_PUSH",
400                normalize_post_commit_push(&self.post_commit_push),
401            ),
402            (
403                "Suppress Tool Output",
404                "SUPPRESS_TOOL_OUTPUT",
405                if self.suppress_tool_output {
406                    "1 (yes)".into()
407                } else {
408                    "0 (no)".into()
409                },
410            ),
411            (
412                "Warn Staged Files",
413                "WARN_STAGED_FILES_ENABLED",
414                if self.warn_staged_files_enabled {
415                    "1 (yes)".into()
416                } else {
417                    "0 (no)".into()
418                },
419            ),
420            (
421                "Staged Warn Threshold",
422                "WARN_STAGED_FILES_THRESHOLD",
423                self.warn_staged_files_threshold.to_string(),
424            ),
425            (
426                "Confirm New Version",
427                "CONFIRM_NEW_VERSION",
428                if self.confirm_new_version {
429                    "1 (yes)".into()
430                } else {
431                    "0 (no)".into()
432                },
433            ),
434        ]
435    }
436
437    /// Set a field by its env suffix
438    pub fn set_field(&mut self, suffix: &str, value: &str) -> Result<()> {
439        match suffix {
440            "PROVIDER" => self.provider = value.into(),
441            "MODEL" => self.model = value.into(),
442            "API_KEY" => self.api_key = value.into(),
443            "API_URL" => self.api_url = value.into(),
444            "API_HEADERS" => self.api_headers = value.into(),
445            "LOCALE" => {
446                let locale = normalize_locale(value);
447                validate_locale(&locale)?;
448                self.locale = locale;
449            }
450            "ONE_LINER" => self.one_liner = value == "1" || value.eq_ignore_ascii_case("true"),
451            "COMMIT_TEMPLATE" => self.commit_template = value.into(),
452            "LLM_SYSTEM_PROMPT" => self.llm_system_prompt = value.into(),
453            "USE_GITMOJI" => self.use_gitmoji = value == "1" || value.eq_ignore_ascii_case("true"),
454            "GITMOJI_FORMAT" => self.gitmoji_format = value.into(),
455            "REVIEW_COMMIT" => {
456                self.review_commit = value == "1" || value.eq_ignore_ascii_case("true")
457            }
458            "POST_COMMIT_PUSH" => self.post_commit_push = normalize_post_commit_push(value),
459            "SUPPRESS_TOOL_OUTPUT" => {
460                self.suppress_tool_output = value == "1" || value.eq_ignore_ascii_case("true")
461            }
462            "WARN_STAGED_FILES_ENABLED" => {
463                self.warn_staged_files_enabled = value == "1" || value.eq_ignore_ascii_case("true");
464            }
465            "WARN_STAGED_FILES_THRESHOLD" => {
466                self.warn_staged_files_threshold =
467                    parse_usize_or_default(value, default_warn_staged_files_threshold());
468            }
469            "CONFIRM_NEW_VERSION" => {
470                self.confirm_new_version = value == "1" || value.eq_ignore_ascii_case("true");
471            }
472            _ => {}
473        }
474        Ok(())
475    }
476
477    fn ensure_valid_locale(&mut self) -> Result<()> {
478        self.locale = normalize_locale(&self.locale);
479        validate_locale(&self.locale)
480    }
481}
482
483/// Global config file path
484pub fn global_config_path() -> Option<PathBuf> {
485    if let Some(override_dir) = std::env::var_os("ACR_CONFIG_HOME") {
486        let override_path = PathBuf::from(override_dir);
487        if !override_path.as_os_str().is_empty() {
488            return Some(override_path.join("cgen").join("config.toml"));
489        }
490    }
491    dirs::config_dir().map(|d| d.join("cgen").join("config.toml"))
492}
493
494fn mask_key(key: &str) -> String {
495    if key.len() <= 8 {
496        "*".repeat(key.len())
497    } else {
498        format!("{}...{}", &key[..4], &key[key.len() - 4..])
499    }
500}
501
502fn truncate(s: &str, max: usize) -> String {
503    if s.len() <= max {
504        s.to_string()
505    } else {
506        format!("{}...", &s[..max])
507    }
508}
509
510fn normalize_post_commit_push(value: &str) -> String {
511    match value.trim().to_ascii_lowercase().as_str() {
512        "never" => "never".into(),
513        "always" => "always".into(),
514        _ => "ask".into(),
515    }
516}
517
518fn parse_usize_or_default(value: &str, default: usize) -> usize {
519    value.trim().parse::<usize>().unwrap_or(default)
520}
521
522fn normalize_locale(value: &str) -> String {
523    let normalized = value.trim();
524    if normalized.is_empty() {
525        default_locale()
526    } else {
527        normalized.to_ascii_lowercase()
528    }
529}
530
531fn validate_locale(locale: &str) -> Result<()> {
532    if locale == "en" || locale_has_i18n(locale) {
533        return Ok(());
534    }
535    anyhow::bail!(
536        "Unsupported locale '{}'. Only 'en' is available unless matching i18n resources exist. Set locale with `cgen config` or add i18n files first.",
537        locale
538    );
539}
540
541fn locale_has_i18n(locale: &str) -> bool {
542    locale_i18n_dirs()
543        .iter()
544        .any(|dir| locale_exists_in_i18n_dir(dir, locale))
545}
546
547fn locale_i18n_dirs() -> Vec<PathBuf> {
548    let mut dirs = Vec::new();
549    if let Ok(repo_root) = crate::git::find_repo_root() {
550        dirs.push(PathBuf::from(repo_root).join("i18n"));
551    }
552    if let Ok(current_dir) = std::env::current_dir() {
553        let i18n_dir = current_dir.join("i18n");
554        if !dirs.contains(&i18n_dir) {
555            dirs.push(i18n_dir);
556        }
557    }
558    dirs
559}
560
561fn locale_exists_in_i18n_dir(i18n_dir: &PathBuf, locale: &str) -> bool {
562    if !i18n_dir.exists() {
563        return false;
564    }
565    if i18n_dir.join(locale).is_dir() {
566        return true;
567    }
568
569    let entries = match std::fs::read_dir(i18n_dir) {
570        Ok(entries) => entries,
571        Err(_) => return false,
572    };
573
574    entries.filter_map(|entry| entry.ok()).any(|entry| {
575        let path = entry.path();
576        if path.is_file() {
577            return path
578                .file_stem()
579                .and_then(|stem| stem.to_str())
580                .map(|stem| stem.eq_ignore_ascii_case(locale))
581                .unwrap_or(false);
582        }
583        false
584    })
585}
586
587fn parse_dotenv(path: &PathBuf) -> Result<HashMap<String, String>> {
588    let content = std::fs::read_to_string(path)
589        .with_context(|| format!("Failed to read {}", path.display()))?;
590    let mut map = HashMap::new();
591    for line in content.lines() {
592        let line = line.trim();
593        if line.is_empty() || line.starts_with('#') {
594            continue;
595        }
596        if let Some((key, val)) = line.split_once('=') {
597            let key = key.trim().to_string();
598            let val = val.trim().trim_matches('"').trim_matches('\'').to_string();
599            map.insert(key, val);
600        }
601    }
602    Ok(map)
603}