1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm::{
4 get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
5};
6use crate::log_debug;
7
8use anyhow::{Context, Result, anyhow};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14
15#[derive(Deserialize, Serialize, Clone, Debug)]
17pub struct Config {
18 pub default_provider: String,
20 pub providers: HashMap<String, ProviderConfig>,
22 #[serde(default = "default_gitmoji")]
24 pub use_gitmoji: bool,
25 #[serde(default)]
27 pub instructions: String,
28 #[serde(default = "default_instruction_preset")]
29 pub instruction_preset: String,
30 #[serde(skip)]
31 pub temp_instructions: Option<String>,
32 #[serde(skip)]
33 pub temp_preset: Option<String>,
34 #[serde(skip)]
36 pub is_project_config: bool,
37}
38
39#[derive(Deserialize, Serialize, Clone, Debug, Default)]
41pub struct ProviderConfig {
42 pub api_key: String,
44 pub model: String,
46 #[serde(default)]
48 pub additional_params: HashMap<String, String>,
49 pub token_limit: Option<usize>,
51}
52
53fn default_gitmoji() -> bool {
55 true
56}
57
58fn default_instruction_preset() -> String {
60 "default".to_string()
61}
62
63pub const PROJECT_CONFIG_FILENAME: &str = ".irisconfig";
65
66impl Config {
67 pub fn load() -> Result<Self> {
69 let config_path = Self::get_config_path()?;
71 let mut config = if config_path.exists() {
72 let config_content = fs::read_to_string(&config_path)?;
73 let config: Self = toml::from_str(&config_content)?;
74 Self::migrate_if_needed(config)
75 } else {
76 Self::default()
77 };
78
79 if let Ok(project_config) = Self::load_project_config() {
81 config.merge_with_project_config(project_config);
82 }
83
84 log_debug!("Configuration loaded: {:?}", config);
85 Ok(config)
86 }
87
88 pub fn load_project_config() -> Result<Self, anyhow::Error> {
90 let config_path = Self::get_project_config_path()?;
91 if !config_path.exists() {
92 return Err(anyhow::anyhow!("Project configuration file not found"));
93 }
94
95 let config_str = match fs::read_to_string(&config_path) {
97 Ok(content) => content,
98 Err(e) => return Err(anyhow::anyhow!("Failed to read project config file: {}", e)),
99 };
100
101 let mut config: Self = match toml::from_str(&config_str) {
103 Ok(config) => config,
104 Err(e) => {
105 return Err(anyhow::anyhow!(
106 "Invalid project configuration file format: {}. Please check your {} file for syntax errors.",
107 e,
108 PROJECT_CONFIG_FILENAME
109 ));
110 }
111 };
112
113 config.is_project_config = true;
114 Ok(config)
115 }
116
117 pub fn get_project_config_path() -> Result<PathBuf, anyhow::Error> {
119 let repo_root = crate::git::GitRepo::get_repo_root()?;
121 Ok(repo_root.join(PROJECT_CONFIG_FILENAME))
122 }
123
124 pub fn merge_with_project_config(&mut self, project_config: Self) {
127 log_debug!("Merging with project configuration");
128
129 if project_config.default_provider != Self::default().default_provider {
131 self.default_provider = project_config.default_provider;
132 }
133
134 for (provider, proj_provider_config) in project_config.providers {
136 let entry = self.providers.entry(provider).or_default();
137
138 if !proj_provider_config.model.is_empty() {
140 entry.model = proj_provider_config.model;
141 }
142
143 entry
145 .additional_params
146 .extend(proj_provider_config.additional_params);
147
148 if proj_provider_config.token_limit.is_some() {
150 entry.token_limit = proj_provider_config.token_limit;
151 }
152 }
153
154 self.use_gitmoji = project_config.use_gitmoji;
156
157 self.instructions = project_config.instructions.clone();
159
160 if project_config.instruction_preset != default_instruction_preset() {
162 self.instruction_preset = project_config.instruction_preset;
163 }
164 }
165
166 fn migrate_if_needed(mut config: Self) -> Self {
168 let mut migration_performed = false;
170 if config.providers.contains_key("claude") {
171 log_debug!("Migrating 'claude' provider to 'anthropic'");
172 if let Some(claude_config) = config.providers.remove("claude") {
173 config
174 .providers
175 .insert("anthropic".to_string(), claude_config);
176 }
177
178 if config.default_provider == "claude" {
180 config.default_provider = "anthropic".to_string();
181 }
182
183 migration_performed = true;
184 }
185
186 if migration_performed {
188 log_debug!("Saving configuration after migration");
189 if let Err(e) = config.save() {
190 log_debug!("Failed to save migrated config: {}", e);
191 }
192 }
193
194 config
195 }
196
197 pub fn save(&self) -> Result<()> {
199 if self.is_project_config {
201 return Ok(());
202 }
203
204 let config_path = Self::get_config_path()?;
205 let config_content = toml::to_string(self)?;
206 fs::write(config_path, config_content)?;
207 log_debug!("Configuration saved: {:?}", self);
208 Ok(())
209 }
210
211 pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
213 let config_path = Self::get_project_config_path()?;
214
215 let mut project_config = self.clone();
217
218 for provider_config in project_config.providers.values_mut() {
220 provider_config.api_key.clear();
221 }
222
223 project_config.is_project_config = true;
225
226 let config_str = toml::to_string_pretty(&project_config)?;
228
229 fs::write(config_path, config_str)?;
231
232 Ok(())
233 }
234
235 fn get_config_path() -> Result<PathBuf> {
237 let mut path =
238 config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
239 path.push("git-iris");
240 std::fs::create_dir_all(&path)?;
241 path.push("config.toml");
242 Ok(path)
243 }
244
245 pub fn check_environment(&self) -> Result<()> {
247 if !GitRepo::is_inside_work_tree()? {
249 return Err(anyhow!(
250 "Not in a Git repository. Please run this command from within a Git repository."
251 ));
252 }
253
254 Ok(())
255 }
256
257 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
258 self.temp_instructions = instructions;
259 }
260
261 pub fn set_temp_preset(&mut self, preset: Option<String>) {
262 self.temp_preset = preset;
263 }
264
265 pub fn get_effective_instructions(&self) -> String {
266 let preset_library = get_instruction_preset_library();
267 let preset_instructions = self
268 .temp_preset
269 .as_ref()
270 .or(Some(&self.instruction_preset))
271 .and_then(|p| preset_library.get_preset(p))
272 .map(|p| p.instructions.clone())
273 .unwrap_or_default();
274
275 let custom_instructions = self
276 .temp_instructions
277 .as_ref()
278 .unwrap_or(&self.instructions);
279
280 format!("{preset_instructions}\n\n{custom_instructions}")
281 .trim()
282 .to_string()
283 }
284
285 #[allow(clippy::too_many_arguments)]
287 pub fn update(
288 &mut self,
289 provider: Option<String>,
290 api_key: Option<String>,
291 model: Option<String>,
292 additional_params: Option<HashMap<String, String>>,
293 use_gitmoji: Option<bool>,
294 instructions: Option<String>,
295 token_limit: Option<usize>,
296 ) -> anyhow::Result<()> {
297 if let Some(provider) = provider {
298 self.default_provider.clone_from(&provider);
299 if !self.providers.contains_key(&provider) {
300 if provider_requires_api_key(&provider.to_lowercase()) {
302 self.providers.insert(
303 provider.clone(),
304 ProviderConfig::default_for(&provider.to_lowercase()),
305 );
306 }
307 }
308 }
309
310 let provider_config = self
311 .providers
312 .get_mut(&self.default_provider)
313 .context("Could not get default provider")?;
314
315 if let Some(key) = api_key {
316 provider_config.api_key = key;
317 }
318 if let Some(model) = model {
319 provider_config.model = model;
320 }
321 if let Some(params) = additional_params {
322 provider_config.additional_params.extend(params);
323 }
324 if let Some(gitmoji) = use_gitmoji {
325 self.use_gitmoji = gitmoji;
326 }
327 if let Some(instr) = instructions {
328 self.instructions = instr;
329 }
330 if let Some(limit) = token_limit {
331 provider_config.token_limit = Some(limit);
332 }
333
334 log_debug!("Configuration updated: {:?}", self);
335 Ok(())
336 }
337
338 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
340 let provider_to_lookup = if provider.to_lowercase() == "claude" {
342 "anthropic"
343 } else {
344 provider
345 };
346
347 self.providers.get(provider_to_lookup).or_else(|| {
349 let lowercase_provider = provider_to_lookup.to_lowercase();
351
352 self.providers.get(&lowercase_provider).or_else(|| {
353 if get_available_provider_names().contains(&lowercase_provider) {
355 None
358 } else {
359 None
361 }
362 })
363 })
364 }
365
366 pub fn set_project_config(&mut self, is_project: bool) {
368 self.is_project_config = is_project;
369 }
370
371 pub fn is_project_config(&self) -> bool {
373 self.is_project_config
374 }
375}
376
377impl Default for Config {
378 fn default() -> Self {
379 let mut providers = HashMap::new();
380 for provider in get_available_provider_names() {
381 providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
382 }
383
384 let default_provider = if providers.contains_key("openai") {
386 "openai".to_string()
387 } else {
388 providers.keys().next().map_or_else(
389 || "openai".to_string(), std::clone::Clone::clone,
391 )
392 };
393
394 Self {
395 default_provider,
396 providers,
397 use_gitmoji: default_gitmoji(),
398 instructions: String::new(),
399 instruction_preset: default_instruction_preset(),
400 temp_instructions: None,
401 temp_preset: None,
402 is_project_config: false,
403 }
404 }
405}
406
407impl ProviderConfig {
408 pub fn default_for(provider: &str) -> Self {
410 Self {
411 api_key: String::new(),
412 model: get_default_model_for_provider(provider).to_string(),
413 additional_params: HashMap::new(),
414 token_limit: None, }
416 }
417
418 pub fn get_token_limit(&self) -> Option<usize> {
420 self.token_limit
421 }
422}