helios_engine/
config.rs

1//! # Configuration Module
2//!
3//! This module defines the data structures for configuring the Helios Engine.
4//! It includes settings for both remote and local Language Models (LLMs),
5//! and provides methods for loading and saving configurations from/to TOML files.
6
7use crate::error::{HeliosError, Result};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::Path;
11
12/// The main configuration for the Helios Engine.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Config {
15    /// The configuration for the remote LLM.
16    pub llm: LLMConfig,
17    /// The configuration for the local LLM (optional).
18    #[serde(default)]
19    pub local: Option<LocalConfig>,
20}
21
22/// Configuration for a remote Language Model (LLM).
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LLMConfig {
25    /// The name of the model to use.
26    pub model_name: String,
27    /// The base URL of the LLM API.
28    pub base_url: String,
29    /// The API key for the LLM API.
30    pub api_key: String,
31    /// The temperature to use for the LLM.
32    #[serde(default = "default_temperature")]
33    pub temperature: f32,
34    /// The maximum number of tokens to generate.
35    #[serde(default = "default_max_tokens")]
36    pub max_tokens: u32,
37}
38
39/// Configuration for a local Language Model (LLM).
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LocalConfig {
42    /// The Hugging Face repository of the model.
43    pub huggingface_repo: String,
44    /// The model file to use.
45    pub model_file: String,
46    /// The context size to use for the LLM.
47    #[serde(default = "default_context_size")]
48    pub context_size: usize,
49    /// The temperature to use for the LLM.
50    #[serde(default = "default_temperature")]
51    pub temperature: f32,
52    /// The maximum number of tokens to generate.
53    #[serde(default = "default_max_tokens")]
54    pub max_tokens: u32,
55}
56
57/// Returns the default temperature value.
58fn default_temperature() -> f32 {
59    0.7
60}
61
62/// Returns the default maximum number of tokens.
63fn default_max_tokens() -> u32 {
64    2048
65}
66
67/// Returns the default context size.
68fn default_context_size() -> usize {
69    2048
70}
71
72impl Config {
73    /// Loads the configuration from a TOML file.
74    ///
75    /// # Arguments
76    ///
77    /// * `path` - The path to the TOML file.
78    ///
79    /// # Returns
80    ///
81    /// A `Result` containing the loaded `Config`.
82    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
83        let content = fs::read_to_string(path)
84            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
85
86        let config: Config = toml::from_str(&content)?;
87        Ok(config)
88    }
89
90    /// Creates a new default configuration.
91    pub fn new_default() -> Self {
92        Self {
93            llm: LLMConfig {
94                model_name: "gpt-3.5-turbo".to_string(),
95                base_url: "https://api.openai.com/v1".to_string(),
96                api_key: "your-api-key-here".to_string(),
97                temperature: 0.7,
98                max_tokens: 2048,
99            },
100            local: None,
101        }
102    }
103
104    /// Saves the configuration to a TOML file.
105    ///
106    /// # Arguments
107    ///
108    /// * `path` - The path to the TOML file.
109    ///
110    /// # Returns
111    ///
112    /// A `Result` indicating success or failure.
113    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
114        let content = toml::to_string_pretty(self)
115            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
116
117        fs::write(path, content)
118            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
119
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::fs;
128    use tempfile::tempdir;
129
130    /// Tests loading a configuration from a file.
131    #[test]
132    fn test_config_from_file() {
133        let config_content = r#"
134[llm]
135model_name = "gpt-4"
136base_url = "https://api.openai.com/v1"
137api_key = "test-key"
138temperature = 0.7
139max_tokens = 2048
140
141[local]
142huggingface_repo = "test/repo"
143model_file = "model.gguf"
144context_size = 4096
145temperature = 0.5
146max_tokens = 1024
147"#;
148        let dir = tempdir().unwrap();
149        let config_path = dir.path().join("config.toml");
150        fs::write(&config_path, config_content).unwrap();
151
152        let config = Config::from_file(&config_path).unwrap();
153        assert_eq!(config.llm.model_name, "gpt-4");
154        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
155    }
156
157    /// Tests creating a new default configuration.
158    #[test]
159    fn test_config_new_default() {
160        let config = Config::new_default();
161        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
162        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
163        assert_eq!(config.llm.api_key, "your-api-key-here");
164        assert_eq!(config.llm.temperature, 0.7);
165        assert_eq!(config.llm.max_tokens, 2048);
166        assert!(config.local.is_none());
167    }
168
169    /// Tests saving a configuration to a file.
170    #[test]
171    fn test_config_save() {
172        let config = Config::new_default();
173        let dir = tempdir().unwrap();
174        let config_path = dir.path().join("config.toml");
175
176        config.save(&config_path).unwrap();
177        assert!(config_path.exists());
178
179        let loaded_config = Config::from_file(&config_path).unwrap();
180        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
181    }
182
183    /// Tests the default value functions.
184    #[test]
185    fn test_default_functions() {
186        assert_eq!(default_temperature(), 0.7);
187        assert_eq!(default_max_tokens(), 2048);
188        assert_eq!(default_context_size(), 2048);
189    }
190}