1use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15 pub llm: LLMConfig,
17 #[cfg(feature = "local")]
19 #[serde(default)]
20 pub local: Option<LocalConfig>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LLMConfig {
26 pub model_name: String,
28 pub base_url: String,
30 pub api_key: String,
32 #[serde(default = "default_temperature")]
34 pub temperature: f32,
35 #[serde(default = "default_max_tokens")]
37 pub max_tokens: u32,
38}
39
40#[cfg(feature = "local")]
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LocalConfig {
44 pub huggingface_repo: String,
46 pub model_file: String,
48 #[serde(default = "default_context_size")]
50 pub context_size: usize,
51 #[serde(default = "default_temperature")]
53 pub temperature: f32,
54 #[serde(default = "default_max_tokens")]
56 pub max_tokens: u32,
57}
58
59fn default_temperature() -> f32 {
61 0.7
62}
63
64fn default_max_tokens() -> u32 {
66 2048
67}
68
69#[cfg(feature = "local")]
71fn default_context_size() -> usize {
72 2048
73}
74
75impl Config {
76 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
86 let content = fs::read_to_string(path)
87 .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
88
89 let config: Config = toml::from_str(&content)?;
90 Ok(config)
91 }
92
93 pub fn new_default() -> Self {
95 Self {
96 llm: LLMConfig {
97 model_name: "gpt-3.5-turbo".to_string(),
98 base_url: "https://api.openai.com/v1".to_string(),
99 api_key: "your-api-key-here".to_string(),
100 temperature: 0.7,
101 max_tokens: 2048,
102 },
103 #[cfg(feature = "local")]
104 local: None,
105 }
106 }
107
108 pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
128 Self::from_file(path).unwrap_or_else(|_| Self::new_default())
129 }
130
131 pub fn builder() -> ConfigBuilder {
143 ConfigBuilder::new()
144 }
145
146 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
156 let content = toml::to_string_pretty(self)
157 .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
158
159 fs::write(path, content)
160 .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
161
162 Ok(())
163 }
164}
165
166pub struct ConfigBuilder {
168 model_name: String,
169 base_url: String,
170 api_key: String,
171 temperature: f32,
172 max_tokens: u32,
173}
174
175impl ConfigBuilder {
176 pub fn new() -> Self {
178 Self {
179 model_name: "gpt-3.5-turbo".to_string(),
180 base_url: "https://api.openai.com/v1".to_string(),
181 api_key: std::env::var("OPENAI_API_KEY")
182 .unwrap_or_else(|_| "your-api-key-here".to_string()),
183 temperature: 0.7,
184 max_tokens: 2048,
185 }
186 }
187
188 pub fn model(mut self, model: impl Into<String>) -> Self {
190 self.model_name = model.into();
191 self
192 }
193
194 pub fn m(self, model: impl Into<String>) -> Self {
196 self.model(model)
197 }
198
199 pub fn api_key(mut self, key: impl Into<String>) -> Self {
201 self.api_key = key.into();
202 self
203 }
204
205 pub fn key(self, key: impl Into<String>) -> Self {
207 self.api_key(key)
208 }
209
210 pub fn base_url(mut self, url: impl Into<String>) -> Self {
212 self.base_url = url.into();
213 self
214 }
215
216 pub fn url(self, url: impl Into<String>) -> Self {
218 self.base_url(url)
219 }
220
221 pub fn temperature(mut self, temp: f32) -> Self {
223 self.temperature = temp;
224 self
225 }
226
227 pub fn temp(self, temp: f32) -> Self {
229 self.temperature(temp)
230 }
231
232 pub fn max_tokens(mut self, tokens: u32) -> Self {
234 self.max_tokens = tokens;
235 self
236 }
237
238 pub fn tokens(self, tokens: u32) -> Self {
240 self.max_tokens(tokens)
241 }
242
243 pub fn build(self) -> Config {
245 Config {
246 llm: LLMConfig {
247 model_name: self.model_name,
248 base_url: self.base_url,
249 api_key: self.api_key,
250 temperature: self.temperature,
251 max_tokens: self.max_tokens,
252 },
253 #[cfg(feature = "local")]
254 local: None,
255 }
256 }
257}
258
259impl Default for ConfigBuilder {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::fs;
269 use tempfile::tempdir;
270
271 #[test]
273 #[cfg(feature = "local")]
274 fn test_config_from_file() {
275 let config_content = r#"
276[llm]
277model_name = "gpt-4"
278base_url = "https://api.openai.com/v1"
279api_key = "test-key"
280temperature = 0.7
281max_tokens = 2048
282
283[local]
284huggingface_repo = "test/repo"
285model_file = "model.gguf"
286context_size = 4096
287temperature = 0.5
288max_tokens = 1024
289"#;
290 let dir = tempdir().unwrap();
291 let config_path = dir.path().join("config.toml");
292 fs::write(&config_path, config_content).unwrap();
293
294 let config = Config::from_file(&config_path).unwrap();
295 assert_eq!(config.llm.model_name, "gpt-4");
296 assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
297 }
298
299 #[test]
301 #[cfg(not(feature = "local"))]
302 fn test_config_from_file() {
303 let config_content = r#"
304[llm]
305model_name = "gpt-4"
306base_url = "https://api.openai.com/v1"
307api_key = "test-key"
308temperature = 0.7
309max_tokens = 2048
310"#;
311 let dir = tempdir().unwrap();
312 let config_path = dir.path().join("config.toml");
313 fs::write(&config_path, config_content).unwrap();
314
315 let config = Config::from_file(&config_path).unwrap();
316 assert_eq!(config.llm.model_name, "gpt-4");
317 }
318
319 #[test]
321 fn test_config_new_default() {
322 let config = Config::new_default();
323 assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
324 assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
325 assert_eq!(config.llm.api_key, "your-api-key-here");
326 assert_eq!(config.llm.temperature, 0.7);
327 assert_eq!(config.llm.max_tokens, 2048);
328 #[cfg(feature = "local")]
329 assert!(config.local.is_none());
330 }
331
332 #[test]
334 fn test_config_save() {
335 let config = Config::new_default();
336 let dir = tempdir().unwrap();
337 let config_path = dir.path().join("config.toml");
338
339 config.save(&config_path).unwrap();
340 assert!(config_path.exists());
341
342 let loaded_config = Config::from_file(&config_path).unwrap();
343 assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
344 }
345
346 #[test]
348 fn test_default_functions() {
349 assert_eq!(default_temperature(), 0.7);
350 assert_eq!(default_max_tokens(), 2048);
351 #[cfg(feature = "local")]
352 assert_eq!(default_context_size(), 2048);
353 }
354}