1use crate::git::GitRepo;
7use crate::instruction_presets::get_instruction_preset_library;
8use crate::log_debug;
9use crate::providers::{Provider, ProviderConfig};
10
11use anyhow::{Context, Result, anyhow};
12use dirs::config_dir;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs;
16use std::path::{Path, PathBuf};
17
18pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
20
21#[derive(Deserialize, Serialize, Clone, Debug)]
23pub struct Config {
24 #[serde(default, skip_serializing_if = "String::is_empty")]
26 pub default_provider: String,
27 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
29 pub providers: HashMap<String, ProviderConfig>,
30 #[serde(default = "default_true", skip_serializing_if = "is_true")]
32 pub use_gitmoji: bool,
33 #[serde(default, skip_serializing_if = "String::is_empty")]
35 pub instructions: String,
36 #[serde(default = "default_preset", skip_serializing_if = "is_default_preset")]
38 pub instruction_preset: String,
39 #[serde(default, skip_serializing_if = "String::is_empty")]
41 pub theme: String,
42 #[serde(
44 default = "default_subagent_timeout",
45 skip_serializing_if = "is_default_subagent_timeout"
46 )]
47 pub subagent_timeout_secs: u64,
48 #[serde(skip)]
50 pub temp_instructions: Option<String>,
51 #[serde(skip)]
53 pub temp_preset: Option<String>,
54 #[serde(skip)]
56 pub is_project_config: bool,
57 #[serde(skip)]
59 pub gitmoji_override: Option<bool>,
60}
61
62fn default_true() -> bool {
63 true
64}
65
66#[allow(clippy::trivially_copy_pass_by_ref)]
67fn is_true(val: &bool) -> bool {
68 *val
69}
70
71fn default_preset() -> String {
72 "default".to_string()
73}
74
75fn is_default_preset(val: &str) -> bool {
76 val.is_empty() || val == "default"
77}
78
79fn default_subagent_timeout() -> u64 {
80 120 }
82
83#[allow(clippy::trivially_copy_pass_by_ref)]
84fn is_default_subagent_timeout(val: &u64) -> bool {
85 *val == 120
86}
87
88impl Default for Config {
89 fn default() -> Self {
90 let mut providers = HashMap::new();
91 for provider in Provider::ALL {
92 providers.insert(
93 provider.name().to_string(),
94 ProviderConfig::with_defaults(*provider),
95 );
96 }
97
98 Self {
99 default_provider: Provider::default().name().to_string(),
100 providers,
101 use_gitmoji: true,
102 instructions: String::new(),
103 instruction_preset: default_preset(),
104 theme: String::new(),
105 subagent_timeout_secs: default_subagent_timeout(),
106 temp_instructions: None,
107 temp_preset: None,
108 is_project_config: false,
109 gitmoji_override: None,
110 }
111 }
112}
113
114impl Config {
115 pub fn load() -> Result<Self> {
121 let config_path = Self::get_personal_config_path()?;
122 let mut config = if config_path.exists() {
123 let content = fs::read_to_string(&config_path)?;
124 let config: Self = toml::from_str(&content)?;
125 Self::migrate_if_needed(config)
126 } else {
127 Self::default()
128 };
129
130 if let Ok((project_config, project_source)) = Self::load_project_config_with_source() {
132 config.merge_loaded_project_config(project_config, &project_source);
133 }
134
135 log_debug!(
136 "Configuration loaded (provider: {}, gitmoji: {})",
137 config.default_provider,
138 config.use_gitmoji
139 );
140 Ok(config)
141 }
142
143 pub fn load_project_config() -> Result<Self> {
149 let (config, _) = Self::load_project_config_with_source()?;
150 Ok(config)
151 }
152
153 fn load_project_config_with_source() -> Result<(Self, toml::Value)> {
154 let config_path = Self::get_project_config_path()?;
155 if !config_path.exists() {
156 return Err(anyhow!("Project configuration file not found"));
157 }
158
159 let content = fs::read_to_string(&config_path)
160 .with_context(|| format!("Failed to read {}", config_path.display()))?;
161 let project_source = toml::from_str(&content).with_context(|| {
162 format!(
163 "Invalid {} format. Check for syntax errors.",
164 PROJECT_CONFIG_FILENAME
165 )
166 })?;
167
168 let mut config: Self = toml::from_str(&content).with_context(|| {
169 format!(
170 "Invalid {} format. Check for syntax errors.",
171 PROJECT_CONFIG_FILENAME
172 )
173 })?;
174
175 config.is_project_config = true;
176 Ok((config, project_source))
177 }
178
179 pub fn get_project_config_path() -> Result<PathBuf> {
185 let repo_root = GitRepo::get_repo_root()?;
186 Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
187 }
188
189 pub fn merge_with_project_config(&mut self, project_config: Self) {
191 log_debug!("Merging with project configuration");
192
193 if !project_config.default_provider.is_empty()
195 && project_config.default_provider != Provider::default().name()
196 {
197 self.default_provider = project_config.default_provider;
198 }
199
200 for (provider_name, proj_config) in project_config.providers {
202 let entry = self.providers.entry(provider_name).or_default();
203
204 if !proj_config.model.is_empty() {
205 entry.model = proj_config.model;
206 }
207 if proj_config.fast_model.is_some() {
208 entry.fast_model = proj_config.fast_model;
209 }
210 if proj_config.token_limit.is_some() {
211 entry.token_limit = proj_config.token_limit;
212 }
213 entry
214 .additional_params
215 .extend(proj_config.additional_params);
216 }
217
218 self.use_gitmoji = project_config.use_gitmoji;
220 self.instructions = project_config.instructions;
221
222 if project_config.instruction_preset != default_preset() {
223 self.instruction_preset = project_config.instruction_preset;
224 }
225
226 if !project_config.theme.is_empty() {
228 self.theme = project_config.theme;
229 }
230
231 if project_config.subagent_timeout_secs != default_subagent_timeout() {
233 self.subagent_timeout_secs = project_config.subagent_timeout_secs;
234 }
235 }
236
237 fn merge_loaded_project_config(&mut self, project_config: Self, project_source: &toml::Value) {
238 log_debug!("Merging loaded project configuration with explicit field tracking");
239
240 self.merge_project_provider_config(&project_config);
241
242 if Self::project_config_has_key(project_source, "default_provider") {
243 self.default_provider = project_config.default_provider;
244 }
245 if Self::project_config_has_key(project_source, "use_gitmoji") {
246 self.use_gitmoji = project_config.use_gitmoji;
247 }
248 if Self::project_config_has_key(project_source, "instructions") {
249 self.instructions = project_config.instructions;
250 }
251 if Self::project_config_has_key(project_source, "instruction_preset") {
252 self.instruction_preset = project_config.instruction_preset;
253 }
254 if Self::project_config_has_key(project_source, "theme") {
255 self.theme = project_config.theme;
256 }
257 if Self::project_config_has_key(project_source, "subagent_timeout_secs") {
258 self.subagent_timeout_secs = project_config.subagent_timeout_secs;
259 }
260 }
261
262 fn merge_project_provider_config(&mut self, project_config: &Self) {
263 for (provider_name, proj_config) in &project_config.providers {
264 let entry = self.providers.entry(provider_name.clone()).or_default();
265
266 if !proj_config.model.is_empty() {
267 proj_config.model.clone_into(&mut entry.model);
268 }
269 if proj_config.fast_model.is_some() {
270 entry.fast_model.clone_from(&proj_config.fast_model);
271 }
272 if proj_config.token_limit.is_some() {
273 entry.token_limit = proj_config.token_limit;
274 }
275 entry
276 .additional_params
277 .extend(proj_config.additional_params.clone());
278 }
279 }
280
281 fn project_config_has_key(project_source: &toml::Value, key: &str) -> bool {
282 project_source
283 .as_table()
284 .is_some_and(|table| table.contains_key(key))
285 }
286
287 fn migrate_if_needed(mut config: Self) -> Self {
289 let mut migrated = false;
290
291 for (legacy, canonical) in [("claude", "anthropic"), ("gemini", "google")] {
292 if let Some(legacy_config) = config.providers.remove(legacy) {
293 log_debug!("Migrating '{legacy}' provider to '{canonical}'");
294
295 if config.providers.contains_key(canonical) {
296 log_debug!(
297 "Keeping existing '{canonical}' config and dropping legacy '{legacy}' entry"
298 );
299 } else {
300 config
301 .providers
302 .insert(canonical.to_string(), legacy_config);
303 }
304
305 migrated = true;
306 }
307
308 if config.default_provider.eq_ignore_ascii_case(legacy) {
309 config.default_provider = canonical.to_string();
310 migrated = true;
311 }
312 }
313
314 if migrated && let Err(e) = config.save() {
315 log_debug!("Failed to save migrated config: {}", e);
316 }
317
318 config
319 }
320
321 pub fn save(&self) -> Result<()> {
327 if self.is_project_config {
328 return Ok(());
329 }
330
331 let config_path = Self::get_personal_config_path()?;
332 let content = toml::to_string_pretty(self)?;
333 Self::write_config_file(&config_path, &content)?;
334 log_debug!("Configuration saved");
335 Ok(())
336 }
337
338 pub fn save_as_project_config(&self) -> Result<()> {
344 let config_path = Self::get_project_config_path()?;
345
346 let mut project_config = self.clone();
347 project_config.is_project_config = true;
348
349 for provider_config in project_config.providers.values_mut() {
351 provider_config.api_key.clear();
352 }
353
354 let content = toml::to_string_pretty(&project_config)?;
355 Self::write_config_file(&config_path, &content)?;
356 Ok(())
357 }
358
359 fn write_config_file(path: &Path, content: &str) -> Result<()> {
365 #[cfg(unix)]
366 {
367 use std::os::unix::fs::PermissionsExt;
368
369 let tmp_path = path.with_extension("tmp");
371 fs::write(&tmp_path, content)?;
372 if let Err(e) = fs::set_permissions(&tmp_path, fs::Permissions::from_mode(0o600)) {
373 eprintln!(
374 "Warning: Could not restrict config permissions on {}: {e}",
375 tmp_path.display()
376 );
377 }
378 fs::rename(&tmp_path, path)?;
379 }
380
381 #[cfg(not(unix))]
382 {
383 fs::write(path, content)?;
384 }
385
386 Ok(())
387 }
388
389 fn resolve_personal_config_dir(
390 xdg_config_home: Option<PathBuf>,
391 platform_config_dir: Option<PathBuf>,
392 ) -> Result<PathBuf> {
393 let base_dir = xdg_config_home
394 .filter(|path| !path.as_os_str().is_empty())
395 .or(platform_config_dir)
396 .ok_or_else(|| anyhow!("Unable to determine config directory"))?;
397
398 Ok(base_dir.join("git-iris"))
399 }
400
401 pub fn get_personal_config_path() -> Result<PathBuf> {
407 let mut path = Self::resolve_personal_config_dir(
408 std::env::var_os("XDG_CONFIG_HOME").map(PathBuf::from),
409 config_dir(),
410 )?;
411 fs::create_dir_all(&path)?;
412 path.push("config.toml");
413 Ok(path)
414 }
415
416 pub fn check_environment(&self) -> Result<()> {
422 if !GitRepo::is_inside_work_tree()? {
423 return Err(anyhow!(
424 "Not in a Git repository. Please run this command from within a Git repository."
425 ));
426 }
427 Ok(())
428 }
429
430 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
432 self.temp_instructions = instructions;
433 }
434
435 pub fn set_temp_preset(&mut self, preset: Option<String>) {
437 self.temp_preset = preset;
438 }
439
440 #[must_use]
442 pub fn get_effective_preset_name(&self) -> &str {
443 self.temp_preset
444 .as_deref()
445 .unwrap_or(&self.instruction_preset)
446 }
447
448 #[must_use]
450 pub fn get_effective_instructions(&self) -> String {
451 let preset_library = get_instruction_preset_library();
452 let preset_instructions = self
453 .temp_preset
454 .as_ref()
455 .or(Some(&self.instruction_preset))
456 .and_then(|p| preset_library.get_preset(p))
457 .map(|p| p.instructions.clone())
458 .unwrap_or_default();
459
460 let custom = self
461 .temp_instructions
462 .as_ref()
463 .unwrap_or(&self.instructions);
464
465 format!("{preset_instructions}\n\n{custom}")
466 .trim()
467 .to_string()
468 }
469
470 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
472 pub fn update(
477 &mut self,
478 provider: Option<String>,
479 api_key: Option<String>,
480 model: Option<String>,
481 fast_model: Option<String>,
482 additional_params: Option<HashMap<String, String>>,
483 use_gitmoji: Option<bool>,
484 instructions: Option<String>,
485 token_limit: Option<usize>,
486 ) -> Result<()> {
487 if let Some(ref provider_name) = provider {
488 let parsed: Provider = provider_name.parse().with_context(|| {
490 format!(
491 "Unknown provider '{}'. Supported: {}",
492 provider_name,
493 Provider::all_names().join(", ")
494 )
495 })?;
496
497 self.default_provider = parsed.name().to_string();
498
499 if !self.providers.contains_key(parsed.name()) {
501 self.providers.insert(
502 parsed.name().to_string(),
503 ProviderConfig::with_defaults(parsed),
504 );
505 }
506 }
507
508 let provider_config = self
509 .providers
510 .get_mut(&self.default_provider)
511 .context("Could not get default provider config")?;
512
513 if let Some(key) = api_key {
514 provider_config.api_key = key;
515 }
516 if let Some(m) = model {
517 provider_config.model = m;
518 }
519 if let Some(fm) = fast_model {
520 provider_config.fast_model = Some(fm);
521 }
522 if let Some(params) = additional_params {
523 provider_config.additional_params.extend(params);
524 }
525 if let Some(gitmoji) = use_gitmoji {
526 self.use_gitmoji = gitmoji;
527 }
528 if let Some(instr) = instructions {
529 self.instructions = instr;
530 }
531 if let Some(limit) = token_limit {
532 provider_config.token_limit = Some(limit);
533 }
534
535 log_debug!("Configuration updated");
536 Ok(())
537 }
538
539 #[must_use]
541 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
542 let name = if provider.eq_ignore_ascii_case("claude") {
544 "anthropic"
545 } else if provider.eq_ignore_ascii_case("gemini") {
546 "google"
547 } else {
548 provider
549 };
550
551 self.providers
552 .get(name)
553 .or_else(|| self.providers.get(&name.to_lowercase()))
554 }
555
556 #[must_use]
558 pub fn provider(&self) -> Option<Provider> {
559 self.default_provider.parse().ok()
560 }
561
562 pub fn validate(&self) -> Result<()> {
568 let provider: Provider = self
569 .default_provider
570 .parse()
571 .with_context(|| format!("Invalid provider: {}", self.default_provider))?;
572
573 let config = self
574 .get_provider_config(provider.name())
575 .ok_or_else(|| anyhow!("No configuration found for provider: {}", provider.name()))?;
576
577 if !config.has_api_key() {
578 if std::env::var(provider.api_key_env()).is_err() {
580 return Err(anyhow!(
581 "API key required for {}. Set {} or configure in ~/.config/git-iris/config.toml",
582 provider.name(),
583 provider.api_key_env()
584 ));
585 }
586 }
587
588 Ok(())
589 }
590}
591
592#[cfg(test)]
593mod tests;