1use crate::git::GitRepo;
2use crate::instruction_presets::get_instruction_preset_library;
3use crate::llm_providers::{
4 get_available_providers, get_provider_metadata, LLMProviderConfig, LLMProviderType,
5};
6use crate::log_debug;
7
8use anyhow::{anyhow, Context, Result};
9use dirs::config_dir;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16#[derive(Deserialize, Serialize, Clone, Debug)]
18pub struct Config {
19 pub default_provider: String,
21 pub providers: HashMap<String, ProviderConfig>,
23 #[serde(default = "default_gitmoji")]
25 pub use_gitmoji: bool,
26 #[serde(default)]
28 pub instructions: String,
29 #[serde(default = "default_instruction_preset")]
30 pub instruction_preset: String,
31 #[serde(skip)]
32 pub temp_instructions: Option<String>,
33 #[serde(skip)]
34 pub temp_preset: Option<String>,
35}
36
37#[derive(Deserialize, Serialize, Clone, Debug, Default)]
39pub struct ProviderConfig {
40 pub api_key: String,
42 pub model: String,
44 #[serde(default)]
46 pub additional_params: HashMap<String, String>,
47 pub token_limit: Option<usize>,
49}
50
51fn default_gitmoji() -> bool {
53 true
54}
55
56fn default_instruction_preset() -> String {
58 "default".to_string()
59}
60
61impl Config {
62 pub fn load() -> Result<Self> {
64 let config_path = Self::get_config_path()?;
65 if !config_path.exists() {
66 return Ok(Self::default());
67 }
68 let config_content = fs::read_to_string(config_path)?;
69 let config: Self = toml::from_str(&config_content)?;
70 log_debug!("Configuration loaded: {:?}", config);
71 Ok(config)
72 }
73
74 pub fn save(&self) -> Result<()> {
76 let config_path = Self::get_config_path()?;
77 let config_content = toml::to_string(self)?;
78 fs::write(config_path, config_content)?;
79 log_debug!("Configuration saved: {:?}", self);
80 Ok(())
81 }
82
83 fn get_config_path() -> Result<PathBuf> {
85 let mut path =
86 config_dir().ok_or_else(|| anyhow!("Unable to determine config directory"))?;
87 path.push("git-iris");
88 std::fs::create_dir_all(&path)?;
89 path.push("config.toml");
90 Ok(path)
91 }
92
93 pub fn check_environment(&self) -> Result<()> {
95 if !GitRepo::is_inside_work_tree()? {
97 return Err(anyhow!(
98 "Not in a Git repository. Please run this command from within a Git repository."
99 ));
100 }
101
102 Ok(())
103 }
104
105 pub fn set_temp_instructions(&mut self, instructions: Option<String>) {
106 self.temp_instructions = instructions;
107 }
108
109 pub fn set_temp_preset(&mut self, preset: Option<String>) {
110 self.temp_preset = preset;
111 }
112
113 pub fn get_effective_instructions(&self) -> String {
114 let preset_library = get_instruction_preset_library();
115 let preset_instructions = self
116 .temp_preset
117 .as_ref()
118 .or(Some(&self.instruction_preset))
119 .and_then(|p| preset_library.get_preset(p))
120 .map(|p| p.instructions.clone())
121 .unwrap_or_default();
122
123 let custom_instructions = self
124 .temp_instructions
125 .as_ref()
126 .unwrap_or(&self.instructions);
127
128 format!("{preset_instructions}\n\n{custom_instructions}")
129 .trim()
130 .to_string()
131 }
132
133 #[allow(clippy::too_many_arguments)]
135 pub fn update(
136 &mut self,
137 provider: Option<String>,
138 api_key: Option<String>,
139 model: Option<String>,
140 additional_params: Option<HashMap<String, String>>,
141 use_gitmoji: Option<bool>,
142 instructions: Option<String>,
143 token_limit: Option<usize>,
144 ) -> anyhow::Result<()> {
145 if let Some(provider) = provider {
146 self.default_provider.clone_from(&provider);
147 if !self.providers.contains_key(&provider) {
148 let provider_type =
150 LLMProviderType::from_str(&provider).unwrap_or(LLMProviderType::OpenAI);
151 if get_provider_metadata(&provider_type).requires_api_key {
152 self.providers
153 .insert(provider.clone(), ProviderConfig::default_for(&provider));
154 }
155 }
156 }
157
158 let provider_config = self
159 .providers
160 .get_mut(&self.default_provider)
161 .context("Could not get default provider")?;
162
163 if let Some(key) = api_key {
164 provider_config.api_key = key;
165 }
166 if let Some(model) = model {
167 provider_config.model = model;
168 }
169 if let Some(params) = additional_params {
170 provider_config.additional_params.extend(params);
171 }
172 if let Some(gitmoji) = use_gitmoji {
173 self.use_gitmoji = gitmoji;
174 }
175 if let Some(instr) = instructions {
176 self.instructions = instr;
177 }
178 if let Some(limit) = token_limit {
179 provider_config.token_limit = Some(limit);
180 }
181
182 log_debug!("Configuration updated: {:?}", self);
183 Ok(())
184 }
185
186 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
188 self.providers.get(provider).or_else(|| {
189 if LLMProviderType::from_str(provider).is_ok() {
191 None
194 } else {
195 None
197 }
198 })
199 }
200}
201
202impl Default for Config {
203 #[allow(clippy::unwrap_used)] fn default() -> Self {
205 let mut providers = HashMap::new();
206 for provider in get_available_providers() {
207 providers.insert(
208 provider.to_string(),
209 ProviderConfig::default_for(provider.as_ref()),
210 );
211 }
212
213 Self {
214 default_provider: get_available_providers().first().unwrap().to_string(),
215 providers,
216 use_gitmoji: true,
217 instructions: String::new(),
218 instruction_preset: default_instruction_preset(),
219 temp_instructions: None,
220 temp_preset: None,
221 }
222 }
223}
224
225impl ProviderConfig {
226 pub fn default_for(provider: &str) -> Self {
228 let provider_type =
229 LLMProviderType::from_str(provider).unwrap_or_else(|_| get_available_providers()[0]);
230 let metadata = get_provider_metadata(&provider_type);
231 Self {
232 api_key: String::new(),
233 model: metadata.default_model.to_string(),
234 additional_params: HashMap::new(),
235 token_limit: Some(metadata.default_token_limit),
236 }
237 }
238
239 pub fn get_token_limit(&self) -> usize {
241 self.token_limit.unwrap_or_else(|| {
242 let provider_type = LLMProviderType::from_str(&self.model)
243 .unwrap_or_else(|_| get_available_providers()[0]);
244 get_provider_metadata(&provider_type).default_token_limit
245 })
246 }
247
248 pub fn to_llm_provider_config(&self) -> LLMProviderConfig {
250 let mut additional_params = self.additional_params.clone();
251
252 if let Some(limit) = self.token_limit {
254 additional_params.insert("token_limit".to_string(), limit.to_string());
255 }
256
257 LLMProviderConfig {
258 api_key: self.api_key.clone(),
259 model: self.model.clone(),
260 additional_params,
261 }
262 }
263}