1use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15 pub llm: LLMConfig,
17 #[cfg(feature = "local")]
19 #[serde(default)]
20 pub local: Option<LocalConfig>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LLMConfig {
26 pub model_name: String,
28 pub base_url: String,
30 pub api_key: String,
32 #[serde(default = "default_temperature")]
34 pub temperature: f32,
35 #[serde(default = "default_max_tokens")]
37 pub max_tokens: u32,
38}
39
40#[cfg(feature = "local")]
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LocalConfig {
44 pub huggingface_repo: String,
46 pub model_file: String,
48 #[serde(default = "default_context_size")]
50 pub context_size: usize,
51 #[serde(default = "default_temperature")]
53 pub temperature: f32,
54 #[serde(default = "default_max_tokens")]
56 pub max_tokens: u32,
57}
58
59fn default_temperature() -> f32 {
61 0.7
62}
63
64fn default_max_tokens() -> u32 {
66 2048
67}
68
69#[cfg(feature = "local")]
71fn default_context_size() -> usize {
72 2048
73}
74
75impl Config {
76 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
86 let content = fs::read_to_string(path)
87 .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
88
89 let config: Config = toml::from_str(&content)?;
90 Ok(config)
91 }
92
93 pub fn new_default() -> Self {
95 Self {
96 llm: LLMConfig {
97 model_name: "gpt-3.5-turbo".to_string(),
98 base_url: "https://api.openai.com/v1".to_string(),
99 api_key: "your-api-key-here".to_string(),
100 temperature: 0.7,
101 max_tokens: 2048,
102 },
103 #[cfg(feature = "local")]
104 local: None,
105 }
106 }
107
108 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
118 let content = toml::to_string_pretty(self)
119 .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
120
121 fs::write(path, content)
122 .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
123
124 Ok(())
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use std::fs;
132 use tempfile::tempdir;
133
134 #[test]
136 #[cfg(feature = "local")]
137 fn test_config_from_file() {
138 let config_content = r#"
139[llm]
140model_name = "gpt-4"
141base_url = "https://api.openai.com/v1"
142api_key = "test-key"
143temperature = 0.7
144max_tokens = 2048
145
146[local]
147huggingface_repo = "test/repo"
148model_file = "model.gguf"
149context_size = 4096
150temperature = 0.5
151max_tokens = 1024
152"#;
153 let dir = tempdir().unwrap();
154 let config_path = dir.path().join("config.toml");
155 fs::write(&config_path, config_content).unwrap();
156
157 let config = Config::from_file(&config_path).unwrap();
158 assert_eq!(config.llm.model_name, "gpt-4");
159 assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
160 }
161
162 #[test]
164 #[cfg(not(feature = "local"))]
165 fn test_config_from_file() {
166 let config_content = r#"
167[llm]
168model_name = "gpt-4"
169base_url = "https://api.openai.com/v1"
170api_key = "test-key"
171temperature = 0.7
172max_tokens = 2048
173"#;
174 let dir = tempdir().unwrap();
175 let config_path = dir.path().join("config.toml");
176 fs::write(&config_path, config_content).unwrap();
177
178 let config = Config::from_file(&config_path).unwrap();
179 assert_eq!(config.llm.model_name, "gpt-4");
180 }
181
182 #[test]
184 fn test_config_new_default() {
185 let config = Config::new_default();
186 assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
187 assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
188 assert_eq!(config.llm.api_key, "your-api-key-here");
189 assert_eq!(config.llm.temperature, 0.7);
190 assert_eq!(config.llm.max_tokens, 2048);
191 #[cfg(feature = "local")]
192 assert!(config.local.is_none());
193 }
194
195 #[test]
197 fn test_config_save() {
198 let config = Config::new_default();
199 let dir = tempdir().unwrap();
200 let config_path = dir.path().join("config.toml");
201
202 config.save(&config_path).unwrap();
203 assert!(config_path.exists());
204
205 let loaded_config = Config::from_file(&config_path).unwrap();
206 assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
207 }
208
209 #[test]
211 fn test_default_functions() {
212 assert_eq!(default_temperature(), 0.7);
213 assert_eq!(default_max_tokens(), 2048);
214 #[cfg(feature = "local")]
215 assert_eq!(default_context_size(), 2048);
216 }
217}