1use crate::git::GitRepo;
7use crate::instruction_presets::get_instruction_preset_library;
8use crate::log_debug;
9use crate::providers::{Provider, ProviderConfig};
10
11use anyhow::{Context, Result, anyhow};
12use dirs::{config_dir, home_dir};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs;
16use std::path::{Path, PathBuf};
17
18pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
20
21#[derive(Deserialize, Serialize, Clone, Debug)]
23pub struct Config {
24 #[serde(default, skip_serializing_if = "String::is_empty")]
26 pub default_provider: String,
27 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
29 pub providers: HashMap<String, ProviderConfig>,
30 #[serde(default = "default_true", skip_serializing_if = "is_true")]
32 pub use_gitmoji: bool,
33 #[serde(default, skip_serializing_if = "String::is_empty")]
35 pub instructions: String,
36 #[serde(default = "default_preset", skip_serializing_if = "is_default_preset")]
38 pub instruction_preset: String,
39 #[serde(default, skip_serializing_if = "String::is_empty")]
41 pub theme: String,
42 #[serde(
44 default = "default_subagent_timeout",
45 skip_serializing_if = "is_default_subagent_timeout"
46 )]
47 pub subagent_timeout_secs: u64,
48 #[serde(
50 default = "default_subagent_max_turns",
51 skip_serializing_if = "is_default_subagent_max_turns"
52 )]
53 pub subagent_max_turns: usize,
54 #[serde(default = "default_true", skip_serializing_if = "is_true")]
56 pub critic_enabled: bool,
57 #[serde(skip)]
59 pub temp_instructions: Option<String>,
60 #[serde(skip)]
62 pub temp_preset: Option<String>,
63 #[serde(skip)]
65 pub is_project_config: bool,
66 #[serde(skip)]
68 pub gitmoji_override: Option<bool>,
69}
70
71fn default_true() -> bool {
72 true
73}
74
75#[allow(clippy::trivially_copy_pass_by_ref)]
76fn is_true(val: &bool) -> bool {
77 *val
78}
79
80fn default_preset() -> String {
81 "default".to_string()
82}
83
84fn is_default_preset(val: &str) -> bool {
85 val.is_empty() || val == "default"
86}
87
88fn default_subagent_timeout() -> u64 {
89 120 }
91
92#[allow(clippy::trivially_copy_pass_by_ref)]
93fn is_default_subagent_timeout(val: &u64) -> bool {
94 *val == 120
95}
96
97fn default_subagent_max_turns() -> usize {
98 20
99}
100
101#[allow(clippy::trivially_copy_pass_by_ref)]
102fn is_default_subagent_max_turns(val: &usize) -> bool {
103 *val == 20
104}
105
106impl Default for Config {
107 fn default() -> Self {
108 let mut providers = HashMap::new();
109 for provider in Provider::ALL {
110 providers.insert(
111 provider.name().to_string(),
112 ProviderConfig::with_defaults(*provider),
113 );
114 }
115
116 Self {
117 default_provider: Provider::default().name().to_string(),
118 providers,
119 use_gitmoji: true,
120 instructions: String::new(),
121 instruction_preset: default_preset(),
122 theme: String::new(),
123 subagent_timeout_secs: default_subagent_timeout(),
124 subagent_max_turns: default_subagent_max_turns(),
125 critic_enabled: true,
126 temp_instructions: None,
127 temp_preset: None,
128 is_project_config: false,
129 gitmoji_override: None,
130 }
131 }
132}
133
134impl Config {
135 pub fn load() -> Result<Self> {
141 let config_path = Self::get_personal_config_path()?;
142 let mut config = if config_path.exists() {
143 let content = fs::read_to_string(&config_path)?;
144 let parsed: Self = toml::from_str(&content)?;
145 let (migrated, needs_save) = Self::migrate_if_needed(parsed);
146 if needs_save && let Err(e) = migrated.save() {
147 log_debug!("Failed to save migrated config: {}", e);
148 }
149 migrated
150 } else {
151 Self::default()
152 };
153
154 if let Ok((project_config, project_source)) = Self::load_project_config_with_source() {
156 config.merge_loaded_project_config(project_config, &project_source);
157 }
158
159 log_debug!(
160 "Configuration loaded (provider: {}, gitmoji: {})",
161 config.default_provider,
162 config.use_gitmoji
163 );
164 Ok(config)
165 }
166
167 pub fn load_project_config() -> Result<Self> {
173 let (config, _) = Self::load_project_config_with_source()?;
174 Ok(config)
175 }
176
177 fn load_project_config_with_source() -> Result<(Self, toml::Value)> {
178 let config_path = Self::get_project_config_path()?;
179 if !config_path.exists() {
180 return Err(anyhow!("Project configuration file not found"));
181 }
182
183 let content = fs::read_to_string(&config_path)
184 .with_context(|| format!("Failed to read {}", config_path.display()))?;
185 let project_source = toml::from_str(&content).with_context(|| {
186 format!(
187 "Invalid {} format. Check for syntax errors.",
188 PROJECT_CONFIG_FILENAME
189 )
190 })?;
191
192 let mut config: Self = toml::from_str(&content).with_context(|| {
193 format!(
194 "Invalid {} format. Check for syntax errors.",
195 PROJECT_CONFIG_FILENAME
196 )
197 })?;
198
199 config.is_project_config = true;
200 Ok((config, project_source))
201 }
202
203 pub fn get_project_config_path() -> Result<PathBuf> {
209 let repo_root = GitRepo::get_repo_root()?;
210 Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
211 }
212
213 pub fn merge_with_project_config(&mut self, project_config: Self) {
215 log_debug!("Merging with project configuration");
216
217 if !project_config.default_provider.is_empty()
219 && project_config.default_provider != Provider::default().name()
220 {
221 self.default_provider = project_config.default_provider;
222 }
223
224 for (provider_name, proj_config) in project_config.providers {
226 let entry = self.providers.entry(provider_name).or_default();
227
228 if !proj_config.model.is_empty() {
229 entry.model = proj_config.model;
230 }
231 if proj_config.fast_model.is_some() {
232 entry.fast_model = proj_config.fast_model;
233 }
234 if proj_config.token_limit.is_some() {
235 entry.token_limit = proj_config.token_limit;
236 }
237 entry
238 .additional_params
239 .extend(proj_config.additional_params);
240 }
241
242 self.use_gitmoji = project_config.use_gitmoji;
244 self.instructions = project_config.instructions;
245
246 if project_config.instruction_preset != default_preset() {
247 self.instruction_preset = project_config.instruction_preset;
248 }
249
250 if !project_config.theme.is_empty() {
252 self.theme = project_config.theme;
253 }
254
255 if project_config.subagent_timeout_secs != default_subagent_timeout() {
257 self.subagent_timeout_secs = project_config.subagent_timeout_secs;
258 }
259 if project_config.subagent_max_turns != default_subagent_max_turns() {
260 self.subagent_max_turns = project_config.subagent_max_turns;
261 }
262 }
263
264 fn merge_loaded_project_config(&mut self, project_config: Self, project_source: &toml::Value) {
265 log_debug!("Merging loaded project configuration with explicit field tracking");
266
267 self.merge_project_provider_config(&project_config);
268
269 if Self::project_config_has_key(project_source, "default_provider") {
270 self.default_provider = project_config.default_provider;
271 }
272 if Self::project_config_has_key(project_source, "use_gitmoji") {
273 self.use_gitmoji = project_config.use_gitmoji;
274 }
275 if Self::project_config_has_key(project_source, "instructions") {
276 self.instructions = project_config.instructions;
277 }
278 if Self::project_config_has_key(project_source, "instruction_preset") {
279 self.instruction_preset = project_config.instruction_preset;
280 }
281 if Self::project_config_has_key(project_source, "theme") {
282 self.theme = project_config.theme;
283 }
284 if Self::project_config_has_key(project_source, "subagent_timeout_secs") {
285 self.subagent_timeout_secs = project_config.subagent_timeout_secs;
286 }
287 if Self::project_config_has_key(project_source, "subagent_max_turns") {
288 self.subagent_max_turns = project_config.subagent_max_turns;
289 }
290 if Self::project_config_has_key(project_source, "critic_enabled") {
291 self.critic_enabled = project_config.critic_enabled;
292 }
293 }
294
295 fn merge_project_provider_config(&mut self, project_config: &Self) {
296 for (provider_name, proj_config) in &project_config.providers {
297 let entry = self.providers.entry(provider_name.clone()).or_default();
298
299 if !proj_config.model.is_empty() {
300 proj_config.model.clone_into(&mut entry.model);
301 }
302 if proj_config.fast_model.is_some() {
303 entry.fast_model.clone_from(&proj_config.fast_model);
304 }
305 if proj_config.token_limit.is_some() {
306 entry.token_limit = proj_config.token_limit;
307 }
308 entry
309 .additional_params
310 .extend(proj_config.additional_params.clone());
311 }
312 }
313
314 fn project_config_has_key(project_source: &toml::Value, key: &str) -> bool {
315 project_source
316 .as_table()
317 .is_some_and(|table| table.contains_key(key))
318 }
319
320 fn migrate_if_needed(mut config: Self) -> (Self, bool) {
328 let mut migrated = false;
329
330 for (legacy, canonical) in [("claude", "anthropic"), ("gemini", "google")] {
331 if let Some(legacy_config) = config.providers.remove(legacy) {
332 log_debug!("Migrating '{legacy}' provider to '{canonical}'");
333
334 if config.providers.contains_key(canonical) {
335 log_debug!(
336 "Keeping existing '{canonical}' config and dropping legacy '{legacy}' entry"
337 );
338 } else {
339 config
340 .providers
341 .insert(canonical.to_string(), legacy_config);
342 }
343
344 migrated = true;
345 }
346
347 if config.default_provider.eq_ignore_ascii_case(legacy) {
348 config.default_provider = canonical.to_string();
349 migrated = true;
350 }
351 }
352
353 (config, migrated)
354 }
355
356 pub fn save(&self) -> Result<()> {
362 if self.is_project_config {
363 return Ok(());
364 }
365
366 let config_path = Self::get_personal_config_path()?;
367 let content = toml::to_string_pretty(self)?;
368 Self::write_config_file(&config_path, &content)?;
369 log_debug!("Configuration saved");
370 Ok(())
371 }
372
373 pub fn save_as_project_config(&self) -> Result<()> {
379 let config_path = Self::get_project_config_path()?;
380
381 let mut project_config = self.clone();
382 project_config.is_project_config = true;
383
384 for provider_config in project_config.providers.values_mut() {
386 provider_config.api_key.clear();
387 }
388
389 let content = toml::to_string_pretty(&project_config)?;
390 Self::write_config_file(&config_path, &content)?;
391 Ok(())
392 }
393
394 fn write_config_file(path: &Path, content: &str) -> Result<()> {
400 #[cfg(unix)]
401 {
402 use std::os::unix::fs::PermissionsExt;
403
404 let tmp_path = path.with_extension("tmp");
406 fs::write(&tmp_path, content)?;
407 if let Err(e) = fs::set_permissions(&tmp_path, fs::Permissions::from_mode(0o600)) {
408 eprintln!(
409 "Warning: Could not restrict config permissions on {}: {e}",
410 tmp_path.display()
411 );
412 }
413 fs::rename(&tmp_path, path)?;
414 }
415
416 #[cfg(not(unix))]
417 {
418 fs::write(path, content)?;
419 }
420
421 Ok(())
422 }
423
424 fn resolve_personal_config_dir(
440 xdg_config_home: Option<PathBuf>,
441 home_dir: Option<PathBuf>,
442 platform_config_dir: Option<PathBuf>,
443 legacy_macos_config_exists: bool,
444 ) -> Result<PathBuf> {
445 if let Some(xdg) = xdg_config_home.filter(|path| !path.as_os_str().is_empty()) {
446 return Ok(xdg.join("git-iris"));
447 }
448
449 if legacy_macos_config_exists && let Some(platform) = platform_config_dir.clone() {
450 return Ok(platform.join("git-iris"));
451 }
452
453 if let Some(home) = home_dir {
454 return Ok(home.join(".config").join("git-iris"));
455 }
456
457 platform_config_dir
458 .map(|p| p.join("git-iris"))
459 .ok_or_else(|| anyhow!("Unable to determine config directory"))
460 }
461
462 pub fn get_personal_config_path() -> Result<PathBuf> {
468 let platform_dir = config_dir();
469
470 let legacy_macos_config_exists = cfg!(target_os = "macos")
475 && platform_dir
476 .as_ref()
477 .is_some_and(|dir| dir.join("git-iris").join("config.toml").exists());
478
479 let mut path = Self::resolve_personal_config_dir(
480 std::env::var_os("XDG_CONFIG_HOME").map(PathBuf::from),
481 home_dir(),
482 platform_dir,
483 legacy_macos_config_exists,
484 )?;
485 fs::create_dir_all(&path)?;
486 path.push("config.toml");
487 Ok(path)
488 }
489
490 pub fn check_environment(&self) -> Result<()> {
496 if !GitRepo::is_inside_work_tree()? {
497 return Err(anyhow!(
498 "Not in a Git repository. Please run this command from within a Git repository."
499 ));
500 }
501 Ok(())
502 }
503
504 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
506 self.temp_instructions = instructions;
507 }
508
509 pub fn set_temp_preset(&mut self, preset: Option<String>) {
511 self.temp_preset = preset;
512 }
513
514 #[must_use]
516 pub fn get_effective_preset_name(&self) -> &str {
517 self.temp_preset
518 .as_deref()
519 .unwrap_or(&self.instruction_preset)
520 }
521
522 #[must_use]
524 pub fn get_effective_instructions(&self) -> String {
525 let preset_library = get_instruction_preset_library();
526 let preset_instructions = self
527 .temp_preset
528 .as_ref()
529 .or(Some(&self.instruction_preset))
530 .and_then(|p| preset_library.get_preset(p))
531 .map(|p| p.instructions.clone())
532 .unwrap_or_default();
533
534 let custom = self
535 .temp_instructions
536 .as_ref()
537 .unwrap_or(&self.instructions);
538
539 format!("{preset_instructions}\n\n{custom}")
540 .trim()
541 .to_string()
542 }
543
544 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
546 pub fn update(
551 &mut self,
552 provider: Option<String>,
553 api_key: Option<String>,
554 model: Option<String>,
555 fast_model: Option<String>,
556 additional_params: Option<HashMap<String, String>>,
557 use_gitmoji: Option<bool>,
558 instructions: Option<String>,
559 token_limit: Option<usize>,
560 ) -> Result<()> {
561 if let Some(ref provider_name) = provider {
562 let parsed: Provider = provider_name.parse().with_context(|| {
564 format!(
565 "Unknown provider '{}'. Supported: {}",
566 provider_name,
567 Provider::all_names().join(", ")
568 )
569 })?;
570
571 self.default_provider = parsed.name().to_string();
572
573 if !self.providers.contains_key(parsed.name()) {
575 self.providers.insert(
576 parsed.name().to_string(),
577 ProviderConfig::with_defaults(parsed),
578 );
579 }
580 }
581
582 let provider_config = self
583 .providers
584 .get_mut(&self.default_provider)
585 .context("Could not get default provider config")?;
586
587 if let Some(key) = api_key {
588 provider_config.api_key = key;
589 }
590 if let Some(m) = model {
591 provider_config.model = m;
592 }
593 if let Some(fm) = fast_model {
594 provider_config.fast_model = Some(fm);
595 }
596 if let Some(params) = additional_params {
597 provider_config.additional_params.extend(params);
598 }
599 if let Some(gitmoji) = use_gitmoji {
600 self.use_gitmoji = gitmoji;
601 }
602 if let Some(instr) = instructions {
603 self.instructions = instr;
604 }
605 if let Some(limit) = token_limit {
606 provider_config.token_limit = Some(limit);
607 }
608
609 log_debug!("Configuration updated");
610 Ok(())
611 }
612
613 #[must_use]
615 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
616 let name = if provider.eq_ignore_ascii_case("claude") {
618 "anthropic"
619 } else if provider.eq_ignore_ascii_case("gemini") {
620 "google"
621 } else {
622 provider
623 };
624
625 self.providers
626 .get(name)
627 .or_else(|| self.providers.get(&name.to_lowercase()))
628 }
629
630 #[must_use]
632 pub fn provider(&self) -> Option<Provider> {
633 self.default_provider.parse().ok()
634 }
635
636 pub fn validate(&self) -> Result<()> {
642 let provider: Provider = self
643 .default_provider
644 .parse()
645 .with_context(|| format!("Invalid provider: {}", self.default_provider))?;
646
647 let config = self
648 .get_provider_config(provider.name())
649 .ok_or_else(|| anyhow!("No configuration found for provider: {}", provider.name()))?;
650
651 if !config.has_api_key() {
652 if std::env::var(provider.api_key_env()).is_err() {
654 return Err(anyhow!(
655 "API key required for {}. Set {} or configure in ~/.config/git-iris/config.toml",
656 provider.name(),
657 provider.api_key_env()
658 ));
659 }
660 }
661
662 Ok(())
663 }
664}
665
666#[cfg(test)]
667mod tests;