helios_engine/
config.rs

1//! # Configuration Module
2//!
3//! This module defines the data structures for configuring the Helios Engine.
4//! It includes settings for both remote and local Language Models (LLMs),
5//! and provides methods for loading and saving configurations from/to TOML files.
6
7use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12/// The main configuration for the Helios Engine.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15    /// The configuration for the remote LLM.
16    pub llm: LLMConfig,
17    /// The configuration for the local LLM (optional).
18    #[cfg(feature = "local")]
19    #[serde(default)]
20    pub local: Option<LocalConfig>,
21}
22
23/// Configuration for a remote Language Model (LLM).
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LLMConfig {
26    /// The name of the model to use.
27    pub model_name: String,
28    /// The base URL of the LLM API.
29    pub base_url: String,
30    /// The API key for the LLM API.
31    pub api_key: String,
32    /// The temperature to use for the LLM.
33    #[serde(default = "default_temperature")]
34    pub temperature: f32,
35    /// The maximum number of tokens to generate.
36    #[serde(default = "default_max_tokens")]
37    pub max_tokens: u32,
38}
39
40/// Configuration for a local Language Model (LLM).
41#[cfg(feature = "local")]
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct LocalConfig {
44    /// The Hugging Face repository of the model.
45    pub huggingface_repo: String,
46    /// The model file to use.
47    pub model_file: String,
48    /// The context size to use for the LLM.
49    #[serde(default = "default_context_size")]
50    pub context_size: usize,
51    /// The temperature to use for the LLM.
52    #[serde(default = "default_temperature")]
53    pub temperature: f32,
54    /// The maximum number of tokens to generate.
55    #[serde(default = "default_max_tokens")]
56    pub max_tokens: u32,
57}
58
59/// Returns the default temperature value.
60fn default_temperature() -> f32 {
61    0.7
62}
63
64/// Returns the default maximum number of tokens.
65fn default_max_tokens() -> u32 {
66    2048
67}
68
69/// Returns the default context size.
70#[cfg(feature = "local")]
71fn default_context_size() -> usize {
72    2048
73}
74
75impl Config {
76    /// Loads the configuration from a TOML file.
77    ///
78    /// # Arguments
79    ///
80    /// * `path` - The path to the TOML file.
81    ///
82    /// # Returns
83    ///
84    /// A `Result` containing the loaded `Config`.
85    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
86        let content = fs::read_to_string(path)
87            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
88
89        let config: Config = toml::from_str(&content)?;
90        Ok(config)
91    }
92
93    /// Creates a new default configuration.
94    pub fn new_default() -> Self {
95        Self {
96            llm: LLMConfig {
97                model_name: "gpt-3.5-turbo".to_string(),
98                base_url: "https://api.openai.com/v1".to_string(),
99                api_key: "your-api-key-here".to_string(),
100                temperature: 0.7,
101                max_tokens: 2048,
102            },
103            #[cfg(feature = "local")]
104            local: None,
105        }
106    }
107
108    /// Loads configuration from a file or falls back to defaults.
109    ///
110    /// This is a convenience method that attempts to load from the specified file
111    /// and returns the default configuration if the file doesn't exist or can't be read.
112    ///
113    /// # Arguments
114    ///
115    /// * `path` - Path to the configuration file (typically "config.toml")
116    ///
117    /// # Returns
118    ///
119    /// The loaded configuration, or default if loading fails
120    ///
121    /// # Example
122    ///
123    /// ```rust,no_run
124    /// use helios_engine::Config;
125    /// let config = Config::load_or_default("config.toml");
126    /// ```
127    pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
128        Self::from_file(path).unwrap_or_else(|_| Self::new_default())
129    }
130
131    /// Creates a new configuration builder for fluent initialization.
132    ///
133    /// # Example
134    ///
135    /// ```rust,no_run
136    /// # use helios_engine::Config;
137    /// let config = Config::builder()
138    ///     .model("gpt-4")
139    ///     .api_key("your-key")
140    ///     .build();
141    /// ```
142    pub fn builder() -> ConfigBuilder {
143        ConfigBuilder::new()
144    }
145
146    /// Saves the configuration to a TOML file.
147    ///
148    /// # Arguments
149    ///
150    /// * `path` - The path to the TOML file.
151    ///
152    /// # Returns
153    ///
154    /// A `Result` indicating success or failure.
155    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
156        let content = toml::to_string_pretty(self)
157            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
158
159        fs::write(path, content)
160            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
161
162        Ok(())
163    }
164}
165
166/// A builder for creating configurations with a fluent API.
167pub struct ConfigBuilder {
168    model_name: String,
169    base_url: String,
170    api_key: String,
171    temperature: f32,
172    max_tokens: u32,
173}
174
175impl ConfigBuilder {
176    /// Creates a new configuration builder with default values.
177    pub fn new() -> Self {
178        Self {
179            model_name: "gpt-3.5-turbo".to_string(),
180            base_url: "https://api.openai.com/v1".to_string(),
181            api_key: std::env::var("OPENAI_API_KEY")
182                .unwrap_or_else(|_| "your-api-key-here".to_string()),
183            temperature: 0.7,
184            max_tokens: 2048,
185        }
186    }
187
188    /// Sets the model name.
189    pub fn model(mut self, model: impl Into<String>) -> Self {
190        self.model_name = model.into();
191        self
192    }
193
194    /// Shorthand: set model with 'm'
195    pub fn m(self, model: impl Into<String>) -> Self {
196        self.model(model)
197    }
198
199    /// Sets the API key.
200    pub fn api_key(mut self, key: impl Into<String>) -> Self {
201        self.api_key = key.into();
202        self
203    }
204
205    /// Shorthand: set API key with 'key'
206    pub fn key(self, key: impl Into<String>) -> Self {
207        self.api_key(key)
208    }
209
210    /// Sets the base URL for the API.
211    pub fn base_url(mut self, url: impl Into<String>) -> Self {
212        self.base_url = url.into();
213        self
214    }
215
216    /// Shorthand: set base URL with 'url'
217    pub fn url(self, url: impl Into<String>) -> Self {
218        self.base_url(url)
219    }
220
221    /// Sets the temperature for generation.
222    pub fn temperature(mut self, temp: f32) -> Self {
223        self.temperature = temp;
224        self
225    }
226
227    /// Shorthand: set temperature with 'temp'
228    pub fn temp(self, temp: f32) -> Self {
229        self.temperature(temp)
230    }
231
232    /// Sets the maximum tokens for generation.
233    pub fn max_tokens(mut self, tokens: u32) -> Self {
234        self.max_tokens = tokens;
235        self
236    }
237
238    /// Shorthand: set max tokens with 'tokens'
239    pub fn tokens(self, tokens: u32) -> Self {
240        self.max_tokens(tokens)
241    }
242
243    /// Builds the configuration.
244    pub fn build(self) -> Config {
245        Config {
246            llm: LLMConfig {
247                model_name: self.model_name,
248                base_url: self.base_url,
249                api_key: self.api_key,
250                temperature: self.temperature,
251                max_tokens: self.max_tokens,
252            },
253            #[cfg(feature = "local")]
254            local: None,
255        }
256    }
257}
258
259impl Default for ConfigBuilder {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use std::fs;
269    use tempfile::tempdir;
270
271    /// Tests loading a configuration from a file.
272    #[test]
273    #[cfg(feature = "local")]
274    fn test_config_from_file() {
275        let config_content = r#"
276[llm]
277model_name = "gpt-4"
278base_url = "https://api.openai.com/v1"
279api_key = "test-key"
280temperature = 0.7
281max_tokens = 2048
282
283[local]
284huggingface_repo = "test/repo"
285model_file = "model.gguf"
286context_size = 4096
287temperature = 0.5
288max_tokens = 1024
289"#;
290        let dir = tempdir().unwrap();
291        let config_path = dir.path().join("config.toml");
292        fs::write(&config_path, config_content).unwrap();
293
294        let config = Config::from_file(&config_path).unwrap();
295        assert_eq!(config.llm.model_name, "gpt-4");
296        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
297    }
298
299    /// Tests loading a configuration from a file without local config.
300    #[test]
301    #[cfg(not(feature = "local"))]
302    fn test_config_from_file() {
303        let config_content = r#"
304[llm]
305model_name = "gpt-4"
306base_url = "https://api.openai.com/v1"
307api_key = "test-key"
308temperature = 0.7
309max_tokens = 2048
310"#;
311        let dir = tempdir().unwrap();
312        let config_path = dir.path().join("config.toml");
313        fs::write(&config_path, config_content).unwrap();
314
315        let config = Config::from_file(&config_path).unwrap();
316        assert_eq!(config.llm.model_name, "gpt-4");
317    }
318
319    /// Tests creating a new default configuration.
320    #[test]
321    fn test_config_new_default() {
322        let config = Config::new_default();
323        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
324        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
325        assert_eq!(config.llm.api_key, "your-api-key-here");
326        assert_eq!(config.llm.temperature, 0.7);
327        assert_eq!(config.llm.max_tokens, 2048);
328        #[cfg(feature = "local")]
329        assert!(config.local.is_none());
330    }
331
332    /// Tests saving a configuration to a file.
333    #[test]
334    fn test_config_save() {
335        let config = Config::new_default();
336        let dir = tempdir().unwrap();
337        let config_path = dir.path().join("config.toml");
338
339        config.save(&config_path).unwrap();
340        assert!(config_path.exists());
341
342        let loaded_config = Config::from_file(&config_path).unwrap();
343        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
344    }
345
346    /// Tests the default value functions.
347    #[test]
348    fn test_default_functions() {
349        assert_eq!(default_temperature(), 0.7);
350        assert_eq!(default_max_tokens(), 2048);
351        #[cfg(feature = "local")]
352        assert_eq!(default_context_size(), 2048);
353    }
354}