1use crate::core::llm::{
2 get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
3};
4use crate::debug;
5use crate::git::GitRepo;
6use crate::instruction_presets::get_instruction_preset_library;
7use anyhow::{Context, Result, anyhow};
12use git2::Config as GitConfig;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::process::Command;
16
17#[derive(Deserialize, Serialize, Clone, Debug)]
19pub struct Config {
20 pub default_provider: String,
22 pub providers: HashMap<String, ProviderConfig>,
24 #[serde(default = "default_emoji")]
26 pub use_emoji: bool,
27 #[serde(default)]
29 pub instructions: String,
30 #[serde(default = "default_instruction_preset")]
31 pub instruction_preset: String,
32 #[serde(skip)]
33 pub temp_instructions: Option<String>,
34 #[serde(skip)]
35 pub temp_preset: Option<String>,
36 #[serde(skip)]
38 pub is_project_config: bool,
39}
40
41#[derive(Deserialize, Serialize, Clone, Debug, Default)]
43pub struct ProviderConfig {
44 pub api_key: String,
46 pub model: String,
48 #[serde(default)]
50 pub additional_params: HashMap<String, String>,
51 pub token_limit: Option<usize>,
53}
54
55fn default_emoji() -> bool {
57 false
58}
59
60fn default_instruction_preset() -> String {
62 "default".to_string()
63}
64
65impl Config {
66 pub fn load() -> Result<Self> {
68 let mut config = Self::load_from_config("gitai");
69
70 if let Ok(project_config) = Self::load_project_config() {
72 config.merge_with_project_config(project_config);
73 }
74
75 debug!("Configuration loaded: {config:?}");
76 Ok(config)
77 }
78
79 fn load_from_config(prefix: &str) -> Self {
81 let default_provider = Self::get_git_config_value(&format!("{prefix}.defaultprovider"))
82 .unwrap_or("openai".to_string());
83 let use_emoji = Self::get_git_config_bool(&format!("{prefix}.useemoji")).unwrap_or(true);
84 let instructions =
85 Self::get_git_config_value(&format!("{prefix}.instructions")).unwrap_or_default();
86 let instruction_preset = Self::get_git_config_value(&format!("{prefix}.instructionpreset"))
87 .unwrap_or("default".to_string());
88
89 let mut providers = HashMap::new();
90 for provider in get_available_provider_names() {
93 if let Some(api_key) =
94 Self::get_git_config_value(&format!("{prefix}.{provider}-apikey"))
95 {
96 let default_model = get_default_model_for_provider(&provider).to_string();
97 let model = Self::get_git_config_value(&format!("{prefix}.{provider}-model"))
98 .unwrap_or(default_model);
99 let token_limit =
100 Self::get_git_config_i64(&format!("{prefix}.{provider}-tokenlimit")).map(|v| {
101 usize::try_from(v).expect("Failed to convert token limit from i64 to usize")
102 });
103 let additional_params = HashMap::new();
104 providers.insert(
106 provider.to_string(),
107 ProviderConfig {
108 api_key,
109 model,
110 additional_params,
111 token_limit,
112 },
113 );
114 }
115 }
116
117 Self {
118 default_provider,
119 providers,
120 use_emoji,
121 instructions,
122 instruction_preset,
123 temp_instructions: None,
124 temp_preset: None,
125 is_project_config: false,
126 }
127 }
128
129 fn get_git_config_value(key: &str) -> Option<String> {
130 let output = Command::new("git")
131 .args(["config", "--get", key])
132 .output()
133 .ok()?;
134 if output.status.success() {
135 Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
136 } else {
137 None
138 }
139 }
140
141 fn get_git_config_bool(key: &str) -> Option<bool> {
142 Self::get_git_config_value(key).and_then(|v| v.parse().ok())
143 }
144
145 fn get_git_config_i64(key: &str) -> Option<i64> {
146 Self::get_git_config_value(key).and_then(|v| v.parse().ok())
147 }
148
149 pub fn load_project_config() -> Result<Self, anyhow::Error> {
151 let mut project_config = Self::load_from_config("gitai");
152 project_config.is_project_config = true;
153 Ok(project_config)
154 }
155
156 pub fn merge_with_project_config(&mut self, project_config: Self) {
159 debug!("Merging with project configuration");
160
161 if project_config.default_provider != Self::default().default_provider {
163 self.default_provider = project_config.default_provider;
164 }
165
166 for (provider, proj_provider_config) in project_config.providers {
168 let entry = self.providers.entry(provider).or_default();
169
170 if !proj_provider_config.model.is_empty() {
172 entry.model = proj_provider_config.model;
173 }
174
175 entry
177 .additional_params
178 .extend(proj_provider_config.additional_params);
179
180 if proj_provider_config.token_limit.is_some() {
182 entry.token_limit = proj_provider_config.token_limit;
183 }
184 }
185
186 self.use_emoji = project_config.use_emoji;
188
189 self.instructions = project_config.instructions.clone();
191
192 if project_config.instruction_preset != default_instruction_preset() {
194 self.instruction_preset = project_config.instruction_preset;
195 }
196 }
197
198 pub fn save(&self) -> Result<()> {
200 if self.is_project_config {
202 return Ok(());
203 }
204
205 let mut config = GitConfig::open_default()?;
206 self.save_to_config(&mut config, "gitai")?;
207 debug!("Configuration saved to global git config: {self:?}");
208 Ok(())
209 }
210
211 fn save_to_config(&self, config: &mut GitConfig, prefix: &str) -> Result<()> {
213 config.set_str(&format!("{prefix}.defaultprovider"), &self.default_provider)?;
215
216 config.set_bool(&format!("{prefix}.useemoji"), self.use_emoji)?;
218
219 config.set_str(&format!("{prefix}.instructions"), &self.instructions)?;
221
222 config.set_str(
224 &format!("{prefix}.instructionpreset"),
225 &self.instruction_preset,
226 )?;
227
228 for (provider, provider_config) in &self.providers {
229 if !provider_config.api_key.is_empty() {
231 config.set_str(
232 &format!("{prefix}.{provider}-apikey"),
233 &provider_config.api_key,
234 )?;
235 }
236
237 config.set_str(
239 &format!("{prefix}.{provider}-model"),
240 &provider_config.model,
241 )?;
242
243 if let Some(token_limit) = provider_config.token_limit {
244 config.set_i64(
245 &format!("{prefix}.{provider}-tokenlimit"),
246 i64::try_from(token_limit).context("Token limit exceeds i64 range")?,
247 )?;
248 }
249
250 for (key, value) in &provider_config.additional_params {
251 config.set_str(&format!("{prefix}.{provider}-additional{key}"), value)?;
252 }
253 }
254
255 Ok(())
256 }
257
258 pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
260 let repo = git2::Repository::discover(".")?;
261
262 let mut project_config = self.clone();
264
265 for provider_config in project_config.providers.values_mut() {
267 provider_config.api_key.clear();
268 }
269
270 project_config.is_project_config = true;
272
273 let mut config = repo.config()?;
275 project_config.save_to_config(&mut config, "gitai")?;
276 debug!("Project configuration saved to local git config: {project_config:?}");
277 Ok(())
278 }
279
280 pub fn check_environment(&self) -> Result<()> {
282 if !GitRepo::is_inside_work_tree()? {
284 return Err(anyhow!(
285 "Not in a Git repository. Please run this command from within a Git repository."
286 ));
287 }
288
289 Ok(())
290 }
291
292 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
293 self.temp_instructions = instructions;
294 }
295
296 pub fn set_temp_preset(&mut self, preset: Option<String>) {
297 self.temp_preset = preset;
298 }
299
300 pub fn get_effective_instructions(&self) -> String {
301 let preset_library = get_instruction_preset_library();
302 let preset_instructions = self
303 .temp_preset
304 .as_ref()
305 .or(Some(&self.instruction_preset))
306 .and_then(|p| preset_library.get_preset(p))
307 .map(|p| p.instructions.clone())
308 .unwrap_or_default();
309
310 let custom_instructions = self
311 .temp_instructions
312 .as_ref()
313 .unwrap_or(&self.instructions);
314
315 format!("{preset_instructions}\n\n{custom_instructions}")
316 .trim()
317 .to_string()
318 }
319
320 #[allow(clippy::too_many_arguments)]
322 pub fn update(
323 &mut self,
324 provider: Option<String>,
325 api_key: Option<String>,
326 model: Option<String>,
327 additional_params: Option<HashMap<String, String>>,
328 use_emoji: Option<bool>,
329 instructions: Option<String>,
330 token_limit: Option<usize>,
331 ) -> anyhow::Result<()> {
332 if let Some(provider) = provider {
333 self.default_provider.clone_from(&provider);
334 if !self.providers.contains_key(&provider) {
335 if provider_requires_api_key(&provider.to_lowercase()) {
337 self.providers.insert(
338 provider.clone(),
339 ProviderConfig::default_for(&provider.to_lowercase()),
340 );
341 }
342 }
343 }
344
345 let provider_config = self
346 .providers
347 .get_mut(&self.default_provider)
348 .context("Could not get default provider")?;
349
350 if let Some(key) = api_key {
351 provider_config.api_key = key;
352 }
353 if let Some(model) = model {
354 provider_config.model = model;
355 }
356 if let Some(params) = additional_params {
357 provider_config.additional_params.extend(params);
358 }
359 if let Some(emoji) = use_emoji {
360 self.use_emoji = emoji;
361 }
362 if let Some(instr) = instructions {
363 self.instructions = instr;
364 }
365 if let Some(limit) = token_limit {
366 provider_config.token_limit = Some(limit);
367 }
368
369 debug!("Configuration updated: {self:?}");
370 Ok(())
371 }
372
373 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
375 let provider_to_lookup = if provider.to_lowercase() == "claude" {
377 "anthropic"
378 } else {
379 provider
380 };
381
382 self.providers.get(provider_to_lookup).or_else(|| {
384 let lowercase_provider = provider_to_lookup.to_lowercase();
386
387 self.providers.get(&lowercase_provider).or_else(|| {
388 if get_available_provider_names().contains(&lowercase_provider) {
390 None
393 } else {
394 None
396 }
397 })
398 })
399 }
400
401 pub fn set_project_config(&mut self, is_project: bool) {
403 self.is_project_config = is_project;
404 }
405
406 pub fn is_project_config(&self) -> bool {
408 self.is_project_config
409 }
410}
411
412impl Default for Config {
413 fn default() -> Self {
414 let mut providers = HashMap::new();
415 for provider in get_available_provider_names() {
416 providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
417 }
418
419 let default_provider = if providers.contains_key("openai") {
421 "openai".to_string()
422 } else {
423 providers.keys().next().map_or_else(
424 || "openai".to_string(), std::clone::Clone::clone,
426 )
427 };
428
429 Self {
430 default_provider,
431 providers,
432 use_emoji: default_emoji(),
433 instructions: String::new(),
434 instruction_preset: default_instruction_preset(),
435 temp_instructions: None,
436 temp_preset: None,
437 is_project_config: false,
438 }
439 }
440}
441
442impl ProviderConfig {
443 pub fn default_for(provider: &str) -> Self {
445 Self {
446 api_key: String::new(),
447 model: get_default_model_for_provider(provider).to_string(),
448 additional_params: HashMap::new(),
449 token_limit: None, }
451 }
452
453 pub fn get_token_limit(&self) -> Option<usize> {
455 self.token_limit
456 }
457}