1use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15 pub llm: LLMConfig,
17 #[cfg(feature = "local")]
19 #[serde(default)]
20 pub local: Option<LocalConfig>,
21 #[cfg(feature = "candle")]
23 #[serde(default)]
24 pub candle: Option<CandleConfig>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct LLMConfig {
30 pub model_name: String,
32 pub base_url: String,
34 pub api_key: String,
36 #[serde(default = "default_temperature")]
38 pub temperature: f32,
39 #[serde(default = "default_max_tokens")]
41 pub max_tokens: u32,
42}
43
44#[cfg(feature = "local")]
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LocalConfig {
48 pub huggingface_repo: String,
50 pub model_file: String,
52 #[serde(default = "default_context_size")]
54 pub context_size: usize,
55 #[serde(default = "default_temperature")]
57 pub temperature: f32,
58 #[serde(default = "default_max_tokens")]
60 pub max_tokens: u32,
61}
62
63#[cfg(feature = "candle")]
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct CandleConfig {
67 pub huggingface_repo: String,
69 pub model_file: String,
71 #[serde(default = "default_context_size")]
73 pub context_size: usize,
74 #[serde(default = "default_temperature")]
76 pub temperature: f32,
77 #[serde(default = "default_max_tokens")]
79 pub max_tokens: u32,
80 #[serde(default = "default_use_gpu")]
82 pub use_gpu: bool,
83}
84
85fn default_temperature() -> f32 {
87 0.7
88}
89
90fn default_max_tokens() -> u32 {
92 2048
93}
94
95#[cfg(any(feature = "local", feature = "candle"))]
97fn default_context_size() -> usize {
98 2048
99}
100
101#[cfg(feature = "candle")]
103fn default_use_gpu() -> bool {
104 true
105}
106
107impl Config {
108 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
118 let content = fs::read_to_string(path)
119 .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
120
121 let config: Config = toml::from_str(&content)?;
122 Ok(config)
123 }
124
125 pub fn new_default() -> Self {
127 Self {
128 llm: LLMConfig {
129 model_name: "gpt-3.5-turbo".to_string(),
130 base_url: "https://api.openai.com/v1".to_string(),
131 api_key: "your-api-key-here".to_string(),
132 temperature: 0.7,
133 max_tokens: 2048,
134 },
135 #[cfg(feature = "local")]
136 local: None,
137 #[cfg(feature = "candle")]
138 candle: None,
139 }
140 }
141
142 pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
162 Self::from_file(path).unwrap_or_else(|_| Self::new_default())
163 }
164
165 pub fn builder() -> ConfigBuilder {
177 ConfigBuilder::new()
178 }
179
180 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
190 let content = toml::to_string_pretty(self)
191 .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
192
193 fs::write(path, content)
194 .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
195
196 Ok(())
197 }
198}
199
200pub struct ConfigBuilder {
202 model_name: String,
203 base_url: String,
204 api_key: String,
205 temperature: f32,
206 max_tokens: u32,
207}
208
209impl ConfigBuilder {
210 pub fn new() -> Self {
212 Self {
213 model_name: "gpt-3.5-turbo".to_string(),
214 base_url: "https://api.openai.com/v1".to_string(),
215 api_key: std::env::var("OPENAI_API_KEY")
216 .unwrap_or_else(|_| "your-api-key-here".to_string()),
217 temperature: 0.7,
218 max_tokens: 2048,
219 }
220 }
221
222 pub fn model(mut self, model: impl Into<String>) -> Self {
224 self.model_name = model.into();
225 self
226 }
227
228 pub fn m(self, model: impl Into<String>) -> Self {
230 self.model(model)
231 }
232
233 pub fn api_key(mut self, key: impl Into<String>) -> Self {
235 self.api_key = key.into();
236 self
237 }
238
239 pub fn key(self, key: impl Into<String>) -> Self {
241 self.api_key(key)
242 }
243
244 pub fn base_url(mut self, url: impl Into<String>) -> Self {
246 self.base_url = url.into();
247 self
248 }
249
250 pub fn url(self, url: impl Into<String>) -> Self {
252 self.base_url(url)
253 }
254
255 pub fn temperature(mut self, temp: f32) -> Self {
257 self.temperature = temp;
258 self
259 }
260
261 pub fn temp(self, temp: f32) -> Self {
263 self.temperature(temp)
264 }
265
266 pub fn max_tokens(mut self, tokens: u32) -> Self {
268 self.max_tokens = tokens;
269 self
270 }
271
272 pub fn tokens(self, tokens: u32) -> Self {
274 self.max_tokens(tokens)
275 }
276
277 pub fn build(self) -> Config {
279 Config {
280 llm: LLMConfig {
281 model_name: self.model_name,
282 base_url: self.base_url,
283 api_key: self.api_key,
284 temperature: self.temperature,
285 max_tokens: self.max_tokens,
286 },
287 #[cfg(feature = "local")]
288 local: None,
289 #[cfg(feature = "candle")]
290 candle: None,
291 }
292 }
293}
294
295impl Default for ConfigBuilder {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use std::fs;
305 use tempfile::tempdir;
306
307 #[test]
309 #[cfg(feature = "local")]
310 fn test_config_from_file() {
311 let config_content = r#"
312[llm]
313model_name = "gpt-4"
314base_url = "https://api.openai.com/v1"
315api_key = "test-key"
316temperature = 0.7
317max_tokens = 2048
318
319[local]
320huggingface_repo = "test/repo"
321model_file = "model.gguf"
322context_size = 4096
323temperature = 0.5
324max_tokens = 1024
325"#;
326 let dir = tempdir().unwrap();
327 let config_path = dir.path().join("config.toml");
328 fs::write(&config_path, config_content).unwrap();
329
330 let config = Config::from_file(&config_path).unwrap();
331 assert_eq!(config.llm.model_name, "gpt-4");
332 assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
333 }
334
335 #[test]
337 #[cfg(not(feature = "local"))]
338 fn test_config_from_file() {
339 let config_content = r#"
340[llm]
341model_name = "gpt-4"
342base_url = "https://api.openai.com/v1"
343api_key = "test-key"
344temperature = 0.7
345max_tokens = 2048
346"#;
347 let dir = tempdir().unwrap();
348 let config_path = dir.path().join("config.toml");
349 fs::write(&config_path, config_content).unwrap();
350
351 let config = Config::from_file(&config_path).unwrap();
352 assert_eq!(config.llm.model_name, "gpt-4");
353 }
354
355 #[test]
357 fn test_config_new_default() {
358 let config = Config::new_default();
359 assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
360 assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
361 assert_eq!(config.llm.api_key, "your-api-key-here");
362 assert_eq!(config.llm.temperature, 0.7);
363 assert_eq!(config.llm.max_tokens, 2048);
364 #[cfg(feature = "local")]
365 assert!(config.local.is_none());
366 }
367
368 #[test]
370 fn test_config_save() {
371 let config = Config::new_default();
372 let dir = tempdir().unwrap();
373 let config_path = dir.path().join("config.toml");
374
375 config.save(&config_path).unwrap();
376 assert!(config_path.exists());
377
378 let loaded_config = Config::from_file(&config_path).unwrap();
379 assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
380 }
381
382 #[test]
384 fn test_default_functions() {
385 assert_eq!(default_temperature(), 0.7);
386 assert_eq!(default_max_tokens(), 2048);
387 #[cfg(feature = "local")]
388 assert_eq!(default_context_size(), 2048);
389 }
390}