git_iris/
config.rs

1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm::{
4    get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
5};
6use crate::log_debug;
7
8use anyhow::{Context, Result, anyhow};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14
15/// Configuration structure for the Git-Iris application
16#[derive(Deserialize, Serialize, Clone, Debug)]
17pub struct Config {
18    /// Default LLM provider
19    pub default_provider: String,
20    /// Provider-specific configurations
21    pub providers: HashMap<String, ProviderConfig>,
22    /// Flag indicating whether to use Gitmoji
23    #[serde(default = "default_gitmoji")]
24    pub use_gitmoji: bool,
25    /// Instructions for commit messages
26    #[serde(default)]
27    pub instructions: String,
28    #[serde(default = "default_instruction_preset")]
29    pub instruction_preset: String,
30    #[serde(skip)]
31    pub temp_instructions: Option<String>,
32    #[serde(skip)]
33    pub temp_preset: Option<String>,
34    /// Flag indicating if this config is from a project file
35    #[serde(skip)]
36    pub is_project_config: bool,
37}
38
39/// Provider-specific configuration structure
40#[derive(Deserialize, Serialize, Clone, Debug, Default)]
41pub struct ProviderConfig {
42    /// API key for the provider
43    pub api_key: String,
44    /// Model to be used with the provider
45    pub model: String,
46    /// Additional parameters for the provider
47    #[serde(default)]
48    pub additional_params: HashMap<String, String>,
49    /// Token limit, if set by the user
50    pub token_limit: Option<usize>,
51}
52
53/// Default function for `use_gitmoji`
54fn default_gitmoji() -> bool {
55    true
56}
57
58// Default instruction preset to use
59fn default_instruction_preset() -> String {
60    "default".to_string()
61}
62
63/// Project configuration filename
64pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
65
66impl Config {
67    /// Load the configuration from the file
68    pub fn load() -> Result<Self> {
69        // First load personal config
70        let config_path = Self::get_config_path()?;
71        let mut config = if config_path.exists() {
72            let config_content = fs::read_to_string(&config_path)?;
73            let config: Self = toml::from_str(&config_content)?;
74            Self::migrate_if_needed(config)
75        } else {
76            Self::default()
77        };
78
79        // Then try to load and merge project config if available
80        if let Ok(project_config) = Self::load_project_config() {
81            config.merge_with_project_config(project_config);
82        }
83
84        log_debug!("Configuration loaded: {:?}", config);
85        Ok(config)
86    }
87
88    /// Load project-specific configuration
89    pub fn load_project_config() -> Result<Self, anyhow::Error> {
90        let config_path = Self::get_project_config_path()?;
91        if !config_path.exists() {
92            return Err(anyhow::anyhow!("Project configuration file not found"));
93        }
94
95        // Read the config file with improved error handling
96        let config_str = match fs::read_to_string(&config_path) {
97            Ok(content) => content,
98            Err(e) => return Err(anyhow::anyhow!("Failed to read project config file: {}", e)),
99        };
100
101        // Parse the TOML with improved error handling
102        let mut config: Self = match toml::from_str(&config_str) {
103            Ok(config) => config,
104            Err(e) => {
105                return Err(anyhow::anyhow!(
106                    "Invalid project configuration file format: {}. Please check your {} file for syntax errors.",
107                    e,
108                    PROJECT_CONFIG_FILENAME
109                ));
110            }
111        };
112
113        config.is_project_config = true;
114        Ok(config)
115    }
116
117    /// Get the path to the project configuration file
118    pub fn get_project_config_path() -> Result<PathBuf, anyhow::Error> {
119        // Use the static method to get repo root
120        let repo_root = crate::git::GitRepo::get_repo_root()?;
121        Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
122    }
123
124    /// Merge this config with project-specific config, with project config taking precedence
125    /// But never allow API keys from project config
126    pub fn merge_with_project_config(&mut self, project_config: Self) {
127        log_debug!("Merging with project configuration");
128
129        // Override default provider if set in project config
130        if project_config.default_provider != Self::default().default_provider {
131            self.default_provider = project_config.default_provider;
132        }
133
134        // Merge provider configs, but never allow API keys from project config
135        for (provider, proj_provider_config) in project_config.providers {
136            let entry = self.providers.entry(provider).or_default();
137
138            // Don't override API keys from project config (security)
139            if !proj_provider_config.model.is_empty() {
140                entry.model = proj_provider_config.model;
141            }
142
143            // Merge additional params
144            entry
145                .additional_params
146                .extend(proj_provider_config.additional_params);
147
148            // Override token limit if set in project config
149            if proj_provider_config.token_limit.is_some() {
150                entry.token_limit = proj_provider_config.token_limit;
151            }
152        }
153
154        // Override other settings
155        self.use_gitmoji = project_config.use_gitmoji;
156
157        // Always override instructions field if set in project config
158        self.instructions = project_config.instructions.clone();
159
160        // Override preset
161        if project_config.instruction_preset != default_instruction_preset() {
162            self.instruction_preset = project_config.instruction_preset;
163        }
164    }
165
166    /// Migrate older config formats if needed
167    fn migrate_if_needed(mut config: Self) -> Self {
168        // Migration: rename "claude" provider to "anthropic" if it exists
169        let mut migration_performed = false;
170        if config.providers.contains_key("claude") {
171            log_debug!("Migrating 'claude' provider to 'anthropic'");
172            if let Some(claude_config) = config.providers.remove("claude") {
173                config
174                    .providers
175                    .insert("anthropic".to_string(), claude_config);
176            }
177
178            // Update default provider if it was set to claude
179            if config.default_provider == "claude" {
180                config.default_provider = "anthropic".to_string();
181            }
182
183            migration_performed = true;
184        }
185
186        // Save the config if a migration was performed
187        if migration_performed {
188            log_debug!("Saving configuration after migration");
189            if let Err(e) = config.save() {
190                log_debug!("Failed to save migrated config: {}", e);
191            }
192        }
193
194        config
195    }
196
197    /// Save the configuration to the file
198    pub fn save(&self) -> Result<()> {
199        // Don't save project configs to personal config file
200        if self.is_project_config {
201            return Ok(());
202        }
203
204        let config_path = Self::get_config_path()?;
205        let config_content = toml::to_string(self)?;
206        fs::write(config_path, config_content)?;
207        log_debug!("Configuration saved: {:?}", self);
208        Ok(())
209    }
210
211    /// Save the configuration as a project-specific configuration
212    pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
213        let config_path = Self::get_project_config_path()?;
214
215        // Before saving, create a copy that excludes API keys
216        let mut project_config = self.clone();
217
218        // Remove API keys from all providers
219        for provider_config in project_config.providers.values_mut() {
220            provider_config.api_key.clear();
221        }
222
223        // Mark as project config
224        project_config.is_project_config = true;
225
226        // Convert to TOML string
227        let config_str = toml::to_string_pretty(&project_config)?;
228
229        // Write to file
230        fs::write(config_path, config_str)?;
231
232        Ok(())
233    }
234
235    /// Get the path to the configuration file
236    fn get_config_path() -> Result<PathBuf> {
237        let mut path =
238            config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
239        path.push("git-iris");
240        std::fs::create_dir_all(&path)?;
241        path.push("config.toml");
242        Ok(path)
243    }
244
245    /// Check the environment for necessary prerequisites
246    pub fn check_environment(&self) -> Result<()> {
247        // Check if we're in a git repository
248        if !GitRepo::is_inside_work_tree()? {
249            return Err(anyhow!(
250                "Not in a Git repository. Please run this command from within a Git repository."
251            ));
252        }
253
254        Ok(())
255    }
256
257    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
258        self.temp_instructions = instructions;
259    }
260
261    pub fn set_temp_preset(&mut self, preset: Option<String>) {
262        self.temp_preset = preset;
263    }
264
265    pub fn get_effective_instructions(&self) -> String {
266        let preset_library = get_instruction_preset_library();
267        let preset_instructions = self
268            .temp_preset
269            .as_ref()
270            .or(Some(&self.instruction_preset))
271            .and_then(|p| preset_library.get_preset(p))
272            .map(|p| p.instructions.clone())
273            .unwrap_or_default();
274
275        let custom_instructions = self
276            .temp_instructions
277            .as_ref()
278            .unwrap_or(&self.instructions);
279
280        format!("{preset_instructions}\n\n{custom_instructions}")
281            .trim()
282            .to_string()
283    }
284
285    /// Update the configuration with new values
286    #[allow(clippy::too_many_arguments)]
287    pub fn update(
288        &mut self,
289        provider: Option<String>,
290        api_key: Option<String>,
291        model: Option<String>,
292        additional_params: Option<HashMap<String, String>>,
293        use_gitmoji: Option<bool>,
294        instructions: Option<String>,
295        token_limit: Option<usize>,
296    ) -> anyhow::Result<()> {
297        if let Some(provider) = provider {
298            self.default_provider.clone_from(&provider);
299            if !self.providers.contains_key(&provider) {
300                // Only insert a new provider if it requires configuration
301                if provider_requires_api_key(&provider.to_lowercase()) {
302                    self.providers.insert(
303                        provider.clone(),
304                        ProviderConfig::default_for(&provider.to_lowercase()),
305                    );
306                }
307            }
308        }
309
310        let provider_config = self
311            .providers
312            .get_mut(&self.default_provider)
313            .context("Could not get default provider")?;
314
315        if let Some(key) = api_key {
316            provider_config.api_key = key;
317        }
318        if let Some(model) = model {
319            provider_config.model = model;
320        }
321        if let Some(params) = additional_params {
322            provider_config.additional_params.extend(params);
323        }
324        if let Some(gitmoji) = use_gitmoji {
325            self.use_gitmoji = gitmoji;
326        }
327        if let Some(instr) = instructions {
328            self.instructions = instr;
329        }
330        if let Some(limit) = token_limit {
331            provider_config.token_limit = Some(limit);
332        }
333
334        log_debug!("Configuration updated: {:?}", self);
335        Ok(())
336    }
337
338    /// Get the configuration for a specific provider
339    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
340        // Special case: redirect "claude" to "anthropic"
341        let provider_to_lookup = if provider.to_lowercase() == "claude" {
342            "anthropic"
343        } else {
344            provider
345        };
346
347        // First try direct lookup
348        self.providers.get(provider_to_lookup).or_else(|| {
349            // If not found, try lowercased version
350            let lowercase_provider = provider_to_lookup.to_lowercase();
351
352            self.providers.get(&lowercase_provider).or_else(|| {
353                // If the provider is not in the config, check if it's a valid provider
354                if get_available_provider_names().contains(&lowercase_provider) {
355                    // Return None for valid providers not in the config
356                    // This allows the code to use default values for providers like Ollama
357                    None
358                } else {
359                    // Return None for invalid providers
360                    None
361                }
362            })
363        })
364    }
365
366    /// Set whether this config is a project config
367    pub fn set_project_config(&mut self, is_project: bool) {
368        self.is_project_config = is_project;
369    }
370
371    /// Check if this is a project config
372    pub fn is_project_config(&self) -> bool {
373        self.is_project_config
374    }
375}
376
377impl Default for Config {
378    fn default() -> Self {
379        let mut providers = HashMap::new();
380        for provider in get_available_provider_names() {
381            providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
382        }
383
384        // Default to OpenAI if available, otherwise use the first available provider
385        let default_provider = if providers.contains_key("openai") {
386            "openai".to_string()
387        } else {
388            providers.keys().next().map_or_else(
389                || "openai".to_string(), // Fallback even if no providers (should never happen)
390                std::clone::Clone::clone,
391            )
392        };
393
394        Self {
395            default_provider,
396            providers,
397            use_gitmoji: default_gitmoji(),
398            instructions: String::new(),
399            instruction_preset: default_instruction_preset(),
400            temp_instructions: None,
401            temp_preset: None,
402            is_project_config: false,
403        }
404    }
405}
406
407impl ProviderConfig {
408    /// Create a default provider configuration for a given provider
409    pub fn default_for(provider: &str) -> Self {
410        Self {
411            api_key: String::new(),
412            model: get_default_model_for_provider(provider).to_string(),
413            additional_params: HashMap::new(),
414            token_limit: None, // Will use the default from get_default_token_limit_for_provider
415        }
416    }
417
418    /// Get the token limit for this provider configuration
419    pub fn get_token_limit(&self) -> Option<usize> {
420        self.token_limit
421    }
422}