helios_engine/
config.rs

1use crate::error::{HeliosError, Result};
2use serde::{Deserialize, Serialize};
3use std::fs;
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8    pub llm: LLMConfig,
9    #[serde(default)]
10    pub local: Option<LocalConfig>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LLMConfig {
15    pub model_name: String,
16    pub base_url: String,
17    pub api_key: String,
18    #[serde(default = "default_temperature")]
19    pub temperature: f32,
20    #[serde(default = "default_max_tokens")]
21    pub max_tokens: u32,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LocalConfig {
26    pub huggingface_repo: String,
27    pub model_file: String,
28    #[serde(default = "default_context_size")]
29    pub context_size: usize,
30    #[serde(default = "default_temperature")]
31    pub temperature: f32,
32    #[serde(default = "default_max_tokens")]
33    pub max_tokens: u32,
34}
35
36fn default_temperature() -> f32 {
37    0.7
38}
39
40fn default_max_tokens() -> u32 {
41    2048
42}
43
44fn default_context_size() -> usize {
45    2048
46}
47
48impl Config {
49    /// Load configuration from a TOML file
50    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
51        let content = fs::read_to_string(path)
52            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
53
54        let config: Config = toml::from_str(&content)?;
55        Ok(config)
56    }
57
58    /// Create a default configuration
59    pub fn new_default() -> Self {
60        Self {
61            llm: LLMConfig {
62                model_name: "gpt-3.5-turbo".to_string(),
63                base_url: "https://api.openai.com/v1".to_string(),
64                api_key: "your-api-key-here".to_string(),
65                temperature: 0.7,
66                max_tokens: 2048,
67            },
68            local: None,
69        }
70    }
71
72    /// Save configuration to a TOML file
73    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
74        let content = toml::to_string_pretty(self)
75            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
76
77        fs::write(path, content)
78            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
79
80        Ok(())
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use std::fs;
88    use tempfile::tempdir;
89
90    #[test]
91    fn test_config_from_file() {
92        let config_content = r#"
93[llm]
94model_name = "gpt-4"
95base_url = "https://api.openai.com/v1"
96api_key = "test-key"
97temperature = 0.7
98max_tokens = 2048
99
100[local]
101huggingface_repo = "test/repo"
102model_file = "model.gguf"
103context_size = 4096
104temperature = 0.5
105max_tokens = 1024
106"#;
107        let dir = tempdir().unwrap();
108        let config_path = dir.path().join("config.toml");
109        fs::write(&config_path, config_content).unwrap();
110
111        let config = Config::from_file(&config_path).unwrap();
112        assert_eq!(config.llm.model_name, "gpt-4");
113        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
114    }
115
116    #[test]
117    fn test_config_new_default() {
118        let config = Config::new_default();
119        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
120        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
121        assert_eq!(config.llm.api_key, "your-api-key-here");
122        assert_eq!(config.llm.temperature, 0.7);
123        assert_eq!(config.llm.max_tokens, 2048);
124        assert!(config.local.is_none());
125    }
126
127    #[test]
128    fn test_config_save() {
129        let config = Config::new_default();
130        let dir = tempdir().unwrap();
131        let config_path = dir.path().join("config.toml");
132
133        config.save(&config_path).unwrap();
134        assert!(config_path.exists());
135
136        let loaded_config = Config::from_file(&config_path).unwrap();
137        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
138    }
139
140    #[test]
141    fn test_default_functions() {
142        assert_eq!(default_temperature(), 0.7);
143        assert_eq!(default_max_tokens(), 2048);
144        assert_eq!(default_context_size(), 2048);
145    }
146}