git_iris/
config.rs

1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm::{
4    get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
5};
6use crate::log_debug;
7
8use anyhow::{Context, Result, anyhow};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14
15/// Configuration structure for the Git-Iris application
16#[derive(Deserialize, Serialize, Clone, Debug)]
17pub struct Config {
18    /// Default LLM provider
19    pub default_provider: String,
20    /// Provider-specific configurations
21    pub providers: HashMap<String, ProviderConfig>,
22    /// Flag indicating whether to use Gitmoji
23    #[serde(default = "default_gitmoji")]
24    pub use_gitmoji: bool,
25    /// Instructions for commit messages
26    #[serde(default)]
27    pub instructions: String,
28    #[serde(default = "default_instruction_preset")]
29    pub instruction_preset: String,
30    #[serde(skip)]
31    pub temp_instructions: Option<String>,
32    #[serde(skip)]
33    pub temp_preset: Option<String>,
34}
35
36/// Provider-specific configuration structure
37#[derive(Deserialize, Serialize, Clone, Debug, Default)]
38pub struct ProviderConfig {
39    /// API key for the provider
40    pub api_key: String,
41    /// Model to be used with the provider
42    pub model: String,
43    /// Additional parameters for the provider
44    #[serde(default)]
45    pub additional_params: HashMap<String, String>,
46    /// Token limit, if set by the user
47    pub token_limit: Option<usize>,
48}
49
50/// Default function for `use_gitmoji`
51fn default_gitmoji() -> bool {
52    true
53}
54
55// Default instruction preset to use
56fn default_instruction_preset() -> String {
57    "default".to_string()
58}
59
60impl Config {
61    /// Load the configuration from the file
62    pub fn load() -> Result<Self> {
63        let config_path = Self::get_config_path()?;
64        if !config_path.exists() {
65            return Ok(Self::default());
66        }
67        let config_content = fs::read_to_string(config_path)?;
68        let mut config: Self = toml::from_str(&config_content)?;
69
70        // Migration: rename "claude" provider to "anthropic" if it exists
71        let mut migration_performed = false;
72        if config.providers.contains_key("claude") {
73            log_debug!("Migrating 'claude' provider to 'anthropic'");
74            if let Some(claude_config) = config.providers.remove("claude") {
75                config
76                    .providers
77                    .insert("anthropic".to_string(), claude_config);
78            }
79
80            // Update default provider if it was set to claude
81            if config.default_provider == "claude" {
82                config.default_provider = "anthropic".to_string();
83            }
84
85            migration_performed = true;
86        }
87
88        // Save the config if a migration was performed
89        if migration_performed {
90            log_debug!("Saving configuration after migration");
91            if let Err(e) = config.save() {
92                log_debug!("Failed to save migrated config: {}", e);
93            }
94        }
95
96        log_debug!("Configuration loaded: {:?}", config);
97        Ok(config)
98    }
99
100    /// Save the configuration to the file
101    pub fn save(&self) -> Result<()> {
102        let config_path = Self::get_config_path()?;
103        let config_content = toml::to_string(self)?;
104        fs::write(config_path, config_content)?;
105        log_debug!("Configuration saved: {:?}", self);
106        Ok(())
107    }
108
109    /// Get the path to the configuration file
110    fn get_config_path() -> Result<PathBuf> {
111        let mut path =
112            config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
113        path.push("git-iris");
114        std::fs::create_dir_all(&path)?;
115        path.push("config.toml");
116        Ok(path)
117    }
118
119    /// Check the environment for necessary prerequisites
120    pub fn check_environment(&self) -> Result<()> {
121        // Check if we're in a git repository
122        if !GitRepo::is_inside_work_tree()? {
123            return Err(anyhow!(
124                "Not in a Git repository. Please run this command from within a Git repository."
125            ));
126        }
127
128        Ok(())
129    }
130
131    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
132        self.temp_instructions = instructions;
133    }
134
135    pub fn set_temp_preset(&mut self, preset: Option<String>) {
136        self.temp_preset = preset;
137    }
138
139    pub fn get_effective_instructions(&self) -> String {
140        let preset_library = get_instruction_preset_library();
141        let preset_instructions = self
142            .temp_preset
143            .as_ref()
144            .or(Some(&self.instruction_preset))
145            .and_then(|p| preset_library.get_preset(p))
146            .map(|p| p.instructions.clone())
147            .unwrap_or_default();
148
149        let custom_instructions = self
150            .temp_instructions
151            .as_ref()
152            .unwrap_or(&self.instructions);
153
154        format!("{preset_instructions}\n\n{custom_instructions}")
155            .trim()
156            .to_string()
157    }
158
159    /// Update the configuration with new values
160    #[allow(clippy::too_many_arguments)]
161    pub fn update(
162        &mut self,
163        provider: Option<String>,
164        api_key: Option<String>,
165        model: Option<String>,
166        additional_params: Option<HashMap<String, String>>,
167        use_gitmoji: Option<bool>,
168        instructions: Option<String>,
169        token_limit: Option<usize>,
170    ) -> anyhow::Result<()> {
171        if let Some(provider) = provider {
172            self.default_provider.clone_from(&provider);
173            if !self.providers.contains_key(&provider) {
174                // Only insert a new provider if it requires configuration
175                if provider_requires_api_key(&provider.to_lowercase()) {
176                    self.providers.insert(
177                        provider.clone(),
178                        ProviderConfig::default_for(&provider.to_lowercase()),
179                    );
180                }
181            }
182        }
183
184        let provider_config = self
185            .providers
186            .get_mut(&self.default_provider)
187            .context("Could not get default provider")?;
188
189        if let Some(key) = api_key {
190            provider_config.api_key = key;
191        }
192        if let Some(model) = model {
193            provider_config.model = model;
194        }
195        if let Some(params) = additional_params {
196            provider_config.additional_params.extend(params);
197        }
198        if let Some(gitmoji) = use_gitmoji {
199            self.use_gitmoji = gitmoji;
200        }
201        if let Some(instr) = instructions {
202            self.instructions = instr;
203        }
204        if let Some(limit) = token_limit {
205            provider_config.token_limit = Some(limit);
206        }
207
208        log_debug!("Configuration updated: {:?}", self);
209        Ok(())
210    }
211
212    /// Get the configuration for a specific provider
213    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
214        // Special case: redirect "claude" to "anthropic"
215        let provider_to_lookup = if provider.to_lowercase() == "claude" {
216            "anthropic"
217        } else {
218            provider
219        };
220
221        // First try direct lookup
222        self.providers.get(provider_to_lookup).or_else(|| {
223            // If not found, try lowercased version
224            let lowercase_provider = provider_to_lookup.to_lowercase();
225
226            self.providers.get(&lowercase_provider).or_else(|| {
227                // If the provider is not in the config, check if it's a valid provider
228                if get_available_provider_names().contains(&lowercase_provider) {
229                    // Return None for valid providers not in the config
230                    // This allows the code to use default values for providers like Ollama
231                    None
232                } else {
233                    // Return None for invalid providers
234                    None
235                }
236            })
237        })
238    }
239}
240
241impl Default for Config {
242    fn default() -> Self {
243        let mut providers = HashMap::new();
244        for provider in get_available_provider_names() {
245            providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
246        }
247
248        // Default to OpenAI if available, otherwise use the first available provider
249        let default_provider = if providers.contains_key("openai") {
250            "openai".to_string()
251        } else {
252            providers.keys().next().map_or_else(
253                || "openai".to_string(), // Fallback even if no providers (should never happen)
254                std::clone::Clone::clone,
255            )
256        };
257
258        Self {
259            default_provider,
260            providers,
261            use_gitmoji: default_gitmoji(),
262            instructions: String::new(),
263            instruction_preset: default_instruction_preset(),
264            temp_instructions: None,
265            temp_preset: None,
266        }
267    }
268}
269
270impl ProviderConfig {
271    /// Create a default provider configuration for a given provider
272    pub fn default_for(provider: &str) -> Self {
273        Self {
274            api_key: String::new(),
275            model: get_default_model_for_provider(provider).to_string(),
276            additional_params: HashMap::new(),
277            token_limit: None, // Will use the default from get_default_token_limit_for_provider
278        }
279    }
280
281    /// Get the token limit for this provider configuration
282    pub fn get_token_limit(&self) -> Option<usize> {
283        self.token_limit
284    }
285}