1use crate::core::llm::{
2 get_available_provider_names, get_default_model_for_provider, provider_requires_api_key,
3};
4use crate::debug;
5use crate::git::GitRepo;
6
7use anyhow::{Context, Result, anyhow};
8use git2::Config as GitConfig;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Command;
12
13#[derive(Deserialize, Serialize, Clone, Debug)]
15pub struct Config {
16 pub default_provider: String,
18 pub providers: HashMap<String, ProviderConfig>,
20 #[serde(default)]
22 pub instructions: String,
23 #[serde(skip)]
24 pub temp_instructions: Option<String>,
25 #[serde(skip)]
27 pub is_local: bool,
28}
29
30#[derive(Deserialize, Serialize, Clone, Debug, Default)]
32pub struct ProviderConfig {
33 pub api_key: String,
35 pub model_name: String,
37 #[serde(default)]
39 pub additional_params: HashMap<String, String>,
40 pub token_limit: Option<usize>,
42}
43
44impl Config {
45 pub fn load() -> Result<Self> {
47 let mut config = Self::load_from_config("gitai");
48
49 if let Ok(project_config) = Self::load_project_config() {
51 config.merge_with_project_config(project_config);
52 }
53
54 debug!("Configuration loaded: {config:?}");
55 Ok(config)
56 }
57
58 fn load_from_config(prefix: &str) -> Self {
60 let default_provider = Self::get_git_config_value(&format!("{prefix}.defaultprovider"))
61 .unwrap_or("openai".to_string());
62 let instructions =
63 Self::get_git_config_value(&format!("{prefix}.instructions")).unwrap_or_default();
64
65 let mut providers = HashMap::new();
66 for provider in get_available_provider_names() {
69 if let Some(api_key) =
70 Self::get_git_config_value(&format!("{prefix}.{provider}-apikey"))
71 {
72 let default_model = get_default_model_for_provider(&provider).to_string();
73 let model = Self::get_git_config_value(&format!("{prefix}.{provider}-model"))
74 .unwrap_or(default_model);
75 let token_limit =
76 Self::get_git_config_i64(&format!("{prefix}.{provider}-tokenlimit")).map(|v| {
77 usize::try_from(v).expect("Failed to convert token limit from i64 to usize")
78 });
79 let additional_params = HashMap::new();
80 providers.insert(
82 provider.to_string(),
83 ProviderConfig {
84 api_key,
85 model_name: model,
86 additional_params,
87 token_limit,
88 },
89 );
90 }
91 }
92
93 Self {
94 default_provider,
95 providers,
96 instructions,
97 temp_instructions: None,
98 is_local: false,
99 }
100 }
101
102 fn get_git_config_value(key: &str) -> Option<String> {
103 let output = Command::new("git")
104 .args(["config", "--get", key])
105 .output()
106 .ok()?;
107 if output.status.success() {
108 Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
109 } else {
110 None
111 }
112 }
113
114 #[allow(unused)]
115 fn get_git_config_bool(key: &str) -> Option<bool> {
116 Self::get_git_config_value(key).and_then(|v| v.parse().ok())
117 }
118
119 fn get_git_config_i64(key: &str) -> Option<i64> {
120 Self::get_git_config_value(key).and_then(|v| v.parse().ok())
121 }
122
123 pub fn load_project_config() -> Result<Self, anyhow::Error> {
125 let mut project_config = Self::load_from_config("gitai");
126 project_config.is_local = true;
127 Ok(project_config)
128 }
129
130 pub fn merge_with_project_config(&mut self, project_config: Self) {
133 debug!("Merging with project configuration");
134
135 if project_config.default_provider != Self::default().default_provider {
137 self.default_provider = project_config.default_provider;
138 }
139
140 for (provider, proj_provider_config) in project_config.providers {
142 let entry = self.providers.entry(provider).or_default();
143
144 if !proj_provider_config.model_name.is_empty() {
146 entry.model_name = proj_provider_config.model_name;
147 }
148
149 entry
151 .additional_params
152 .extend(proj_provider_config.additional_params);
153
154 if proj_provider_config.token_limit.is_some() {
156 entry.token_limit = proj_provider_config.token_limit;
157 }
158 }
159
160 self.instructions = project_config.instructions.clone();
162 }
163
164 pub fn save(&self) -> Result<()> {
166 if self.is_local {
168 return Ok(());
169 }
170
171 let mut config = GitConfig::open_default()?;
172 self.save_to_config(&mut config, "gitai")?;
173 debug!("Configuration saved to global git config: {self:?}");
174 Ok(())
175 }
176
177 fn save_to_config(&self, config: &mut GitConfig, prefix: &str) -> Result<()> {
179 config.set_str(&format!("{prefix}.defaultprovider"), &self.default_provider)?;
181
182 config.set_str(&format!("{prefix}.instructions"), &self.instructions)?;
184
185 for (provider, provider_config) in &self.providers {
186 if !provider_config.api_key.is_empty() {
188 config.set_str(
189 &format!("{prefix}.{provider}-apikey"),
190 &provider_config.api_key,
191 )?;
192 }
193
194 config.set_str(
196 &format!("{prefix}.{provider}-model"),
197 &provider_config.model_name,
198 )?;
199
200 if let Some(token_limit) = provider_config.token_limit {
201 config.set_i64(
202 &format!("{prefix}.{provider}-tokenlimit"),
203 i64::try_from(token_limit).context("Token limit exceeds i64 range")?,
204 )?;
205 }
206
207 for (key, value) in &provider_config.additional_params {
208 config.set_str(&format!("{prefix}.{provider}-additional{key}"), value)?;
209 }
210 }
211
212 Ok(())
213 }
214
215 pub fn save_as_project_config(&self) -> Result<(), anyhow::Error> {
217 let repo = git2::Repository::discover(".")?;
218
219 let mut project_config = self.clone();
221
222 for provider_config in project_config.providers.values_mut() {
224 provider_config.api_key.clear();
225 }
226
227 project_config.is_local = true;
229
230 let mut config = repo.config()?;
232 project_config.save_to_config(&mut config, "gitai")?;
233 debug!("Project configuration saved to local git config: {project_config:?}");
234 Ok(())
235 }
236
237 pub fn check_environment(&self) -> Result<()> {
239 if !GitRepo::is_inside_work_tree()? {
241 return Err(anyhow!(
242 "Not in a Git repository. Please run this command from within a Git repository."
243 ));
244 }
245
246 Ok(())
247 }
248
249 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
250 self.temp_instructions = instructions;
251 }
252
253 pub fn get_effective_instructions(&self) -> String {
254 let custom_instructions = self
255 .temp_instructions
256 .as_ref()
257 .unwrap_or(&self.instructions);
258
259 custom_instructions.trim().to_string()
260 }
261
262 #[allow(clippy::too_many_arguments)]
264 pub fn update(
265 &mut self,
266 provider: Option<String>,
267 api_key: Option<String>,
268 model: Option<String>,
269 additional_params: Option<HashMap<String, String>>,
270 instructions: Option<String>,
271 token_limit: Option<usize>,
272 ) -> anyhow::Result<()> {
273 if let Some(provider) = provider {
274 self.default_provider.clone_from(&provider);
275 if !self.providers.contains_key(&provider) {
276 if provider_requires_api_key(&provider.to_lowercase()) {
278 self.providers.insert(
279 provider.clone(),
280 ProviderConfig::default_for(&provider.to_lowercase()),
281 );
282 }
283 }
284 }
285
286 let provider_config = self
287 .providers
288 .get_mut(&self.default_provider)
289 .context("Could not get default provider")?;
290
291 if let Some(key) = api_key {
292 provider_config.api_key = key;
293 }
294 if let Some(model) = model {
295 provider_config.model_name = model;
296 }
297 if let Some(params) = additional_params {
298 provider_config.additional_params.extend(params);
299 }
300
301 if let Some(instr) = instructions {
302 self.instructions = instr;
303 }
304 if let Some(limit) = token_limit {
305 provider_config.token_limit = Some(limit);
306 }
307
308 debug!("Configuration updated: {self:?}");
309 Ok(())
310 }
311
312 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
314 let provider_to_lookup = if provider.to_lowercase() == "claude" {
316 "anthropic"
317 } else {
318 provider
319 };
320
321 self.providers.get(provider_to_lookup).or_else(|| {
323 let lowercase_provider = provider_to_lookup.to_lowercase();
325
326 self.providers.get(&lowercase_provider).or_else(|| {
327 if get_available_provider_names().contains(&lowercase_provider) {
329 None
332 } else {
333 None
335 }
336 })
337 })
338 }
339
340 pub fn set_project_config(&mut self, is_project: bool) {
342 self.is_local = is_project;
343 }
344
345 pub fn is_project_config(&self) -> bool {
347 self.is_local
348 }
349}
350
351impl Default for Config {
352 fn default() -> Self {
353 let mut providers = HashMap::new();
354 for provider in get_available_provider_names() {
355 providers.insert(provider.clone(), ProviderConfig::default_for(&provider));
356 }
357
358 let default_provider = if providers.contains_key("openai") {
360 "openai".to_string()
361 } else {
362 providers.keys().next().map_or_else(
363 || "openai".to_string(), std::clone::Clone::clone,
365 )
366 };
367
368 Self {
369 default_provider,
370 providers,
371 instructions: String::new(),
372 temp_instructions: None,
373 is_local: false,
374 }
375 }
376}
377
378impl ProviderConfig {
379 pub fn default_for(provider: &str) -> Self {
381 Self {
382 api_key: String::new(),
383 model_name: get_default_model_for_provider(provider).to_string(),
384 additional_params: HashMap::new(),
385 token_limit: None, }
387 }
388
389 pub fn get_token_limit(&self) -> Option<usize> {
391 self.token_limit
392 }
393}