helios_engine/
config.rs

1//! # Configuration Module
2//!
3//! This module defines the data structures for configuring the Helios Engine.
4//! It includes settings for both remote and local Language Models (LLMs),
5//! and provides methods for loading and saving configurations from/to TOML files.
6
7use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12/// The main configuration for the Helios Engine.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15    /// The configuration for the remote LLM.
16    pub llm: LLMConfig,
17    /// The configuration for the local LLM (optional).
18    #[cfg(feature = "local")]
19    #[serde(default)]
20    pub local: Option<LocalConfig>,
21}
22
23/// Configuration for a remote Language Model (LLM).
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LLMConfig {
26    /// The name of the model to use.
27    pub model_name: String,
28    /// The base URL of the LLM API.
29    pub base_url: String,
30    /// The API key for the LLM API.
31    pub api_key: String,
32    /// The temperature to use for the LLM.
33    #[serde(default = "default_temperature")]
34    pub temperature: f32,
35    /// The maximum number of tokens to generate.
36    #[serde(default = "default_max_tokens")]
37    pub max_tokens: u32,
38}
39
40/// Configuration for a local Language Model (LLM).
41#[cfg(feature = "local")]
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LocalConfig {
44    /// The Hugging Face repository of the model.
45    pub huggingface_repo: String,
46    /// The model file to use.
47    pub model_file: String,
48    /// The context size to use for the LLM.
49    #[serde(default = "default_context_size")]
50    pub context_size: usize,
51    /// The temperature to use for the LLM.
52    #[serde(default = "default_temperature")]
53    pub temperature: f32,
54    /// The maximum number of tokens to generate.
55    #[serde(default = "default_max_tokens")]
56    pub max_tokens: u32,
57}
58
59/// Returns the default temperature value.
60fn default_temperature() -> f32 {
61    0.7
62}
63
64/// Returns the default maximum number of tokens.
65fn default_max_tokens() -> u32 {
66    2048
67}
68
69/// Returns the default context size.
70#[cfg(feature = "local")]
71fn default_context_size() -> usize {
72    2048
73}
74
75impl Config {
76    /// Loads the configuration from a TOML file.
77    ///
78    /// # Arguments
79    ///
80    /// * `path` - The path to the TOML file.
81    ///
82    /// # Returns
83    ///
84    /// A `Result` containing the loaded `Config`.
85    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
86        let content = fs::read_to_string(path)
87            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
88
89        let config: Config = toml::from_str(&content)?;
90        Ok(config)
91    }
92
93    /// Creates a new default configuration.
94    pub fn new_default() -> Self {
95        Self {
96            llm: LLMConfig {
97                model_name: "gpt-3.5-turbo".to_string(),
98                base_url: "https://api.openai.com/v1".to_string(),
99                api_key: "your-api-key-here".to_string(),
100                temperature: 0.7,
101                max_tokens: 2048,
102            },
103            #[cfg(feature = "local")]
104            local: None,
105        }
106    }
107
108    /// Saves the configuration to a TOML file.
109    ///
110    /// # Arguments
111    ///
112    /// * `path` - The path to the TOML file.
113    ///
114    /// # Returns
115    ///
116    /// A `Result` indicating success or failure.
117    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
118        let content = toml::to_string_pretty(self)
119            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
120
121        fs::write(path, content)
122            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
123
124        Ok(())
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use std::fs;
132    use tempfile::tempdir;
133
134    /// Tests loading a configuration from a file.
135    #[test]
136    #[cfg(feature = "local")]
137    fn test_config_from_file() {
138        let config_content = r#"
139[llm]
140model_name = "gpt-4"
141base_url = "https://api.openai.com/v1"
142api_key = "test-key"
143temperature = 0.7
144max_tokens = 2048
145
146[local]
147huggingface_repo = "test/repo"
148model_file = "model.gguf"
149context_size = 4096
150temperature = 0.5
151max_tokens = 1024
152"#;
153        let dir = tempdir().unwrap();
154        let config_path = dir.path().join("config.toml");
155        fs::write(&config_path, config_content).unwrap();
156
157        let config = Config::from_file(&config_path).unwrap();
158        assert_eq!(config.llm.model_name, "gpt-4");
159        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
160    }
161
162    /// Tests loading a configuration from a file without local config.
163    #[test]
164    #[cfg(not(feature = "local"))]
165    fn test_config_from_file() {
166        let config_content = r#"
167[llm]
168model_name = "gpt-4"
169base_url = "https://api.openai.com/v1"
170api_key = "test-key"
171temperature = 0.7
172max_tokens = 2048
173"#;
174        let dir = tempdir().unwrap();
175        let config_path = dir.path().join("config.toml");
176        fs::write(&config_path, config_content).unwrap();
177
178        let config = Config::from_file(&config_path).unwrap();
179        assert_eq!(config.llm.model_name, "gpt-4");
180    }
181
182    /// Tests creating a new default configuration.
183    #[test]
184    fn test_config_new_default() {
185        let config = Config::new_default();
186        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
187        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
188        assert_eq!(config.llm.api_key, "your-api-key-here");
189        assert_eq!(config.llm.temperature, 0.7);
190        assert_eq!(config.llm.max_tokens, 2048);
191        #[cfg(feature = "local")]
192        assert!(config.local.is_none());
193    }
194
195    /// Tests saving a configuration to a file.
196    #[test]
197    fn test_config_save() {
198        let config = Config::new_default();
199        let dir = tempdir().unwrap();
200        let config_path = dir.path().join("config.toml");
201
202        config.save(&config_path).unwrap();
203        assert!(config_path.exists());
204
205        let loaded_config = Config::from_file(&config_path).unwrap();
206        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
207    }
208
209    /// Tests the default value functions.
210    #[test]
211    fn test_default_functions() {
212        assert_eq!(default_temperature(), 0.7);
213        assert_eq!(default_max_tokens(), 2048);
214        #[cfg(feature = "local")]
215        assert_eq!(default_context_size(), 2048);
216    }
217}