Skip to main content

git_cli/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::PathBuf;
5use dirs;
6
7const DEFAULT_MODEL_FAST: &str = "qwen2.5:3b";
8const DEFAULT_MODEL_SMART: &str = "qwen2.5:3b";
9const DEFAULT_ENDPOINT: &str = "http://localhost:11434";
10const DEFAULT_KEEP_ALIVE: &str = "10m";
11
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct Config {
14    #[serde(default = "default_model_fast")]
15    pub model_fast: String,
16    #[serde(default = "default_model_smart")]
17    pub model_smart: String,
18    #[serde(default)]
19    pub model: Option<String>,
20    #[serde(default = "default_endpoint")]
21    pub endpoint: String,
22    #[serde(default = "default_keep_alive")]
23    pub keep_alive: String,
24    #[serde(default)]
25    pub aliases: HashMap<String, String>,
26}
27
28fn default_model_fast() -> String {
29    DEFAULT_MODEL_FAST.to_string()
30}
31
32fn default_model_smart() -> String {
33    DEFAULT_MODEL_SMART.to_string()
34}
35
36fn default_endpoint() -> String {
37    DEFAULT_ENDPOINT.to_string()
38}
39
40fn default_keep_alive() -> String {
41    DEFAULT_KEEP_ALIVE.to_string()
42}
43
44impl Default for Config {
45    fn default() -> Self {
46        Self {
47            model_fast: default_model_fast(),
48            model_smart: default_model_smart(),
49            model: None,
50            endpoint: default_endpoint(),
51            keep_alive: default_keep_alive(),
52            aliases: HashMap::new(),
53        }
54    }
55}
56
57const COMPLEX_KEYWORDS: &[&str] = &[
58    "rewrite", "rebase", "squash", "cherry-pick", "cherry pick",
59    "bisect", "filter", "reflog", "submodule", "subtree",
60    "worktree", "every commit", "all commits", "multiple commits",
61    "rename commit", "reword", "interactive",
62    "conflict", "resolve", "hook", "migrate",
63    "convert", "split", "reorganize", "restructure",
64    "history", "rewrite history",
65    "how many", "how much", "who are", "who has", "which branches",
66    "pending", "review", "pull request", "pr ",
67    "compare", "between", "since", "contributors", "committers",
68    "analyze", "statistics", "stats", "summary",
69    "multiple branches", "all branches", "merge all",
70];
71
72const PR_KEYWORDS: &[&str] = &[
73    "pull request",
74    "create a pr",
75    "create pr",
76    "open a pr",
77    "open pr",
78    "new pr",
79    "merge pr",
80    "list pr",
81    "show pr",
82    " pr ",
83    " pr to",
84    " pr from",
85    " pr for",
86];
87
88pub fn is_pr_task(task: &str) -> bool {
89    let lower = task.to_lowercase();
90    PR_KEYWORDS.iter().any(|k| lower.contains(k)) || lower.ends_with(" pr")
91}
92
93pub fn is_complex_task(task: &str) -> bool {
94    let lower = task.to_lowercase();
95    is_pr_task(&lower) || COMPLEX_KEYWORDS.iter().any(|k| lower.contains(k))
96}
97
98impl Config {
99    pub fn config_path() -> Option<PathBuf> {
100        dirs::home_dir().map(|h| h.join(".git-cli.toml"))
101    }
102
103    pub fn load() -> Self {
104        let Some(path) = Self::config_path() else {
105            return Self::default();
106        };
107
108        match fs::read_to_string(&path) {
109            Ok(contents) => toml::from_str(&contents).unwrap_or_default(),
110            Err(_) => Self::default(),
111        }
112    }
113
114    pub fn save(&self) -> Result<(), String> {
115        let path = Self::config_path().ok_or("Could not determine home directory")?;
116        let contents =
117            toml::to_string_pretty(self).map_err(|e| format!("Failed to serialize config: {e}"))?;
118        fs::write(&path, contents).map_err(|e| format!("Failed to write {}: {e}", path.display()))
119    }
120
121    pub fn apply_overrides(mut self, model: Option<String>, endpoint: Option<String>) -> Self {
122        if let Some(m) = model {
123            self.model = Some(m);
124        }
125        if let Some(e) = endpoint {
126            self.endpoint = e;
127        }
128        self
129    }
130
131    pub fn select_model(&self, task: &str) -> String {
132        if let Some(ref m) = self.model {
133            return m.clone();
134        }
135        if is_complex_task(task) {
136            self.model_smart.clone()
137        } else {
138            self.model_fast.clone()
139        }
140    }
141
142    pub fn resolve_alias(&self, input: &str) -> String {
143        self.aliases
144            .get(input)
145            .cloned()
146            .unwrap_or_else(|| input.to_string())
147    }
148}
149
150#[derive(Debug, Serialize, Deserialize, Clone)]
151pub struct PromptExample {
152    pub task: String,
153    pub commands: String,
154}
155
156#[derive(Debug, Serialize, Deserialize, Clone)]
157pub struct PromptConfig {
158    #[serde(default)]
159    pub preamble: Option<String>,
160    #[serde(default)]
161    pub examples: Vec<PromptExample>,
162}
163
164impl Default for PromptConfig {
165    fn default() -> Self {
166        Self {
167            preamble: None,
168            examples: Vec::new(),
169        }
170    }
171}
172
173impl PromptConfig {
174    pub fn config_dir() -> Option<PathBuf> {
175        dirs::home_dir().map(|h| h.join(".config").join("git-cli"))
176    }
177
178    pub fn config_path() -> Option<PathBuf> {
179        Self::config_dir().map(|d| d.join("prompt.toml"))
180    }
181
182    pub fn load() -> Self {
183        let Some(path) = Self::config_path() else {
184            return Self::default();
185        };
186
187        match fs::read_to_string(&path) {
188            Ok(contents) => toml::from_str(&contents).unwrap_or_default(),
189            Err(_) => Self::default(),
190        }
191    }
192}