1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6const DEFAULT_SYSTEM_PROMPT: &str = "You are to act as an author of a commit message in git. \
7I'll send you an output of 'git diff --staged' command, and you are to convert \
8it into a commit message. Follow the Conventional Commits specification.";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AppConfig {
12 #[serde(default = "default_provider")]
13 pub provider: String,
14 #[serde(default = "default_model")]
15 pub model: String,
16 #[serde(default)]
17 pub api_key: String,
18 #[serde(default)]
19 pub api_url: String,
20 #[serde(default)]
21 pub api_headers: String,
22 #[serde(default = "default_locale")]
23 pub locale: String,
24 #[serde(default = "default_true")]
25 pub one_liner: bool,
26 #[serde(default = "default_commit_template")]
27 pub commit_template: String,
28 #[serde(default = "default_system_prompt")]
29 pub llm_system_prompt: String,
30 #[serde(default)]
31 pub use_gitmoji: bool,
32 #[serde(default = "default_gitmoji_format")]
33 pub gitmoji_format: String,
34 #[serde(default)]
35 pub review_commit: bool,
36 #[serde(default = "default_post_commit_push")]
37 pub post_commit_push: String,
38 #[serde(default)]
39 pub suppress_tool_output: bool,
40 #[serde(default = "default_true")]
41 pub warn_staged_files_enabled: bool,
42 #[serde(default = "default_warn_staged_files_threshold")]
43 pub warn_staged_files_threshold: usize,
44 #[serde(default = "default_true")]
45 pub confirm_new_version: bool,
46}
47
48fn default_provider() -> String {
49 "groq".into()
50}
51fn default_model() -> String {
52 "llama-3.3-70b-versatile".into()
53}
54fn default_locale() -> String {
55 "en".into()
56}
57fn default_true() -> bool {
58 true
59}
60fn default_post_commit_push() -> String {
61 "ask".into()
62}
63fn default_commit_template() -> String {
64 "$msg".into()
65}
66fn default_system_prompt() -> String {
67 DEFAULT_SYSTEM_PROMPT.into()
68}
69fn default_gitmoji_format() -> String {
70 "unicode".into()
71}
72fn default_warn_staged_files_threshold() -> usize {
73 20
74}
75
76impl Default for AppConfig {
77 fn default() -> Self {
78 Self {
79 provider: default_provider(),
80 model: default_model(),
81 api_key: String::new(),
82 api_url: String::new(),
83 api_headers: String::new(),
84 locale: default_locale(),
85 one_liner: true,
86 commit_template: default_commit_template(),
87 llm_system_prompt: default_system_prompt(),
88 use_gitmoji: false,
89 gitmoji_format: default_gitmoji_format(),
90 review_commit: false,
91 post_commit_push: default_post_commit_push(),
92 suppress_tool_output: false,
93 warn_staged_files_enabled: true,
94 warn_staged_files_threshold: default_warn_staged_files_threshold(),
95 confirm_new_version: true,
96 }
97 }
98}
99
100const ENV_FIELD_MAP: &[(&str, &str)] = &[
102 ("PROVIDER", "provider"),
103 ("MODEL", "model"),
104 ("API_KEY", "api_key"),
105 ("API_URL", "api_url"),
106 ("API_HEADERS", "api_headers"),
107 ("LOCALE", "locale"),
108 ("ONE_LINER", "one_liner"),
109 ("COMMIT_TEMPLATE", "commit_template"),
110 ("LLM_SYSTEM_PROMPT", "llm_system_prompt"),
111 ("USE_GITMOJI", "use_gitmoji"),
112 ("GITMOJI_FORMAT", "gitmoji_format"),
113 ("REVIEW_COMMIT", "review_commit"),
114 ("POST_COMMIT_PUSH", "post_commit_push"),
115 ("SUPPRESS_TOOL_OUTPUT", "suppress_tool_output"),
116 ("WARN_STAGED_FILES_ENABLED", "warn_staged_files_enabled"),
117 ("WARN_STAGED_FILES_THRESHOLD", "warn_staged_files_threshold"),
118 ("CONFIRM_NEW_VERSION", "confirm_new_version"),
119];
120
121impl AppConfig {
122 pub fn load() -> Result<Self> {
124 let mut cfg = Self::default();
125
126 if let Some(path) = global_config_path() {
128 if path.exists() {
129 let content = std::fs::read_to_string(&path)
130 .with_context(|| format!("Failed to read {}", path.display()))?;
131 let file_cfg: AppConfig = toml::from_str(&content)
132 .with_context(|| format!("Failed to parse {}", path.display()))?;
133 cfg.merge_from(&file_cfg);
134 }
135 }
136
137 if let Ok(root) = crate::git::find_repo_root() {
139 let env_path = PathBuf::from(&root).join(".env");
140 if env_path.exists() {
141 let env_map = parse_dotenv(&env_path)?;
142 cfg.apply_env_map(&env_map);
143 }
144 }
145
146 let mut env_map = HashMap::new();
148 for (suffix, _) in ENV_FIELD_MAP {
149 let key = format!("ACR_{suffix}");
150 if let Ok(val) = std::env::var(&key) {
151 env_map.insert(key, val);
152 }
153 }
154 cfg.apply_env_map(&env_map);
155 cfg.ensure_valid_locale()?;
156
157 Ok(cfg)
158 }
159
160 fn merge_from(&mut self, other: &AppConfig) {
161 if !other.provider.is_empty() {
162 self.provider = other.provider.clone();
163 }
164 if !other.model.is_empty() {
165 self.model = other.model.clone();
166 }
167 if !other.api_key.is_empty() {
168 self.api_key = other.api_key.clone();
169 }
170 if !other.api_url.is_empty() {
171 self.api_url = other.api_url.clone();
172 }
173 if !other.api_headers.is_empty() {
174 self.api_headers = other.api_headers.clone();
175 }
176 if !other.locale.is_empty() {
177 self.locale = other.locale.clone();
178 }
179 self.one_liner = other.one_liner;
180 if !other.commit_template.is_empty() {
181 self.commit_template = other.commit_template.clone();
182 }
183 if !other.llm_system_prompt.is_empty() {
184 self.llm_system_prompt = other.llm_system_prompt.clone();
185 }
186 self.use_gitmoji = other.use_gitmoji;
187 if !other.gitmoji_format.is_empty() {
188 self.gitmoji_format = other.gitmoji_format.clone();
189 }
190 self.review_commit = other.review_commit;
191 if !other.post_commit_push.is_empty() {
192 self.post_commit_push = normalize_post_commit_push(&other.post_commit_push);
193 }
194 self.suppress_tool_output = other.suppress_tool_output;
195 self.warn_staged_files_enabled = other.warn_staged_files_enabled;
196 self.warn_staged_files_threshold = other.warn_staged_files_threshold;
197 self.confirm_new_version = other.confirm_new_version;
198 }
199
200 fn apply_env_map(&mut self, map: &HashMap<String, String>) {
201 for (suffix, _field) in ENV_FIELD_MAP {
202 let key = format!("ACR_{suffix}");
203 if let Some(val) = map.get(&key) {
204 match *suffix {
205 "PROVIDER" => self.provider = val.clone(),
206 "MODEL" => self.model = val.clone(),
207 "API_KEY" => self.api_key = val.clone(),
208 "API_URL" => self.api_url = val.clone(),
209 "API_HEADERS" => self.api_headers = val.clone(),
210 "LOCALE" => self.locale = val.clone(),
211 "ONE_LINER" => self.one_liner = val == "1" || val.eq_ignore_ascii_case("true"),
212 "COMMIT_TEMPLATE" => self.commit_template = val.clone(),
213 "LLM_SYSTEM_PROMPT" => self.llm_system_prompt = val.clone(),
214 "USE_GITMOJI" => {
215 self.use_gitmoji = val == "1" || val.eq_ignore_ascii_case("true")
216 }
217 "GITMOJI_FORMAT" => self.gitmoji_format = val.clone(),
218 "REVIEW_COMMIT" => {
219 self.review_commit = val == "1" || val.eq_ignore_ascii_case("true")
220 }
221 "POST_COMMIT_PUSH" => self.post_commit_push = normalize_post_commit_push(val),
222 "SUPPRESS_TOOL_OUTPUT" => {
223 self.suppress_tool_output = val == "1" || val.eq_ignore_ascii_case("true")
224 }
225 "WARN_STAGED_FILES_ENABLED" => {
226 self.warn_staged_files_enabled =
227 val == "1" || val.eq_ignore_ascii_case("true")
228 }
229 "WARN_STAGED_FILES_THRESHOLD" => {
230 self.warn_staged_files_threshold =
231 parse_usize_or_default(val, default_warn_staged_files_threshold());
232 }
233 "CONFIRM_NEW_VERSION" => {
234 self.confirm_new_version = val == "1" || val.eq_ignore_ascii_case("true")
235 }
236 _ => {}
237 }
238 }
239 }
240 }
241
242 pub fn save_global(&self) -> Result<()> {
244 let path = global_config_path().context("Could not determine global config directory")?;
245 if let Some(parent) = path.parent() {
246 std::fs::create_dir_all(parent)
247 .with_context(|| format!("Failed to create {}", parent.display()))?;
248 }
249 let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
250 std::fs::write(&path, content)
251 .with_context(|| format!("Failed to write {}", path.display()))?;
252 Ok(())
253 }
254
255 pub fn save_local(&self) -> Result<()> {
257 let root = crate::git::find_repo_root().context("Not in a git repository")?;
258 let env_path = PathBuf::from(&root).join(".env");
259
260 let mut lines = Vec::new();
261 lines.push(format!("ACR_PROVIDER={}", self.provider));
262 lines.push(format!("ACR_MODEL={}", self.model));
263 if !self.api_key.is_empty() {
264 lines.push(format!("ACR_API_KEY={}", self.api_key));
265 }
266 if !self.api_url.is_empty() {
267 lines.push(format!("ACR_API_URL={}", self.api_url));
268 }
269 if !self.api_headers.is_empty() {
270 lines.push(format!("ACR_API_HEADERS={}", self.api_headers));
271 }
272 lines.push(format!("ACR_LOCALE={}", self.locale));
273 lines.push(format!(
274 "ACR_ONE_LINER={}",
275 if self.one_liner { "1" } else { "0" }
276 ));
277 if self.commit_template != "$msg" {
278 lines.push(format!("ACR_COMMIT_TEMPLATE={}", self.commit_template));
279 }
280 if self.llm_system_prompt != DEFAULT_SYSTEM_PROMPT {
281 lines.push(format!("ACR_LLM_SYSTEM_PROMPT={}", self.llm_system_prompt));
282 }
283 lines.push(format!(
284 "ACR_USE_GITMOJI={}",
285 if self.use_gitmoji { "1" } else { "0" }
286 ));
287 lines.push(format!("ACR_GITMOJI_FORMAT={}", self.gitmoji_format));
288 lines.push(format!(
289 "ACR_REVIEW_COMMIT={}",
290 if self.review_commit { "1" } else { "0" }
291 ));
292 lines.push(format!(
293 "ACR_POST_COMMIT_PUSH={}",
294 normalize_post_commit_push(&self.post_commit_push)
295 ));
296 lines.push(format!(
297 "ACR_SUPPRESS_TOOL_OUTPUT={}",
298 if self.suppress_tool_output { "1" } else { "0" }
299 ));
300 lines.push(format!(
301 "ACR_WARN_STAGED_FILES_ENABLED={}",
302 if self.warn_staged_files_enabled {
303 "1"
304 } else {
305 "0"
306 }
307 ));
308 lines.push(format!(
309 "ACR_WARN_STAGED_FILES_THRESHOLD={}",
310 self.warn_staged_files_threshold
311 ));
312 lines.push(format!(
313 "ACR_CONFIRM_NEW_VERSION={}",
314 if self.confirm_new_version { "1" } else { "0" }
315 ));
316
317 std::fs::write(&env_path, lines.join("\n") + "\n")
318 .with_context(|| format!("Failed to write {}", env_path.display()))?;
319 Ok(())
320 }
321
322 pub fn fields_display(&self) -> Vec<(&'static str, &'static str, String)> {
324 vec![
325 ("Provider", "PROVIDER", self.provider.clone()),
326 ("Model", "MODEL", self.model.clone()),
327 (
328 "API Key",
329 "API_KEY",
330 if self.api_key.is_empty() {
331 "(not set)".into()
332 } else {
333 mask_key(&self.api_key)
334 },
335 ),
336 (
337 "API URL",
338 "API_URL",
339 if self.api_url.is_empty() {
340 "(auto from provider)".into()
341 } else {
342 self.api_url.clone()
343 },
344 ),
345 (
346 "API Headers",
347 "API_HEADERS",
348 if self.api_headers.is_empty() {
349 "(auto from provider)".into()
350 } else {
351 self.api_headers.clone()
352 },
353 ),
354 ("Locale", "LOCALE", self.locale.clone()),
355 (
356 "One-liner",
357 "ONE_LINER",
358 if self.one_liner {
359 "1 (yes)".into()
360 } else {
361 "0 (no)".into()
362 },
363 ),
364 (
365 "Commit Template",
366 "COMMIT_TEMPLATE",
367 self.commit_template.clone(),
368 ),
369 (
370 "System Prompt",
371 "LLM_SYSTEM_PROMPT",
372 truncate(&self.llm_system_prompt, 60),
373 ),
374 (
375 "Use Gitmoji",
376 "USE_GITMOJI",
377 if self.use_gitmoji {
378 "1 (yes)".into()
379 } else {
380 "0 (no)".into()
381 },
382 ),
383 (
384 "Gitmoji Format",
385 "GITMOJI_FORMAT",
386 self.gitmoji_format.clone(),
387 ),
388 (
389 "Review Commit",
390 "REVIEW_COMMIT",
391 if self.review_commit {
392 "1 (yes)".into()
393 } else {
394 "0 (no)".into()
395 },
396 ),
397 (
398 "Post Commit Push",
399 "POST_COMMIT_PUSH",
400 normalize_post_commit_push(&self.post_commit_push),
401 ),
402 (
403 "Suppress Tool Output",
404 "SUPPRESS_TOOL_OUTPUT",
405 if self.suppress_tool_output {
406 "1 (yes)".into()
407 } else {
408 "0 (no)".into()
409 },
410 ),
411 (
412 "Warn Staged Files",
413 "WARN_STAGED_FILES_ENABLED",
414 if self.warn_staged_files_enabled {
415 "1 (yes)".into()
416 } else {
417 "0 (no)".into()
418 },
419 ),
420 (
421 "Staged Warn Threshold",
422 "WARN_STAGED_FILES_THRESHOLD",
423 self.warn_staged_files_threshold.to_string(),
424 ),
425 (
426 "Confirm New Version",
427 "CONFIRM_NEW_VERSION",
428 if self.confirm_new_version {
429 "1 (yes)".into()
430 } else {
431 "0 (no)".into()
432 },
433 ),
434 ]
435 }
436
437 pub fn set_field(&mut self, suffix: &str, value: &str) -> Result<()> {
439 match suffix {
440 "PROVIDER" => self.provider = value.into(),
441 "MODEL" => self.model = value.into(),
442 "API_KEY" => self.api_key = value.into(),
443 "API_URL" => self.api_url = value.into(),
444 "API_HEADERS" => self.api_headers = value.into(),
445 "LOCALE" => {
446 let locale = normalize_locale(value);
447 validate_locale(&locale)?;
448 self.locale = locale;
449 }
450 "ONE_LINER" => self.one_liner = value == "1" || value.eq_ignore_ascii_case("true"),
451 "COMMIT_TEMPLATE" => self.commit_template = value.into(),
452 "LLM_SYSTEM_PROMPT" => self.llm_system_prompt = value.into(),
453 "USE_GITMOJI" => self.use_gitmoji = value == "1" || value.eq_ignore_ascii_case("true"),
454 "GITMOJI_FORMAT" => self.gitmoji_format = value.into(),
455 "REVIEW_COMMIT" => {
456 self.review_commit = value == "1" || value.eq_ignore_ascii_case("true")
457 }
458 "POST_COMMIT_PUSH" => self.post_commit_push = normalize_post_commit_push(value),
459 "SUPPRESS_TOOL_OUTPUT" => {
460 self.suppress_tool_output = value == "1" || value.eq_ignore_ascii_case("true")
461 }
462 "WARN_STAGED_FILES_ENABLED" => {
463 self.warn_staged_files_enabled = value == "1" || value.eq_ignore_ascii_case("true");
464 }
465 "WARN_STAGED_FILES_THRESHOLD" => {
466 self.warn_staged_files_threshold =
467 parse_usize_or_default(value, default_warn_staged_files_threshold());
468 }
469 "CONFIRM_NEW_VERSION" => {
470 self.confirm_new_version = value == "1" || value.eq_ignore_ascii_case("true");
471 }
472 _ => {}
473 }
474 Ok(())
475 }
476
477 fn ensure_valid_locale(&mut self) -> Result<()> {
478 self.locale = normalize_locale(&self.locale);
479 validate_locale(&self.locale)
480 }
481}
482
483pub fn global_config_path() -> Option<PathBuf> {
485 if let Some(override_dir) = std::env::var_os("ACR_CONFIG_HOME") {
486 let override_path = PathBuf::from(override_dir);
487 if !override_path.as_os_str().is_empty() {
488 return Some(override_path.join("cgen").join("config.toml"));
489 }
490 }
491 dirs::config_dir().map(|d| d.join("cgen").join("config.toml"))
492}
493
494fn mask_key(key: &str) -> String {
495 if key.len() <= 8 {
496 "*".repeat(key.len())
497 } else {
498 format!("{}...{}", &key[..4], &key[key.len() - 4..])
499 }
500}
501
502fn truncate(s: &str, max: usize) -> String {
503 if s.len() <= max {
504 s.to_string()
505 } else {
506 format!("{}...", &s[..max])
507 }
508}
509
510fn normalize_post_commit_push(value: &str) -> String {
511 match value.trim().to_ascii_lowercase().as_str() {
512 "never" => "never".into(),
513 "always" => "always".into(),
514 _ => "ask".into(),
515 }
516}
517
518fn parse_usize_or_default(value: &str, default: usize) -> usize {
519 value.trim().parse::<usize>().unwrap_or(default)
520}
521
522fn normalize_locale(value: &str) -> String {
523 let normalized = value.trim();
524 if normalized.is_empty() {
525 default_locale()
526 } else {
527 normalized.to_ascii_lowercase()
528 }
529}
530
531fn validate_locale(locale: &str) -> Result<()> {
532 if locale == "en" || locale_has_i18n(locale) {
533 return Ok(());
534 }
535 anyhow::bail!(
536 "Unsupported locale '{}'. Only 'en' is available unless matching i18n resources exist. Set locale with `cgen config` or add i18n files first.",
537 locale
538 );
539}
540
541fn locale_has_i18n(locale: &str) -> bool {
542 locale_i18n_dirs()
543 .iter()
544 .any(|dir| locale_exists_in_i18n_dir(dir, locale))
545}
546
547fn locale_i18n_dirs() -> Vec<PathBuf> {
548 let mut dirs = Vec::new();
549 if let Ok(repo_root) = crate::git::find_repo_root() {
550 dirs.push(PathBuf::from(repo_root).join("i18n"));
551 }
552 if let Ok(current_dir) = std::env::current_dir() {
553 let i18n_dir = current_dir.join("i18n");
554 if !dirs.contains(&i18n_dir) {
555 dirs.push(i18n_dir);
556 }
557 }
558 dirs
559}
560
561fn locale_exists_in_i18n_dir(i18n_dir: &PathBuf, locale: &str) -> bool {
562 if !i18n_dir.exists() {
563 return false;
564 }
565 if i18n_dir.join(locale).is_dir() {
566 return true;
567 }
568
569 let entries = match std::fs::read_dir(i18n_dir) {
570 Ok(entries) => entries,
571 Err(_) => return false,
572 };
573
574 entries.filter_map(|entry| entry.ok()).any(|entry| {
575 let path = entry.path();
576 if path.is_file() {
577 return path
578 .file_stem()
579 .and_then(|stem| stem.to_str())
580 .map(|stem| stem.eq_ignore_ascii_case(locale))
581 .unwrap_or(false);
582 }
583 false
584 })
585}
586
587fn parse_dotenv(path: &PathBuf) -> Result<HashMap<String, String>> {
588 let content = std::fs::read_to_string(path)
589 .with_context(|| format!("Failed to read {}", path.display()))?;
590 let mut map = HashMap::new();
591 for line in content.lines() {
592 let line = line.trim();
593 if line.is_empty() || line.starts_with('#') {
594 continue;
595 }
596 if let Some((key, val)) = line.split_once('=') {
597 let key = key.trim().to_string();
598 let val = val.trim().trim_matches('"').trim_matches('\'').to_string();
599 map.insert(key, val);
600 }
601 }
602 Ok(map)
603}