helios_engine/
config.rs

1//! # Configuration Module
2//!
3//! This module defines the data structures for configuring the Helios Engine.
4//! It includes settings for both remote and local Language Models (LLMs),
5//! and provides methods for loading and saving configurations from/to TOML files.
6
7use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12/// The main configuration for the Helios Engine.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15    /// The configuration for the remote LLM.
16    pub llm: LLMConfig,
17    /// The configuration for the local LLM (optional).
18    #[cfg(feature = "local")]
19    #[serde(default)]
20    pub local: Option<LocalConfig>,
21    /// The configuration for the Candle-based local LLM (optional).
22    #[cfg(feature = "candle")]
23    #[serde(default)]
24    pub candle: Option<CandleConfig>,
25}
26
27/// Configuration for a remote Language Model (LLM).
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct LLMConfig {
30    /// The name of the model to use.
31    pub model_name: String,
32    /// The base URL of the LLM API.
33    pub base_url: String,
34    /// The API key for the LLM API.
35    pub api_key: String,
36    /// The temperature to use for the LLM.
37    #[serde(default = "default_temperature")]
38    pub temperature: f32,
39    /// The maximum number of tokens to generate.
40    #[serde(default = "default_max_tokens")]
41    pub max_tokens: u32,
42}
43
44/// Configuration for a local Language Model (LLM).
45#[cfg(feature = "local")]
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LocalConfig {
48    /// The Hugging Face repository of the model.
49    pub huggingface_repo: String,
50    /// The model file to use.
51    pub model_file: String,
52    /// The context size to use for the LLM.
53    #[serde(default = "default_context_size")]
54    pub context_size: usize,
55    /// The temperature to use for the LLM.
56    #[serde(default = "default_temperature")]
57    pub temperature: f32,
58    /// The maximum number of tokens to generate.
59    #[serde(default = "default_max_tokens")]
60    pub max_tokens: u32,
61}
62
63/// Configuration for a Candle-based local Language Model (LLM).
64#[cfg(feature = "candle")]
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct CandleConfig {
67    /// The Hugging Face repository of the model.
68    pub huggingface_repo: String,
69    /// The model file to use (e.g., model.safetensors).
70    pub model_file: String,
71    /// The context size to use for the LLM.
72    #[serde(default = "default_context_size")]
73    pub context_size: usize,
74    /// The temperature to use for the LLM.
75    #[serde(default = "default_temperature")]
76    pub temperature: f32,
77    /// The maximum number of tokens to generate.
78    #[serde(default = "default_max_tokens")]
79    pub max_tokens: u32,
80    /// Whether to use GPU if available.
81    #[serde(default = "default_use_gpu")]
82    pub use_gpu: bool,
83}
84
85/// Returns the default temperature value.
86fn default_temperature() -> f32 {
87    0.7
88}
89
90/// Returns the default maximum number of tokens.
91fn default_max_tokens() -> u32 {
92    2048
93}
94
95/// Returns the default context size.
96#[cfg(any(feature = "local", feature = "candle"))]
97fn default_context_size() -> usize {
98    2048
99}
100
101/// Returns the default use_gpu setting.
102#[cfg(feature = "candle")]
103fn default_use_gpu() -> bool {
104    true
105}
106
107impl Config {
108    /// Loads the configuration from a TOML file.
109    ///
110    /// # Arguments
111    ///
112    /// * `path` - The path to the TOML file.
113    ///
114    /// # Returns
115    ///
116    /// A `Result` containing the loaded `Config`.
117    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
118        let content = fs::read_to_string(path)
119            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
120
121        let config: Config = toml::from_str(&content)?;
122        Ok(config)
123    }
124
125    /// Creates a new default configuration.
126    pub fn new_default() -> Self {
127        Self {
128            llm: LLMConfig {
129                model_name: "gpt-3.5-turbo".to_string(),
130                base_url: "https://api.openai.com/v1".to_string(),
131                api_key: "your-api-key-here".to_string(),
132                temperature: 0.7,
133                max_tokens: 2048,
134            },
135            #[cfg(feature = "local")]
136            local: None,
137            #[cfg(feature = "candle")]
138            candle: None,
139        }
140    }
141
142    /// Loads configuration from a file or falls back to defaults.
143    ///
144    /// This is a convenience method that attempts to load from the specified file
145    /// and returns the default configuration if the file doesn't exist or can't be read.
146    ///
147    /// # Arguments
148    ///
149    /// * `path` - Path to the configuration file (typically "config.toml")
150    ///
151    /// # Returns
152    ///
153    /// The loaded configuration, or default if loading fails
154    ///
155    /// # Example
156    ///
157    /// ```rust,no_run
158    /// use helios_engine::Config;
159    /// let config = Config::load_or_default("config.toml");
160    /// ```
161    pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
162        Self::from_file(path).unwrap_or_else(|_| Self::new_default())
163    }
164
165    /// Creates a new configuration builder for fluent initialization.
166    ///
167    /// # Example
168    ///
169    /// ```rust,no_run
170    /// # use helios_engine::Config;
171    /// let config = Config::builder()
172    ///     .model("gpt-4")
173    ///     .api_key("your-key")
174    ///     .build();
175    /// ```
176    pub fn builder() -> ConfigBuilder {
177        ConfigBuilder::new()
178    }
179
180    /// Saves the configuration to a TOML file.
181    ///
182    /// # Arguments
183    ///
184    /// * `path` - The path to the TOML file.
185    ///
186    /// # Returns
187    ///
188    /// A `Result` indicating success or failure.
189    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
190        let content = toml::to_string_pretty(self)
191            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
192
193        fs::write(path, content)
194            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
195
196        Ok(())
197    }
198}
199
200/// A builder for creating configurations with a fluent API.
201pub struct ConfigBuilder {
202    model_name: String,
203    base_url: String,
204    api_key: String,
205    temperature: f32,
206    max_tokens: u32,
207}
208
209impl ConfigBuilder {
210    /// Creates a new configuration builder with default values.
211    pub fn new() -> Self {
212        Self {
213            model_name: "gpt-3.5-turbo".to_string(),
214            base_url: "https://api.openai.com/v1".to_string(),
215            api_key: std::env::var("OPENAI_API_KEY")
216                .unwrap_or_else(|_| "your-api-key-here".to_string()),
217            temperature: 0.7,
218            max_tokens: 2048,
219        }
220    }
221
222    /// Sets the model name.
223    pub fn model(mut self, model: impl Into<String>) -> Self {
224        self.model_name = model.into();
225        self
226    }
227
228    /// Shorthand: set model with 'm'
229    pub fn m(self, model: impl Into<String>) -> Self {
230        self.model(model)
231    }
232
233    /// Sets the API key.
234    pub fn api_key(mut self, key: impl Into<String>) -> Self {
235        self.api_key = key.into();
236        self
237    }
238
239    /// Shorthand: set API key with 'key'
240    pub fn key(self, key: impl Into<String>) -> Self {
241        self.api_key(key)
242    }
243
244    /// Sets the base URL for the API.
245    pub fn base_url(mut self, url: impl Into<String>) -> Self {
246        self.base_url = url.into();
247        self
248    }
249
250    /// Shorthand: set base URL with 'url'
251    pub fn url(self, url: impl Into<String>) -> Self {
252        self.base_url(url)
253    }
254
255    /// Sets the temperature for generation.
256    pub fn temperature(mut self, temp: f32) -> Self {
257        self.temperature = temp;
258        self
259    }
260
261    /// Shorthand: set temperature with 'temp'
262    pub fn temp(self, temp: f32) -> Self {
263        self.temperature(temp)
264    }
265
266    /// Sets the maximum tokens for generation.
267    pub fn max_tokens(mut self, tokens: u32) -> Self {
268        self.max_tokens = tokens;
269        self
270    }
271
272    /// Shorthand: set max tokens with 'tokens'
273    pub fn tokens(self, tokens: u32) -> Self {
274        self.max_tokens(tokens)
275    }
276
277    /// Builds the configuration.
278    pub fn build(self) -> Config {
279        Config {
280            llm: LLMConfig {
281                model_name: self.model_name,
282                base_url: self.base_url,
283                api_key: self.api_key,
284                temperature: self.temperature,
285                max_tokens: self.max_tokens,
286            },
287            #[cfg(feature = "local")]
288            local: None,
289            #[cfg(feature = "candle")]
290            candle: None,
291        }
292    }
293}
294
295impl Default for ConfigBuilder {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use std::fs;
305    use tempfile::tempdir;
306
307    /// Tests loading a configuration from a file.
308    #[test]
309    #[cfg(feature = "local")]
310    fn test_config_from_file() {
311        let config_content = r#"
312[llm]
313model_name = "gpt-4"
314base_url = "https://api.openai.com/v1"
315api_key = "test-key"
316temperature = 0.7
317max_tokens = 2048
318
319[local]
320huggingface_repo = "test/repo"
321model_file = "model.gguf"
322context_size = 4096
323temperature = 0.5
324max_tokens = 1024
325"#;
326        let dir = tempdir().unwrap();
327        let config_path = dir.path().join("config.toml");
328        fs::write(&config_path, config_content).unwrap();
329
330        let config = Config::from_file(&config_path).unwrap();
331        assert_eq!(config.llm.model_name, "gpt-4");
332        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
333    }
334
335    /// Tests loading a configuration from a file without local config.
336    #[test]
337    #[cfg(not(feature = "local"))]
338    fn test_config_from_file() {
339        let config_content = r#"
340[llm]
341model_name = "gpt-4"
342base_url = "https://api.openai.com/v1"
343api_key = "test-key"
344temperature = 0.7
345max_tokens = 2048
346"#;
347        let dir = tempdir().unwrap();
348        let config_path = dir.path().join("config.toml");
349        fs::write(&config_path, config_content).unwrap();
350
351        let config = Config::from_file(&config_path).unwrap();
352        assert_eq!(config.llm.model_name, "gpt-4");
353    }
354
355    /// Tests creating a new default configuration.
356    #[test]
357    fn test_config_new_default() {
358        let config = Config::new_default();
359        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
360        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
361        assert_eq!(config.llm.api_key, "your-api-key-here");
362        assert_eq!(config.llm.temperature, 0.7);
363        assert_eq!(config.llm.max_tokens, 2048);
364        #[cfg(feature = "local")]
365        assert!(config.local.is_none());
366    }
367
368    /// Tests saving a configuration to a file.
369    #[test]
370    fn test_config_save() {
371        let config = Config::new_default();
372        let dir = tempdir().unwrap();
373        let config_path = dir.path().join("config.toml");
374
375        config.save(&config_path).unwrap();
376        assert!(config_path.exists());
377
378        let loaded_config = Config::from_file(&config_path).unwrap();
379        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
380    }
381
382    /// Tests the default value functions.
383    #[test]
384    fn test_default_functions() {
385        assert_eq!(default_temperature(), 0.7);
386        assert_eq!(default_max_tokens(), 2048);
387        #[cfg(feature = "local")]
388        assert_eq!(default_context_size(), 2048);
389    }
390}