git_iris/
config.rs

1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm_providers::{
4    get_available_providers, get_provider_metadata, LLMProviderConfig, LLMProviderType,
5};
6use crate::log_debug;
7
8use anyhow::{anyhow, Context, Result};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16/// Configuration structure for the Git-Iris application
17#[derive(Deserialize, Serialize, Clone, Debug)]
18pub struct Config {
19    /// Default LLM provider
20    pub default_provider: String,
21    /// Provider-specific configurations
22    pub providers: HashMap<String, ProviderConfig>,
23    /// Flag indicating whether to use Gitmoji
24    #[serde(default = "default_gitmoji")]
25    pub use_gitmoji: bool,
26    /// Instructions for commit messages
27    #[serde(default)]
28    pub instructions: String,
29    #[serde(default = "default_instruction_preset")]
30    pub instruction_preset: String,
31    #[serde(skip)]
32    pub temp_instructions: Option<String>,
33    #[serde(skip)]
34    pub temp_preset: Option<String>,
35}
36
37/// Provider-specific configuration structure
38#[derive(Deserialize, Serialize, Clone, Debug, Default)]
39pub struct ProviderConfig {
40    /// API key for the provider
41    pub api_key: String,
42    /// Model to be used with the provider
43    pub model: String,
44    /// Additional parameters for the provider
45    #[serde(default)]
46    pub additional_params: HashMap<String, String>,
47    /// Token limit, if set by the user
48    pub token_limit: Option<usize>,
49}
50
51/// Default function for `use_gitmoji`
52fn default_gitmoji() -> bool {
53    true
54}
55
56// Default instruction preset to use
57fn default_instruction_preset() -> String {
58    "default".to_string()
59}
60
61impl Config {
62    /// Load the configuration from the file
63    pub fn load() -> Result<Self> {
64        let config_path = Self::get_config_path()?;
65        if !config_path.exists() {
66            return Ok(Self::default());
67        }
68        let config_content = fs::read_to_string(config_path)?;
69        let config: Self = toml::from_str(&config_content)?;
70        log_debug!("Configuration loaded: {:?}", config);
71        Ok(config)
72    }
73
74    /// Save the configuration to the file
75    pub fn save(&self) -> Result<()> {
76        let config_path = Self::get_config_path()?;
77        let config_content = toml::to_string(self)?;
78        fs::write(config_path, config_content)?;
79        log_debug!("Configuration saved: {:?}", self);
80        Ok(())
81    }
82
83    /// Get the path to the configuration file
84    fn get_config_path() -> Result<PathBuf> {
85        let mut path =
86            config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
87        path.push("git-iris");
88        std::fs::create_dir_all(&path)?;
89        path.push("config.toml");
90        Ok(path)
91    }
92
93    /// Check the environment for necessary prerequisites
94    pub fn check_environment(&self) -> Result<()> {
95        // Check if we're in a git repository
96        if !GitRepo::is_inside_work_tree()? {
97            return Err(anyhow!(
98                "Not in a Git repository. Please run this command from within a Git repository."
99            ));
100        }
101
102        Ok(())
103    }
104
105    pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
106        self.temp_instructions = instructions;
107    }
108
109    pub fn set_temp_preset(&mut self, preset: Option<String>) {
110        self.temp_preset = preset;
111    }
112
113    pub fn get_effective_instructions(&self) -> String {
114        let preset_library = get_instruction_preset_library();
115        let preset_instructions = self
116            .temp_preset
117            .as_ref()
118            .or(Some(&self.instruction_preset))
119            .and_then(|p| preset_library.get_preset(p))
120            .map(|p| p.instructions.clone())
121            .unwrap_or_default();
122
123        let custom_instructions = self
124            .temp_instructions
125            .as_ref()
126            .unwrap_or(&self.instructions);
127
128        format!("{preset_instructions}\n\n{custom_instructions}")
129            .trim()
130            .to_string()
131    }
132
133    /// Update the configuration with new values
134    #[allow(clippy::too_many_arguments)]
135    pub fn update(
136        &mut self,
137        provider: Option<String>,
138        api_key: Option<String>,
139        model: Option<String>,
140        additional_params: Option<HashMap<String, String>>,
141        use_gitmoji: Option<bool>,
142        instructions: Option<String>,
143        token_limit: Option<usize>,
144    ) -> anyhow::Result<()> {
145        if let Some(provider) = provider {
146            self.default_provider.clone_from(&provider);
147            if !self.providers.contains_key(&provider) {
148                // Only insert a new provider if it requires configuration
149                let provider_type =
150                    LLMProviderType::from_str(&provider).unwrap_or(LLMProviderType::OpenAI);
151                if get_provider_metadata(&provider_type).requires_api_key {
152                    self.providers
153                        .insert(provider.clone(), ProviderConfig::default_for(&provider));
154                }
155            }
156        }
157
158        let provider_config = self
159            .providers
160            .get_mut(&self.default_provider)
161            .context("Could not get default provider")?;
162
163        if let Some(key) = api_key {
164            provider_config.api_key = key;
165        }
166        if let Some(model) = model {
167            provider_config.model = model;
168        }
169        if let Some(params) = additional_params {
170            provider_config.additional_params.extend(params);
171        }
172        if let Some(gitmoji) = use_gitmoji {
173            self.use_gitmoji = gitmoji;
174        }
175        if let Some(instr) = instructions {
176            self.instructions = instr;
177        }
178        if let Some(limit) = token_limit {
179            provider_config.token_limit = Some(limit);
180        }
181
182        log_debug!("Configuration updated: {:?}", self);
183        Ok(())
184    }
185
186    /// Get the configuration for a specific provider
187    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
188        self.providers.get(provider).or_else(|| {
189            // If the provider is not in the config, check if it's a valid provider
190            if LLMProviderType::from_str(provider).is_ok() {
191                // Return None for valid providers not in the config
192                // This allows the code to use default values for providers like Ollama
193                None
194            } else {
195                // Return None for invalid providers
196                None
197            }
198        })
199    }
200}
201
202impl Default for Config {
203    #[allow(clippy::unwrap_used)] // todo: handle unwrap
204    fn default() -> Self {
205        let mut providers = HashMap::new();
206        for provider in get_available_providers() {
207            providers.insert(
208                provider.to_string(),
209                ProviderConfig::default_for(provider.as_ref()),
210            );
211        }
212
213        Self {
214            default_provider: get_available_providers().first().unwrap().to_string(),
215            providers,
216            use_gitmoji: true,
217            instructions: String::new(),
218            instruction_preset: default_instruction_preset(),
219            temp_instructions: None,
220            temp_preset: None,
221        }
222    }
223}
224
225impl ProviderConfig {
226    /// Create a default provider configuration for a given provider
227    pub fn default_for(provider: &str) -> Self {
228        let provider_type =
229            LLMProviderType::from_str(provider).unwrap_or_else(|_| get_available_providers()[0]);
230        let metadata = get_provider_metadata(&provider_type);
231        Self {
232            api_key: String::new(),
233            model: metadata.default_model.to_string(),
234            additional_params: HashMap::new(),
235            token_limit: Some(metadata.default_token_limit),
236        }
237    }
238
239    /// Get the token limit for this provider configuration
240    pub fn get_token_limit(&self) -> usize {
241        self.token_limit.unwrap_or_else(|| {
242            let provider_type = LLMProviderType::from_str(&self.model)
243                .unwrap_or_else(|_| get_available_providers()[0]);
244            get_provider_metadata(&provider_type).default_token_limit
245        })
246    }
247
248    /// Convert to `LLMProviderConfig`
249    pub fn to_llm_provider_config(&self) -> LLMProviderConfig {
250        let mut additional_params = self.additional_params.clone();
251
252        // Add token limit to additional params if set
253        if let Some(limit) = self.token_limit {
254            additional_params.insert("token_limit".to_string(), limit.to_string());
255        }
256
257        LLMProviderConfig {
258            api_key: self.api_key.clone(),
259            model: self.model.clone(),
260            additional_params,
261        }
262    }
263}