1use crate::git::GitRepo;
7use crate::instruction_presets::get_instruction_preset_library;
8use crate::log_debug;
9use crate::providers::{Provider, ProviderConfig};
10
11use anyhow::{Context, Result, anyhow};
12use dirs::config_dir;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs;
16use std::path::PathBuf;
17
18pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
20
21#[derive(Deserialize, Serialize, Clone, Debug)]
23pub struct Config {
24 #[serde(default, skip_serializing_if = "String::is_empty")]
26 pub default_provider: String,
27 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
29 pub providers: HashMap<String, ProviderConfig>,
30 #[serde(default = "default_true", skip_serializing_if = "is_true")]
32 pub use_gitmoji: bool,
33 #[serde(default, skip_serializing_if = "String::is_empty")]
35 pub instructions: String,
36 #[serde(default = "default_preset", skip_serializing_if = "is_default_preset")]
38 pub instruction_preset: String,
39 #[serde(default, skip_serializing_if = "String::is_empty")]
41 pub theme: String,
42 #[serde(
44 default = "default_subagent_timeout",
45 skip_serializing_if = "is_default_subagent_timeout"
46 )]
47 pub subagent_timeout_secs: u64,
48 #[serde(skip)]
50 pub temp_instructions: Option<String>,
51 #[serde(skip)]
53 pub temp_preset: Option<String>,
54 #[serde(skip)]
56 pub is_project_config: bool,
57 #[serde(skip)]
59 pub gitmoji_override: Option<bool>,
60}
61
62fn default_true() -> bool {
63 true
64}
65
66#[allow(clippy::trivially_copy_pass_by_ref)]
67fn is_true(val: &bool) -> bool {
68 *val
69}
70
71fn default_preset() -> String {
72 "default".to_string()
73}
74
75fn is_default_preset(val: &str) -> bool {
76 val.is_empty() || val == "default"
77}
78
79fn default_subagent_timeout() -> u64 {
80 120 }
82
83#[allow(clippy::trivially_copy_pass_by_ref)]
84fn is_default_subagent_timeout(val: &u64) -> bool {
85 *val == 120
86}
87
88impl Default for Config {
89 fn default() -> Self {
90 let mut providers = HashMap::new();
91 for provider in Provider::ALL {
92 providers.insert(
93 provider.name().to_string(),
94 ProviderConfig::with_defaults(*provider),
95 );
96 }
97
98 Self {
99 default_provider: Provider::default().name().to_string(),
100 providers,
101 use_gitmoji: true,
102 instructions: String::new(),
103 instruction_preset: default_preset(),
104 theme: String::new(),
105 subagent_timeout_secs: default_subagent_timeout(),
106 temp_instructions: None,
107 temp_preset: None,
108 is_project_config: false,
109 gitmoji_override: None,
110 }
111 }
112}
113
114impl Config {
115 pub fn load() -> Result<Self> {
117 let config_path = Self::get_config_path()?;
118 let mut config = if config_path.exists() {
119 let content = fs::read_to_string(&config_path)?;
120 let config: Self = toml::from_str(&content)?;
121 Self::migrate_if_needed(config)
122 } else {
123 Self::default()
124 };
125
126 if let Ok(project_config) = Self::load_project_config() {
128 config.merge_with_project_config(project_config);
129 }
130
131 log_debug!(
132 "Configuration loaded (provider: {}, gitmoji: {})",
133 config.default_provider,
134 config.use_gitmoji
135 );
136 Ok(config)
137 }
138
139 pub fn load_project_config() -> Result<Self> {
141 let config_path = Self::get_project_config_path()?;
142 if !config_path.exists() {
143 return Err(anyhow!("Project configuration file not found"));
144 }
145
146 let content = fs::read_to_string(&config_path)
147 .with_context(|| format!("Failed to read {}", config_path.display()))?;
148
149 let mut config: Self = toml::from_str(&content).with_context(|| {
150 format!(
151 "Invalid {} format. Check for syntax errors.",
152 PROJECT_CONFIG_FILENAME
153 )
154 })?;
155
156 config.is_project_config = true;
157 Ok(config)
158 }
159
160 pub fn get_project_config_path() -> Result<PathBuf> {
162 let repo_root = GitRepo::get_repo_root()?;
163 Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
164 }
165
166 pub fn merge_with_project_config(&mut self, project_config: Self) {
168 log_debug!("Merging with project configuration");
169
170 if !project_config.default_provider.is_empty()
172 && project_config.default_provider != Provider::default().name()
173 {
174 self.default_provider = project_config.default_provider;
175 }
176
177 for (provider_name, proj_config) in project_config.providers {
179 let entry = self.providers.entry(provider_name).or_default();
180
181 if !proj_config.model.is_empty() {
182 entry.model = proj_config.model;
183 }
184 if proj_config.fast_model.is_some() {
185 entry.fast_model = proj_config.fast_model;
186 }
187 if proj_config.token_limit.is_some() {
188 entry.token_limit = proj_config.token_limit;
189 }
190 entry
191 .additional_params
192 .extend(proj_config.additional_params);
193 }
194
195 self.use_gitmoji = project_config.use_gitmoji;
197 self.instructions = project_config.instructions;
198
199 if project_config.instruction_preset != default_preset() {
200 self.instruction_preset = project_config.instruction_preset;
201 }
202
203 if !project_config.theme.is_empty() {
205 self.theme = project_config.theme;
206 }
207
208 if project_config.subagent_timeout_secs != default_subagent_timeout() {
210 self.subagent_timeout_secs = project_config.subagent_timeout_secs;
211 }
212 }
213
214 fn migrate_if_needed(mut config: Self) -> Self {
216 let mut migrated = false;
217
218 if config.providers.contains_key("claude") {
220 log_debug!("Migrating 'claude' provider to 'anthropic'");
221 if let Some(claude_config) = config.providers.remove("claude") {
222 config
223 .providers
224 .insert("anthropic".to_string(), claude_config);
225 }
226 if config.default_provider == "claude" {
227 config.default_provider = "anthropic".to_string();
228 }
229 migrated = true;
230 }
231
232 if migrated && let Err(e) = config.save() {
233 log_debug!("Failed to save migrated config: {}", e);
234 }
235
236 config
237 }
238
239 pub fn save(&self) -> Result<()> {
241 if self.is_project_config {
242 return Ok(());
243 }
244
245 let config_path = Self::get_config_path()?;
246 let content = toml::to_string_pretty(self)?;
247 fs::write(config_path, content)?;
248 log_debug!("Configuration saved");
249 Ok(())
250 }
251
252 pub fn save_as_project_config(&self) -> Result<()> {
254 let config_path = Self::get_project_config_path()?;
255
256 let mut project_config = self.clone();
257 project_config.is_project_config = true;
258
259 for provider_config in project_config.providers.values_mut() {
261 provider_config.api_key.clear();
262 }
263
264 let content = toml::to_string_pretty(&project_config)?;
265 fs::write(config_path, content)?;
266 Ok(())
267 }
268
269 fn get_config_path() -> Result<PathBuf> {
271 let mut path =
272 config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
273 path.push("git-iris");
274 fs::create_dir_all(&path)?;
275 path.push("config.toml");
276 Ok(path)
277 }
278
279 pub fn check_environment(&self) -> Result<()> {
281 if !GitRepo::is_inside_work_tree()? {
282 return Err(anyhow!(
283 "Not in a Git repository. Please run this command from within a Git repository."
284 ));
285 }
286 Ok(())
287 }
288
289 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
291 self.temp_instructions = instructions;
292 }
293
294 pub fn set_temp_preset(&mut self, preset: Option<String>) {
296 self.temp_preset = preset;
297 }
298
299 pub fn get_effective_preset_name(&self) -> &str {
301 self.temp_preset
302 .as_deref()
303 .unwrap_or(&self.instruction_preset)
304 }
305
306 pub fn get_effective_instructions(&self) -> String {
308 let preset_library = get_instruction_preset_library();
309 let preset_instructions = self
310 .temp_preset
311 .as_ref()
312 .or(Some(&self.instruction_preset))
313 .and_then(|p| preset_library.get_preset(p))
314 .map(|p| p.instructions.clone())
315 .unwrap_or_default();
316
317 let custom = self
318 .temp_instructions
319 .as_ref()
320 .unwrap_or(&self.instructions);
321
322 format!("{preset_instructions}\n\n{custom}")
323 .trim()
324 .to_string()
325 }
326
327 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
329 pub fn update(
330 &mut self,
331 provider: Option<String>,
332 api_key: Option<String>,
333 model: Option<String>,
334 fast_model: Option<String>,
335 additional_params: Option<HashMap<String, String>>,
336 use_gitmoji: Option<bool>,
337 instructions: Option<String>,
338 token_limit: Option<usize>,
339 ) -> Result<()> {
340 if let Some(ref provider_name) = provider {
341 let parsed: Provider = provider_name.parse().with_context(|| {
343 format!(
344 "Unknown provider '{}'. Supported: {}",
345 provider_name,
346 Provider::all_names().join(", ")
347 )
348 })?;
349
350 self.default_provider = parsed.name().to_string();
351
352 if !self.providers.contains_key(parsed.name()) {
354 self.providers.insert(
355 parsed.name().to_string(),
356 ProviderConfig::with_defaults(parsed),
357 );
358 }
359 }
360
361 let provider_config = self
362 .providers
363 .get_mut(&self.default_provider)
364 .context("Could not get default provider config")?;
365
366 if let Some(key) = api_key {
367 provider_config.api_key = key;
368 }
369 if let Some(m) = model {
370 provider_config.model = m;
371 }
372 if let Some(fm) = fast_model {
373 provider_config.fast_model = Some(fm);
374 }
375 if let Some(params) = additional_params {
376 provider_config.additional_params.extend(params);
377 }
378 if let Some(gitmoji) = use_gitmoji {
379 self.use_gitmoji = gitmoji;
380 }
381 if let Some(instr) = instructions {
382 self.instructions = instr;
383 }
384 if let Some(limit) = token_limit {
385 provider_config.token_limit = Some(limit);
386 }
387
388 log_debug!("Configuration updated");
389 Ok(())
390 }
391
392 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
394 let name = if provider.eq_ignore_ascii_case("claude") {
396 "anthropic"
397 } else {
398 provider
399 };
400
401 self.providers
402 .get(name)
403 .or_else(|| self.providers.get(&name.to_lowercase()))
404 }
405
406 pub fn provider(&self) -> Option<Provider> {
408 self.default_provider.parse().ok()
409 }
410
411 pub fn validate(&self) -> Result<()> {
413 let provider: Provider = self
414 .default_provider
415 .parse()
416 .with_context(|| format!("Invalid provider: {}", self.default_provider))?;
417
418 let config = self
419 .get_provider_config(provider.name())
420 .ok_or_else(|| anyhow!("No configuration found for provider: {}", provider.name()))?;
421
422 if !config.has_api_key() {
423 if std::env::var(provider.api_key_env()).is_err() {
425 return Err(anyhow!(
426 "API key required for {}. Set {} or configure in ~/.config/git-iris/config.toml",
427 provider.name(),
428 provider.api_key_env()
429 ));
430 }
431 }
432
433 Ok(())
434 }
435}