1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm::{
4 get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
5};
6use crate::log_debug;
7
8use anyhow::{Context, Result, anyhow};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14
15#[derive(Deserialize, Serialize, Clone, Debug)]
17pub struct Config {
18 pub default_provider: String,
20 pub providers: HashMap<String, ProviderConfig>,
22 #[serde(default = "default_gitmoji")]
24 pub use_gitmoji: bool,
25 #[serde(default)]
27 pub instructions: String,
28 #[serde(default = "default_instruction_preset")]
29 pub instruction_preset: String,
30 #[serde(skip)]
31 pub temp_instructions: Option<String>,
32 #[serde(skip)]
33 pub temp_preset: Option<String>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug, Default)]
38pub struct ProviderConfig {
39 pub api_key: String,
41 pub model: String,
43 #[serde(default)]
45 pub additional_params: HashMap<String, String>,
46 pub token_limit: Option<usize>,
48}
49
50fn default_gitmoji() -> bool {
52 true
53}
54
55fn default_instruction_preset() -> String {
57 "default".to_string()
58}
59
60impl Config {
61 pub fn load() -> Result<Self> {
63 let config_path = Self::get_config_path()?;
64 if !config_path.exists() {
65 return Ok(Self::default());
66 }
67 let config_content = fs::read_to_string(config_path)?;
68 let mut config: Self = toml::from_str(&config_content)?;
69
70 let mut migration_performed = false;
72 if config.providers.contains_key("claude") {
73 log_debug!("Migrating 'claude' provider to 'anthropic'");
74 if let Some(claude_config) = config.providers.remove("claude") {
75 config
76 .providers
77 .insert("anthropic".to_string(), claude_config);
78 }
79
80 if config.default_provider == "claude" {
82 config.default_provider = "anthropic".to_string();
83 }
84
85 migration_performed = true;
86 }
87
88 if migration_performed {
90 log_debug!("Saving configuration after migration");
91 if let Err(e) = config.save() {
92 log_debug!("Failed to save migrated config: {}", e);
93 }
94 }
95
96 log_debug!("Configuration loaded: {:?}", config);
97 Ok(config)
98 }
99
100 pub fn save(&self) -> Result<()> {
102 let config_path = Self::get_config_path()?;
103 let config_content = toml::to_string(self)?;
104 fs::write(config_path, config_content)?;
105 log_debug!("Configuration saved: {:?}", self);
106 Ok(())
107 }
108
109 fn get_config_path() -> Result<PathBuf> {
111 let mut path =
112 config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
113 path.push("git-iris");
114 std::fs::create_dir_all(&path)?;
115 path.push("config.toml");
116 Ok(path)
117 }
118
119 pub fn check_environment(&self) -> Result<()> {
121 if !GitRepo::is_inside_work_tree()? {
123 return Err(anyhow!(
124 "Not in a Git repository. Please run this command from within a Git repository."
125 ));
126 }
127
128 Ok(())
129 }
130
131 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
132 self.temp_instructions = instructions;
133 }
134
135 pub fn set_temp_preset(&mut self, preset: Option<String>) {
136 self.temp_preset = preset;
137 }
138
139 pub fn get_effective_instructions(&self) -> String {
140 let preset_library = get_instruction_preset_library();
141 let preset_instructions = self
142 .temp_preset
143 .as_ref()
144 .or(Some(&self.instruction_preset))
145 .and_then(|p| preset_library.get_preset(p))
146 .map(|p| p.instructions.clone())
147 .unwrap_or_default();
148
149 let custom_instructions = self
150 .temp_instructions
151 .as_ref()
152 .unwrap_or(&self.instructions);
153
154 format!("{preset_instructions}\n\n{custom_instructions}")
155 .trim()
156 .to_string()
157 }
158
159 #[allow(clippy::too_many_arguments)]
161 pub fn update(
162 &mut self,
163 provider: Option<String>,
164 api_key: Option<String>,
165 model: Option<String>,
166 additional_params: Option<HashMap<String, String>>,
167 use_gitmoji: Option<bool>,
168 instructions: Option<String>,
169 token_limit: Option<usize>,
170 ) -> anyhow::Result<()> {
171 if let Some(provider) = provider {
172 self.default_provider.clone_from(&provider);
173 if !self.providers.contains_key(&provider) {
174 if provider_requires_api_key(&provider.to_lowercase()) {
176 self.providers.insert(
177 provider.clone(),
178 ProviderConfig::default_for(&provider.to_lowercase()),
179 );
180 }
181 }
182 }
183
184 let provider_config = self
185 .providers
186 .get_mut(&self.default_provider)
187 .context("Could not get default provider")?;
188
189 if let Some(key) = api_key {
190 provider_config.api_key = key;
191 }
192 if let Some(model) = model {
193 provider_config.model = model;
194 }
195 if let Some(params) = additional_params {
196 provider_config.additional_params.extend(params);
197 }
198 if let Some(gitmoji) = use_gitmoji {
199 self.use_gitmoji = gitmoji;
200 }
201 if let Some(instr) = instructions {
202 self.instructions = instr;
203 }
204 if let Some(limit) = token_limit {
205 provider_config.token_limit = Some(limit);
206 }
207
208 log_debug!("Configuration updated: {:?}", self);
209 Ok(())
210 }
211
212 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
214 let provider_to_lookup = if provider.to_lowercase() == "claude" {
216 "anthropic"
217 } else {
218 provider
219 };
220
221 self.providers.get(provider_to_lookup).or_else(|| {
223 let lowercase_provider = provider_to_lookup.to_lowercase();
225
226 self.providers.get(&lowercase_provider).or_else(|| {
227 if get_available_provider_names().contains(&lowercase_provider) {
229 None
232 } else {
233 None
235 }
236 })
237 })
238 }
239}
240
241impl Default for Config {
242 fn default() -> Self {
243 let mut providers = HashMap::new();
244 for provider in get_available_provider_names() {
245 providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
246 }
247
248 let default_provider = if providers.contains_key("openai") {
250 "openai".to_string()
251 } else {
252 providers.keys().next().map_or_else(
253 || "openai".to_string(), std::clone::Clone::clone,
255 )
256 };
257
258 Self {
259 default_provider,
260 providers,
261 use_gitmoji: default_gitmoji(),
262 instructions: String::new(),
263 instruction_preset: default_instruction_preset(),
264 temp_instructions: None,
265 temp_preset: None,
266 }
267 }
268}
269
270impl ProviderConfig {
271 pub fn default_for(provider: &str) -> Self {
273 Self {
274 api_key: String::new(),
275 model: get_default_model_for_provider(provider).to_string(),
276 additional_params: HashMap::new(),
277 token_limit: None, }
279 }
280
281 pub fn get_token_limit(&self) -> Option<usize> {
283 self.token_limit
284 }
285}