gitai/
config.rs

1use crate::core::llm::{
2    get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
3};
4use crate::debug;
5use crate::git::GitRepo;
6use crate::instruction_presets::get_instruction_preset_library;
7// use crate::llm::{
8//     get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
9// };
10
11use anyhow::{Context, Result, anyhow};
12use git2::Config as GitConfig;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::process::Command;
16
17/// Configuration structure
18#[derive(Deserialize, Serialize, Clone, Debug)]
19pub struct Config {
20    /// Default LLM provider
21    pub default_provider: String,
22    /// Provider-specific configurations
23    pub providers: HashMap<String, ProviderConfig>,
24    /// Flag indicating whether to use emoji
25    #[serde(default = "default_emoji")]
26    pub use_emoji: bool,
27    /// Instructions for commit messages
28    #[serde(default)]
29    pub instructions: String,
30    #[serde(default = "default_instruction_preset")]
31    pub instruction_preset: String,
32    #[serde(skip)]
33    pub temp_instructions: Option<String>,
34    #[serde(skip)]
35    pub temp_preset: Option<String>,
36    /// Flag indicating if this config is from a project file
37    #[serde(skip)]
38    pub is_project_config: bool,
39}
40
41/// Provider-specific configuration structure
42#[derive(Deserialize, Serialize, Clone, Debug, Default)]
43pub struct ProviderConfig {
44    /// API key for the provider
45    pub api_key: String,
46    /// Model to be used with the provider
47    pub model: String,
48    /// Additional parameters for the provider
49    #[serde(default)]
50    pub additional_params: HashMap<String, String>,
51    /// Token limit, if set by the user
52    pub token_limit: Option<usize>,
53}
54
55/// Default function for `use_emoji`
56fn default_emoji() -> bool {
57    false
58}
59
60// Default instruction preset to use
61fn default_instruction_preset() -> String {
62    "default".to_string()
63}
64
65impl Config {
66    /// Load the configuration from git config
67    pub fn load() -> Result<Self> {
68        let mut config = Self::load_from_config("gitai");
69
70        // Then try to load and merge project config if available
71        if let Ok(project_config) = Self::load_project_config() {
72            config.merge_with_project_config(project_config);
73        }
74
75        debug!("Configuration loaded: {config:?}");
76        Ok(config)
77    }
78
79    /// Load configuration from git config
80    fn load_from_config(prefix: &str) -> Self {
81        let default_provider = Self::get_git_config_value(&format!("{prefix}.defaultprovider"))
82            .unwrap_or("openai".to_string());
83        let use_emoji = Self::get_git_config_bool(&format!("{prefix}.useemoji")).unwrap_or(true);
84        let instructions =
85            Self::get_git_config_value(&format!("{prefix}.instructions")).unwrap_or_default();
86        let instruction_preset = Self::get_git_config_value(&format!("{prefix}.instructionpreset"))
87            .unwrap_or("default".to_string());
88
89        let mut providers = HashMap::new();
90        // To load providers, we need to iterate over all keys with prefix
91        // But git2 Config doesn't have easy way to iterate, so for now, assume known providers
92        for provider in get_available_provider_names() {
93            if let Some(api_key) =
94                Self::get_git_config_value(&format!("{prefix}.{provider}-apikey"))
95            {
96                let default_model = get_default_model_for_provider(&provider).to_string();
97                let model = Self::get_git_config_value(&format!("{prefix}.{provider}-model"))
98                    .unwrap_or(default_model);
99                let token_limit =
100                    Self::get_git_config_i64(&format!("{prefix}.{provider}-tokenlimit")).map(|v| {
101                        usize::try_from(v).expect("Failed to convert token limit from i64 to usize")
102                    });
103                let additional_params = HashMap::new();
104                // For additional params, it's hard to iterate, so skip for now
105                providers.insert(
106                    provider.to_string(),
107                    ProviderConfig {
108                        api_key,
109                        model,
110                        additional_params,
111                        token_limit,
112                    },
113                );
114            }
115        }
116
117        Self {
118            default_provider,
119            providers,
120            use_emoji,
121            instructions,
122            instruction_preset,
123            temp_instructions: None,
124            temp_preset: None,
125            is_project_config: false,
126        }
127    }
128
129    fn get_git_config_value(key: &str) -> Option<String> {
130        let output = Command::new("git")
131            .args(["config", "--get", key])
132            .output()
133            .ok()?;
134        if output.status.success() {
135            Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
136        } else {
137            None
138        }
139    }
140
141    fn get_git_config_bool(key: &str) -> Option<bool> {
142        Self::get_git_config_value(key).and_then(|v| v.parse().ok())
143    }
144
145    fn get_git_config_i64(key: &str) -> Option<i64> {
146        Self::get_git_config_value(key).and_then(|v| v.parse().ok())
147    }
148
149    /// Load project-specific configuration
150    pub fn load_project_config() -> Result<Self, anyhow::Error> {
151        let mut project_config = Self::load_from_config("gitai");
152        project_config.is_project_config = true;
153        Ok(project_config)
154    }
155
156    /// Merge this config with project-specific config, with project config taking precedence
157    /// But never allow API keys from project config
158    pub fn merge_with_project_config(&mut self, project_config: Self) {
159        debug!("Merging with project configuration");
160
161        // Override default provider if set in project config
162        if project_config.default_provider != Self::default().default_provider {
163            self.default_provider = project_config.default_provider;
164        }
165
166        // Merge provider configs, but never allow API keys from project config
167        for (provider, proj_provider_config) in project_config.providers {
168            let entry = self.providers.entry(provider).or_default();
169
170            // Don't override API keys from project config (security)
171            if !proj_provider_config.model.is_empty() {
172                entry.model = proj_provider_config.model;
173            }
174
175            // Merge additional params
176            entry
177                .additional_params
178                .extend(proj_provider_config.additional_params);
179
180            // Override token limit if set in project config
181            if proj_provider_config.token_limit.is_some() {
182                entry.token_limit = proj_provider_config.token_limit;
183            }
184        }
185
186        // Override other settings
187        self.use_emoji = project_config.use_emoji;
188
189        // Always override instructions field if set in project config
190        self.instructions = project_config.instructions.clone();
191
192        // Override preset
193        if project_config.instruction_preset != default_instruction_preset() {
194            self.instruction_preset = project_config.instruction_preset;
195        }
196    }
197
198    /// Save the configuration to git config
199    pub fn save(&self) -> Result<()> {
200        // Don't save project configs to personal config file
201        if self.is_project_config {
202            return Ok(());
203        }
204
205        let mut config = GitConfig::open_default()?;
206        self.save_to_config(&mut config, "gitai")?;
207        debug!("Configuration saved to global git config: {self:?}");
208        Ok(())
209    }
210
211    /// Save the configuration to a git config
212    fn save_to_config(&self, config: &mut GitConfig, prefix: &str) -> Result<()> {
213        // Set default provider
214        config.set_str(&format!("{prefix}.defaultprovider"), &self.default_provider)?;
215
216        // Set use emoji
217        config.set_bool(&format!("{prefix}.useemoji"), self.use_emoji)?;
218
219        // Set instructions
220        config.set_str(&format!("{prefix}.instructions"), &self.instructions)?;
221
222        // Set instruction preset
223        config.set_str(
224            &format!("{prefix}.instructionpreset"),
225            &self.instruction_preset,
226        )?;
227
228        for (provider, provider_config) in &self.providers {
229            // Set api key only if not empty
230            if !provider_config.api_key.is_empty() {
231                config.set_str(
232                    &format!("{prefix}.{provider}-apikey"),
233                    &provider_config.api_key,
234                )?;
235            }
236
237            // Set model
238            config.set_str(
239                &format!("{prefix}.{provider}-model"),
240                &provider_config.model,
241            )?;
242
243            if let Some(token_limit) = provider_config.token_limit {
244                config.set_i64(
245                    &format!("{prefix}.{provider}-tokenlimit"),
246                    i64::try_from(token_limit).context("Token limit exceeds i64 range")?,
247                )?;
248            }
249
250            for (key, value) in &provider_config.additional_params {
251                config.set_str(&format!("{prefix}.{provider}-additional{key}"), value)?;
252            }
253        }
254
255        Ok(())
256    }
257
258    /// Save the configuration as a project-specific configuration
259    pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
260        let repo = git2::Repository::discover(".")?;
261
262        // Before saving, create a copy that excludes API keys
263        let mut project_config = self.clone();
264
265        // Remove API keys from all providers
266        for provider_config in project_config.providers.values_mut() {
267            provider_config.api_key.clear();
268        }
269
270        // Mark as project config
271        project_config.is_project_config = true;
272
273        // Save to local git config
274        let mut config = repo.config()?;
275        project_config.save_to_config(&mut config, "gitai")?;
276        debug!("Project configuration saved to local git config: {project_config:?}");
277        Ok(())
278    }
279
280    /// Check the environment for necessary prerequisites
281    pub fn check_environment(&self) -> Result<()> {
282        // Check if we're in a git repository
283        if !GitRepo::is_inside_work_tree()? {
284            return Err(anyhow!(
285                "Not in a Git repository. Please run this command from within a Git repository."
286            ));
287        }
288
289        Ok(())
290    }
291
292    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
293        self.temp_instructions = instructions;
294    }
295
296    pub fn set_temp_preset(&mut self, preset: Option<String>) {
297        self.temp_preset = preset;
298    }
299
300    pub fn get_effective_instructions(&self) -> String {
301        let preset_library = get_instruction_preset_library();
302        let preset_instructions = self
303            .temp_preset
304            .as_ref()
305            .or(Some(&self.instruction_preset))
306            .and_then(|p| preset_library.get_preset(p))
307            .map(|p| p.instructions.clone())
308            .unwrap_or_default();
309
310        let custom_instructions = self
311            .temp_instructions
312            .as_ref()
313            .unwrap_or(&self.instructions);
314
315        format!("{preset_instructions}\n\n{custom_instructions}")
316            .trim()
317            .to_string()
318    }
319
320    /// Update the configuration with new values
321    #[allow(clippy::too_many_arguments)]
322    pub fn update(
323        &mut self,
324        provider: Option<String>,
325        api_key: Option<String>,
326        model: Option<String>,
327        additional_params: Option<HashMap<String, String>>,
328        use_emoji: Option<bool>,
329        instructions: Option<String>,
330        token_limit: Option<usize>,
331    ) -> anyhow::Result<()> {
332        if let Some(provider) = provider {
333            self.default_provider.clone_from(&provider);
334            if !self.providers.contains_key(&provider) {
335                // Only insert a new provider if it requires configuration
336                if provider_requires_api_key(&provider.to_lowercase()) {
337                    self.providers.insert(
338                        provider.clone(),
339                        ProviderConfig::default_for(&provider.to_lowercase()),
340                    );
341                }
342            }
343        }
344
345        let provider_config = self
346            .providers
347            .get_mut(&self.default_provider)
348            .context("Could not get default provider")?;
349
350        if let Some(key) = api_key {
351            provider_config.api_key = key;
352        }
353        if let Some(model) = model {
354            provider_config.model = model;
355        }
356        if let Some(params) = additional_params {
357            provider_config.additional_params.extend(params);
358        }
359        if let Some(emoji) = use_emoji {
360            self.use_emoji = emoji;
361        }
362        if let Some(instr) = instructions {
363            self.instructions = instr;
364        }
365        if let Some(limit) = token_limit {
366            provider_config.token_limit = Some(limit);
367        }
368
369        debug!("Configuration updated: {self:?}");
370        Ok(())
371    }
372
373    /// Get the configuration for a specific provider
374    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
375        // Special case: redirect "claude" to "anthropic"
376        let provider_to_lookup = if provider.to_lowercase() == "claude" {
377            "anthropic"
378        } else {
379            provider
380        };
381
382        // First try direct lookup
383        self.providers.get(provider_to_lookup).or_else(|| {
384            // If not found, try lowercased version
385            let lowercase_provider = provider_to_lookup.to_lowercase();
386
387            self.providers.get(&lowercase_provider).or_else(|| {
388                // If the provider is not in the config, check if it's a valid provider
389                if get_available_provider_names().contains(&lowercase_provider) {
390                    // Return None for valid providers not in the config
391                    // This allows the code to use default values for providers like Ollama
392                    None
393                } else {
394                    // Return None for invalid providers
395                    None
396                }
397            })
398        })
399    }
400
401    /// Set whether this config is a project config
402    pub fn set_project_config(&mut self, is_project: bool) {
403        self.is_project_config = is_project;
404    }
405
406    /// Check if this is a project config
407    pub fn is_project_config(&self) -> bool {
408        self.is_project_config
409    }
410}
411
412impl Default for Config {
413    fn default() -> Self {
414        let mut providers = HashMap::new();
415        for provider in get_available_provider_names() {
416            providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
417        }
418
419        // Default to OpenAI if available, otherwise use the first available provider
420        let default_provider = if providers.contains_key("openai") {
421            "openai".to_string()
422        } else {
423            providers.keys().next().map_or_else(
424                || "openai".to_string(), // Fallback even if no providers (should never happen)
425                std::clone::Clone::clone,
426            )
427        };
428
429        Self {
430            default_provider,
431            providers,
432            use_emoji: default_emoji(),
433            instructions: String::new(),
434            instruction_preset: default_instruction_preset(),
435            temp_instructions: None,
436            temp_preset: None,
437            is_project_config: false,
438        }
439    }
440}
441
442impl ProviderConfig {
443    /// Create a default provider configuration for a given provider
444    pub fn default_for(provider: &str) -> Self {
445        Self {
446            api_key: String::new(),
447            model: get_default_model_for_provider(provider).to_string(),
448            additional_params: HashMap::new(),
449            token_limit: None, // Will use the default from get_default_token_limit_for_provider
450        }
451    }
452
453    /// Get the token limit for this provider configuration
454    pub fn get_token_limit(&self) -> Option<usize> {
455        self.token_limit
456    }
457}