1use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15 pub llm: LLMConfig,
17 #[serde(default)]
19 pub local: Option<LocalConfig>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LLMConfig {
25 pub model_name: String,
27 pub base_url: String,
29 pub api_key: String,
31 #[serde(default = "default_temperature")]
33 pub temperature: f32,
34 #[serde(default = "default_max_tokens")]
36 pub max_tokens: u32,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LocalConfig {
42 pub huggingface_repo: String,
44 pub model_file: String,
46 #[serde(default = "default_context_size")]
48 pub context_size: usize,
49 #[serde(default = "default_temperature")]
51 pub temperature: f32,
52 #[serde(default = "default_max_tokens")]
54 pub max_tokens: u32,
55}
56
57fn default_temperature() -> f32 {
59 0.7
60}
61
62fn default_max_tokens() -> u32 {
64 2048
65}
66
67fn default_context_size() -> usize {
69 2048
70}
71
72impl Config {
73 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
83 let content = fs::read_to_string(path)
84 .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
85
86 let config: Config = toml::from_str(&content)?;
87 Ok(config)
88 }
89
90 pub fn new_default() -> Self {
92 Self {
93 llm: LLMConfig {
94 model_name: "gpt-3.5-turbo".to_string(),
95 base_url: "https://api.openai.com/v1".to_string(),
96 api_key: "your-api-key-here".to_string(),
97 temperature: 0.7,
98 max_tokens: 2048,
99 },
100 local: None,
101 }
102 }
103
104 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
114 let content = toml::to_string_pretty(self)
115 .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
116
117 fs::write(path, content)
118 .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
119
120 Ok(())
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::fs;
128 use tempfile::tempdir;
129
130 #[test]
132 fn test_config_from_file() {
133 let config_content = r#"
134[llm]
135model_name = "gpt-4"
136base_url = "https://api.openai.com/v1"
137api_key = "test-key"
138temperature = 0.7
139max_tokens = 2048
140
141[local]
142huggingface_repo = "test/repo"
143model_file = "model.gguf"
144context_size = 4096
145temperature = 0.5
146max_tokens = 1024
147"#;
148 let dir = tempdir().unwrap();
149 let config_path = dir.path().join("config.toml");
150 fs::write(&config_path, config_content).unwrap();
151
152 let config = Config::from_file(&config_path).unwrap();
153 assert_eq!(config.llm.model_name, "gpt-4");
154 assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
155 }
156
157 #[test]
159 fn test_config_new_default() {
160 let config = Config::new_default();
161 assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
162 assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
163 assert_eq!(config.llm.api_key, "your-api-key-here");
164 assert_eq!(config.llm.temperature, 0.7);
165 assert_eq!(config.llm.max_tokens, 2048);
166 assert!(config.local.is_none());
167 }
168
169 #[test]
171 fn test_config_save() {
172 let config = Config::new_default();
173 let dir = tempdir().unwrap();
174 let config_path = dir.path().join("config.toml");
175
176 config.save(&config_path).unwrap();
177 assert!(config_path.exists());
178
179 let loaded_config = Config::from_file(&config_path).unwrap();
180 assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
181 }
182
183 #[test]
185 fn test_default_functions() {
186 assert_eq!(default_temperature(), 0.7);
187 assert_eq!(default_max_tokens(), 2048);
188 assert_eq!(default_context_size(), 2048);
189 }
190}