Skip to main content

git_cli/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::PathBuf;
5
6const DEFAULT_MODEL_FAST: &str = "qwen2.5:3b";
7const DEFAULT_MODEL_SMART: &str = "qwen2.5:3b";
8const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
9const DEFAULT_KEEP_ALIVE: &str = "10m";
10
11#[derive(Debug, Serialize, Deserialize, Clone)]
12pub struct Config {
13    #[serde(default = "default_model_fast")]
14    pub model_fast: String,
15    #[serde(default = "default_model_smart")]
16    pub model_smart: String,
17    #[serde(default)]
18    pub model: Option<String>,
19    #[serde(default = "default_endpoint")]
20    pub endpoint: String,
21    #[serde(default = "default_keep_alive")]
22    pub keep_alive: String,
23    #[serde(default)]
24    pub aliases: HashMap<String, String>,
25}
26
27fn default_model_fast() -> String {
28    DEFAULT_MODEL_FAST.to_string()
29}
30
31fn default_model_smart() -> String {
32    DEFAULT_MODEL_SMART.to_string()
33}
34
35fn default_endpoint() -> String {
36    DEFAULT_ENDPOINT.to_string()
37}
38
39fn default_keep_alive() -> String {
40    DEFAULT_KEEP_ALIVE.to_string()
41}
42
43impl Default for Config {
44    fn default() -> Self {
45        Self {
46            model_fast: default_model_fast(),
47            model_smart: default_model_smart(),
48            model: None,
49            endpoint: default_endpoint(),
50            keep_alive: default_keep_alive(),
51            aliases: HashMap::new(),
52        }
53    }
54}
55
56const COMPLEX_KEYWORDS: &[&str] = &[
57    "rewrite", "rebase", "squash", "cherry-pick", "cherry pick",
58    "bisect", "filter", "reflog", "submodule", "subtree",
59    "worktree", "every commit", "all commits", "multiple commits",
60    "rename commit", "reword", "interactive",
61    "conflict", "resolve", "hook", "migrate",
62    "convert", "split", "reorganize", "restructure",
63    "history", "rewrite history",
64    "how many", "how much", "who are", "who has", "which branches",
65    "pending", "review", "pull request", "pr ",
66    "compare", "between", "since", "contributors", "committers",
67    "analyze", "statistics", "stats", "summary",
68    "multiple branches", "all branches", "merge all",
69];
70
71pub fn is_complex_task(task: &str) -> bool {
72    let lower = task.to_lowercase();
73    COMPLEX_KEYWORDS.iter().any(|k| lower.contains(k))
74}
75
76impl Config {
77    pub fn config_path() -> Option<PathBuf> {
78        dirs::home_dir().map(|h| h.join(".git-cli.toml"))
79    }
80
81    pub fn load() -> Self {
82        let Some(path) = Self::config_path() else {
83            return Self::default();
84        };
85
86        match fs::read_to_string(&path) {
87            Ok(contents) => toml::from_str(&contents).unwrap_or_default(),
88            Err(_) => Self::default(),
89        }
90    }
91
92    pub fn save(&self) -> Result<(), String> {
93        let path = Self::config_path().ok_or("Could not determine home directory")?;
94        let contents =
95            toml::to_string_pretty(self).map_err(|e| format!("Failed to serialize config: {e}"))?;
96        fs::write(&path, contents).map_err(|e| format!("Failed to write {}: {e}", path.display()))
97    }
98
99    pub fn apply_overrides(mut self, model: Option<String>, endpoint: Option<String>) -> Self {
100        if let Some(m) = model {
101            self.model = Some(m);
102        }
103        if let Some(e) = endpoint {
104            self.endpoint = e;
105        }
106        self
107    }
108
109    pub fn select_model(&self, task: &str) -> String {
110        if let Some(ref m) = self.model {
111            return m.clone();
112        }
113        if is_complex_task(task) {
114            self.model_smart.clone()
115        } else {
116            self.model_fast.clone()
117        }
118    }
119
120    pub fn resolve_alias(&self, input: &str) -> String {
121        self.aliases
122            .get(input)
123            .cloned()
124            .unwrap_or_else(|| input.to_string())
125    }
126}