gitai/
config.rs

1use crate::core::llm::{
2    get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
3};
4use crate::debug;
5use crate::git::GitRepo;
6
7use anyhow::{Context, Result, anyhow};
8use git2::Config as GitConfig;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Command;
12
13/// Configuration structure
14#[derive(Deserialize, Serialize, Clone, Debug)]
15pub struct Config {
16    /// Default LLM provider
17    pub default_provider: String,
18    /// Provider-specific configurations
19    pub providers: HashMap<String, ProviderConfig>,
20    /// Instructions for commit messages
21    #[serde(default)]
22    pub instructions: String,
23    #[serde(skip)]
24    pub temp_instructions: Option<String>,
25    /// Flag indicating if this config is local
26    #[serde(skip)]
27    pub is_local: bool,
28}
29
30/// Provider-specific configuration structure
31#[derive(Deserialize, Serialize, Clone, Debug, Default)]
32pub struct ProviderConfig {
33    /// API key for the provider
34    pub api_key: String,
35    /// Model to be used with the provider
36    pub model_name: String,
37    /// Additional parameters for the provider
38    #[serde(default)]
39    pub additional_params: HashMap<String, String>,
40    /// Token limit, if set by the user
41    pub token_limit: Option<usize>,
42}
43
44impl Config {
45    /// Load the configuration from git config
46    pub fn load() -> Result<Self> {
47        let mut config = Self::load_from_config("gitai");
48
49        // Then try to load and merge project config if available
50        if let Ok(project_config) = Self::load_project_config() {
51            config.merge_with_project_config(project_config);
52        }
53
54        debug!("Configuration loaded: {config:?}");
55        Ok(config)
56    }
57
58    /// Load configuration from git config
59    fn load_from_config(prefix: &str) -> Self {
60        let default_provider = Self::get_git_config_value(&format!("{prefix}.defaultprovider"))
61            .unwrap_or("openai".to_string());
62        let instructions =
63            Self::get_git_config_value(&format!("{prefix}.instructions")).unwrap_or_default();
64
65        let mut providers = HashMap::new();
66        // To load providers, we need to iterate over all keys with prefix
67        // But git2 Config doesn't have easy way to iterate, so for now, assume known providers
68        for provider in get_available_provider_names() {
69            if let Some(api_key) =
70                Self::get_git_config_value(&format!("{prefix}.{provider}-apikey"))
71            {
72                let default_model = get_default_model_for_provider(&provider).to_string();
73                let model = Self::get_git_config_value(&format!("{prefix}.{provider}-model"))
74                    .unwrap_or(default_model);
75                let token_limit =
76                    Self::get_git_config_i64(&format!("{prefix}.{provider}-tokenlimit")).map(|v| {
77                        usize::try_from(v).expect("Failed to convert token limit from i64 to usize")
78                    });
79                let additional_params = HashMap::new();
80                // For additional params, it's hard to iterate, so skip for now
81                providers.insert(
82                    provider.to_string(),
83                    ProviderConfig {
84                        api_key,
85                        model_name: model,
86                        additional_params,
87                        token_limit,
88                    },
89                );
90            }
91        }
92
93        Self {
94            default_provider,
95            providers,
96            instructions,
97            temp_instructions: None,
98            is_local: false,
99        }
100    }
101
102    fn get_git_config_value(key: &str) -> Option<String> {
103        let output = Command::new("git")
104            .args(["config", "--get", key])
105            .output()
106            .ok()?;
107        if output.status.success() {
108            Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
109        } else {
110            None
111        }
112    }
113
114    #[allow(unused)]
115    fn get_git_config_bool(key: &str) -> Option<bool> {
116        Self::get_git_config_value(key).and_then(|v| v.parse().ok())
117    }
118
119    fn get_git_config_i64(key: &str) -> Option<i64> {
120        Self::get_git_config_value(key).and_then(|v| v.parse().ok())
121    }
122
123    /// Load project-specific configuration
124    pub fn load_project_config() -> Result<Self, anyhow::Error> {
125        let mut project_config = Self::load_from_config("gitai");
126        project_config.is_local = true;
127        Ok(project_config)
128    }
129
130    /// Merge this config with project-specific config, with project config taking precedence
131    /// But never allow API keys from project config
132    pub fn merge_with_project_config(&mut self, project_config: Self) {
133        debug!("Merging with project configuration");
134
135        // Override default provider if set in project config
136        if project_config.default_provider != Self::default().default_provider {
137            self.default_provider = project_config.default_provider;
138        }
139
140        // Merge provider configs, but never allow API keys from project config
141        for (provider, proj_provider_config) in project_config.providers {
142            let entry = self.providers.entry(provider).or_default();
143
144            // Don't override API keys from project config (security)
145            if !proj_provider_config.model_name.is_empty() {
146                entry.model_name = proj_provider_config.model_name;
147            }
148
149            // Merge additional params
150            entry
151                .additional_params
152                .extend(proj_provider_config.additional_params);
153
154            // Override token limit if set in project config
155            if proj_provider_config.token_limit.is_some() {
156                entry.token_limit = proj_provider_config.token_limit;
157            }
158        }
159
160        // Always override instructions field if set in project config
161        self.instructions = project_config.instructions.clone();
162    }
163
164    /// Save the configuration to git config
165    pub fn save(&self) -> Result<()> {
166        // Don't save project configs to personal config file
167        if self.is_local {
168            return Ok(());
169        }
170
171        let mut config = GitConfig::open_default()?;
172        self.save_to_config(&mut config, "gitai")?;
173        debug!("Configuration saved to global git config: {self:?}");
174        Ok(())
175    }
176
177    /// Save the configuration to a git config
178    fn save_to_config(&self, config: &mut GitConfig, prefix: &str) -> Result<()> {
179        // Set default provider
180        config.set_str(&format!("{prefix}.defaultprovider"), &self.default_provider)?;
181
182        // Set instructions
183        config.set_str(&format!("{prefix}.instructions"), &self.instructions)?;
184
185        for (provider, provider_config) in &self.providers {
186            // Set api key only if not empty
187            if !provider_config.api_key.is_empty() {
188                config.set_str(
189                    &format!("{prefix}.{provider}-apikey"),
190                    &provider_config.api_key,
191                )?;
192            }
193
194            // Set model
195            config.set_str(
196                &format!("{prefix}.{provider}-model"),
197                &provider_config.model_name,
198            )?;
199
200            if let Some(token_limit) = provider_config.token_limit {
201                config.set_i64(
202                    &format!("{prefix}.{provider}-tokenlimit"),
203                    i64::try_from(token_limit).context("Token limit exceeds i64 range")?,
204                )?;
205            }
206
207            for (key, value) in &provider_config.additional_params {
208                config.set_str(&format!("{prefix}.{provider}-additional{key}"), value)?;
209            }
210        }
211
212        Ok(())
213    }
214
215    /// Save the configuration as a project-specific configuration
216    pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
217        let repo = git2::Repository::discover(".")?;
218
219        // Before saving, create a copy that excludes API keys
220        let mut project_config = self.clone();
221
222        // Remove API keys from all providers
223        for provider_config in project_config.providers.values_mut() {
224            provider_config.api_key.clear();
225        }
226
227        // Mark as project config
228        project_config.is_local = true;
229
230        // Save to local git config
231        let mut config = repo.config()?;
232        project_config.save_to_config(&mut config, "gitai")?;
233        debug!("Project configuration saved to local git config: {project_config:?}");
234        Ok(())
235    }
236
237    /// Check the environment for necessary prerequisites
238    pub fn check_environment(&self) -> Result<()> {
239        // Check if we're in a git repository
240        if !GitRepo::is_inside_work_tree()? {
241            return Err(anyhow!(
242                "Not in a Git repository. Please run this command from within a Git repository."
243            ));
244        }
245
246        Ok(())
247    }
248
249    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
250        self.temp_instructions = instructions;
251    }
252
253    pub fn get_effective_instructions(&self) -> String {
254        let custom_instructions = self
255            .temp_instructions
256            .as_ref()
257            .unwrap_or(&self.instructions);
258
259        custom_instructions.trim().to_string()
260    }
261
262    /// Update the configuration with new values
263    #[allow(clippy::too_many_arguments)]
264    pub fn update(
265        &mut self,
266        provider: Option<String>,
267        api_key: Option<String>,
268        model: Option<String>,
269        additional_params: Option<HashMap<String, String>>,
270        instructions: Option<String>,
271        token_limit: Option<usize>,
272    ) -> anyhow::Result<()> {
273        if let Some(provider) = provider {
274            self.default_provider.clone_from(&provider);
275            if !self.providers.contains_key(&provider) {
276                // Only insert a new provider if it requires configuration
277                if provider_requires_api_key(&provider.to_lowercase()) {
278                    self.providers.insert(
279                        provider.clone(),
280                        ProviderConfig::default_for(&provider.to_lowercase()),
281                    );
282                }
283            }
284        }
285
286        let provider_config = self
287            .providers
288            .get_mut(&self.default_provider)
289            .context("Could not get default provider")?;
290
291        if let Some(key) = api_key {
292            provider_config.api_key = key;
293        }
294        if let Some(model) = model {
295            provider_config.model_name = model;
296        }
297        if let Some(params) = additional_params {
298            provider_config.additional_params.extend(params);
299        }
300
301        if let Some(instr) = instructions {
302            self.instructions = instr;
303        }
304        if let Some(limit) = token_limit {
305            provider_config.token_limit = Some(limit);
306        }
307
308        debug!("Configuration updated: {self:?}");
309        Ok(())
310    }
311
312    /// Get the configuration for a specific provider
313    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
314        // Special case: redirect "claude" to "anthropic"
315        let provider_to_lookup = if provider.to_lowercase() == "claude" {
316            "anthropic"
317        } else {
318            provider
319        };
320
321        // First try direct lookup
322        self.providers.get(provider_to_lookup).or_else(|| {
323            // If not found, try lowercased version
324            let lowercase_provider = provider_to_lookup.to_lowercase();
325
326            self.providers.get(&lowercase_provider).or_else(|| {
327                // If the provider is not in the config, check if it's a valid provider
328                if get_available_provider_names().contains(&lowercase_provider) {
329                    // Return None for valid providers not in the config
330                    // This allows the code to use default values for providers like Ollama
331                    None
332                } else {
333                    // Return None for invalid providers
334                    None
335                }
336            })
337        })
338    }
339
340    /// Set whether this config is a project config
341    pub fn set_project_config(&mut self, is_project: bool) {
342        self.is_local = is_project;
343    }
344
345    /// Check if this is a project config
346    pub fn is_project_config(&self) -> bool {
347        self.is_local
348    }
349}
350
351impl Default for Config {
352    fn default() -> Self {
353        let mut providers = HashMap::new();
354        for provider in get_available_provider_names() {
355            providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
356        }
357
358        // Default to OpenAI if available, otherwise use the first available provider
359        let default_provider = if providers.contains_key("openai") {
360            "openai".to_string()
361        } else {
362            providers.keys().next().map_or_else(
363                || "openai".to_string(), // Fallback even if no providers (should never happen)
364                std::clone::Clone::clone,
365            )
366        };
367
368        Self {
369            default_provider,
370            providers,
371            instructions: String::new(),
372            temp_instructions: None,
373            is_local: false,
374        }
375    }
376}
377
378impl ProviderConfig {
379    /// Create a default provider configuration for a given provider
380    pub fn default_for(provider: &str) -> Self {
381        Self {
382            api_key: String::new(),
383            model_name: get_default_model_for_provider(provider).to_string(),
384            additional_params: HashMap::new(),
385            token_limit: None, // Will use the default from get_default_token_limit_for_provider
386        }
387    }
388
389    /// Get the token limit for this provider configuration
390    pub fn get_token_limit(&self) -> Option<usize> {
391        self.token_limit
392    }
393}