helios-engine 0.5.5

A powerful and flexible Rust framework for building LLM-powered agents with tool support, both locally and online
Documentation
//! # Configuration Module
//!
//! This module defines the data structures for configuring the Helios Engine.
//! It includes settings for both remote and local Language Models (LLMs),
//! and provides methods for loading and saving configurations from/to TOML files.

use crate::error::{HeliosError, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;

/// The main configuration for the Helios Engine.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
    /// The configuration for the remote LLM.
    pub llm: LLMConfig,
    /// The configuration for the local LLM (optional).
    #[cfg(feature = "local")]
    #[serde(default)]
    pub local: Option<LocalConfig>,
    /// The configuration for the Candle-based local LLM (optional).
    #[cfg(feature = "candle")]
    #[serde(default)]
    pub candle: Option<CandleConfig>,
}

/// Configuration for a remote Language Model (LLM).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfig {
    /// The name of the model to use.
    pub model_name: String,
    /// The base URL of the LLM API.
    pub base_url: String,
    /// The API key for the LLM API.
    pub api_key: String,
    /// The temperature to use for the LLM.
    #[serde(default = "default_temperature")]
    pub temperature: f32,
    /// The maximum number of tokens to generate.
    #[serde(default = "default_max_tokens")]
    pub max_tokens: u32,
}

/// Configuration for a local Language Model (LLM).
#[cfg(feature = "local")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalConfig {
    /// The Hugging Face repository of the model.
    pub huggingface_repo: String,
    /// The model file to use.
    pub model_file: String,
    /// The context size to use for the LLM.
    #[serde(default = "default_context_size")]
    pub context_size: usize,
    /// The temperature to use for the LLM.
    #[serde(default = "default_temperature")]
    pub temperature: f32,
    /// The maximum number of tokens to generate.
    #[serde(default = "default_max_tokens")]
    pub max_tokens: u32,
}

/// Configuration for a Candle-based local Language Model (LLM).
#[cfg(feature = "candle")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandleConfig {
    /// The Hugging Face repository of the model.
    pub huggingface_repo: String,
    /// The model file to use (e.g., model.safetensors).
    pub model_file: String,
    /// The context size to use for the LLM.
    #[serde(default = "default_context_size")]
    pub context_size: usize,
    /// The temperature to use for the LLM.
    #[serde(default = "default_temperature")]
    pub temperature: f32,
    /// The maximum number of tokens to generate.
    #[serde(default = "default_max_tokens")]
    pub max_tokens: u32,
    /// Whether to use GPU if available.
    #[serde(default = "default_use_gpu")]
    pub use_gpu: bool,
}

/// Returns the default temperature value.
fn default_temperature() -> f32 {
    0.7
}

/// Returns the default maximum number of tokens.
fn default_max_tokens() -> u32 {
    2048
}

/// Returns the default context size.
#[cfg(any(feature = "local", feature = "candle"))]
fn default_context_size() -> usize {
    2048
}

/// Returns the default use_gpu setting.
#[cfg(feature = "candle")]
fn default_use_gpu() -> bool {
    true
}

impl Config {
    /// Loads the configuration from a TOML file.
    ///
    /// # Arguments
    ///
    /// * `path` - The path to the TOML file.
    ///
    /// # Returns
    ///
    /// A `Result` containing the loaded `Config`.
    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
        let content = fs::read_to_string(path)
            .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;

        let config: Config = toml::from_str(&content)?;
        Ok(config)
    }

    /// Creates a new default configuration.
    pub fn new_default() -> Self {
        Self {
            llm: LLMConfig {
                model_name: "gpt-3.5-turbo".to_string(),
                base_url: "https://api.openai.com/v1".to_string(),
                api_key: "your-api-key-here".to_string(),
                temperature: 0.7,
                max_tokens: 2048,
            },
            #[cfg(feature = "local")]
            local: None,
            #[cfg(feature = "candle")]
            candle: None,
        }
    }

    /// Loads configuration from a file or falls back to defaults.
    ///
    /// This is a convenience method that attempts to load from the specified file
    /// and returns the default configuration if the file doesn't exist or can't be read.
    ///
    /// # Arguments
    ///
    /// * `path` - Path to the configuration file (typically "config.toml")
    ///
    /// # Returns
    ///
    /// The loaded configuration, or default if loading fails
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use helios_engine::Config;
    /// let config = Config::load_or_default("config.toml");
    /// ```
    pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
        Self::from_file(path).unwrap_or_else(|_| Self::new_default())
    }

    /// Creates a new configuration builder for fluent initialization.
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// # use helios_engine::Config;
    /// let config = Config::builder()
    ///     .model("gpt-4")
    ///     .api_key("your-key")
    ///     .build();
    /// ```
    pub fn builder() -> ConfigBuilder {
        ConfigBuilder::new()
    }

    /// Saves the configuration to a TOML file.
    ///
    /// # Arguments
    ///
    /// * `path` - The path to the TOML file.
    ///
    /// # Returns
    ///
    /// A `Result` indicating success or failure.
    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
        let content = toml::to_string_pretty(self)
            .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;

        fs::write(path, content)
            .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;

        Ok(())
    }
}

/// A builder for creating configurations with a fluent API.
pub struct ConfigBuilder {
    model_name: String,
    base_url: String,
    api_key: String,
    temperature: f32,
    max_tokens: u32,
}

impl ConfigBuilder {
    /// Creates a new configuration builder with default values.
    pub fn new() -> Self {
        Self {
            model_name: "gpt-3.5-turbo".to_string(),
            base_url: "https://api.openai.com/v1".to_string(),
            api_key: std::env::var("OPENAI_API_KEY")
                .unwrap_or_else(|_| "your-api-key-here".to_string()),
            temperature: 0.7,
            max_tokens: 2048,
        }
    }

    /// Sets the model name.
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model_name = model.into();
        self
    }

    /// Shorthand: set model with 'm'
    pub fn m(self, model: impl Into<String>) -> Self {
        self.model(model)
    }

    /// Sets the API key.
    pub fn api_key(mut self, key: impl Into<String>) -> Self {
        self.api_key = key.into();
        self
    }

    /// Shorthand: set API key with 'key'
    pub fn key(self, key: impl Into<String>) -> Self {
        self.api_key(key)
    }

    /// Sets the base URL for the API.
    pub fn base_url(mut self, url: impl Into<String>) -> Self {
        self.base_url = url.into();
        self
    }

    /// Shorthand: set base URL with 'url'
    pub fn url(self, url: impl Into<String>) -> Self {
        self.base_url(url)
    }

    /// Sets the temperature for generation.
    pub fn temperature(mut self, temp: f32) -> Self {
        self.temperature = temp;
        self
    }

    /// Shorthand: set temperature with 'temp'
    pub fn temp(self, temp: f32) -> Self {
        self.temperature(temp)
    }

    /// Sets the maximum tokens for generation.
    pub fn max_tokens(mut self, tokens: u32) -> Self {
        self.max_tokens = tokens;
        self
    }

    /// Shorthand: set max tokens with 'tokens'
    pub fn tokens(self, tokens: u32) -> Self {
        self.max_tokens(tokens)
    }

    /// Builds the configuration.
    pub fn build(self) -> Config {
        Config {
            llm: LLMConfig {
                model_name: self.model_name,
                base_url: self.base_url,
                api_key: self.api_key,
                temperature: self.temperature,
                max_tokens: self.max_tokens,
            },
            #[cfg(feature = "local")]
            local: None,
            #[cfg(feature = "candle")]
            candle: None,
        }
    }
}

impl Default for ConfigBuilder {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;
    use tempfile::tempdir;

    /// Tests loading a configuration from a file.
    #[test]
    #[cfg(feature = "local")]
    fn test_config_from_file() {
        let config_content = r#"
[llm]
model_name = "gpt-4"
base_url = "https://api.openai.com/v1"
api_key = "test-key"
temperature = 0.7
max_tokens = 2048

[local]
huggingface_repo = "test/repo"
model_file = "model.gguf"
context_size = 4096
temperature = 0.5
max_tokens = 1024
"#;
        let dir = tempdir().unwrap();
        let config_path = dir.path().join("config.toml");
        fs::write(&config_path, config_content).unwrap();

        let config = Config::from_file(&config_path).unwrap();
        assert_eq!(config.llm.model_name, "gpt-4");
        assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
    }

    /// Tests loading a configuration from a file without local config.
    #[test]
    #[cfg(not(feature = "local"))]
    fn test_config_from_file() {
        let config_content = r#"
[llm]
model_name = "gpt-4"
base_url = "https://api.openai.com/v1"
api_key = "test-key"
temperature = 0.7
max_tokens = 2048
"#;
        let dir = tempdir().unwrap();
        let config_path = dir.path().join("config.toml");
        fs::write(&config_path, config_content).unwrap();

        let config = Config::from_file(&config_path).unwrap();
        assert_eq!(config.llm.model_name, "gpt-4");
    }

    /// Tests creating a new default configuration.
    #[test]
    fn test_config_new_default() {
        let config = Config::new_default();
        assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
        assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
        assert_eq!(config.llm.api_key, "your-api-key-here");
        assert_eq!(config.llm.temperature, 0.7);
        assert_eq!(config.llm.max_tokens, 2048);
        #[cfg(feature = "local")]
        assert!(config.local.is_none());
    }

    /// Tests saving a configuration to a file.
    #[test]
    fn test_config_save() {
        let config = Config::new_default();
        let dir = tempdir().unwrap();
        let config_path = dir.path().join("config.toml");

        config.save(&config_path).unwrap();
        assert!(config_path.exists());

        let loaded_config = Config::from_file(&config_path).unwrap();
        assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
    }

    /// Tests the default value functions.
    #[test]
    fn test_default_functions() {
        assert_eq!(default_temperature(), 0.7);
        assert_eq!(default_max_tokens(), 2048);
        #[cfg(feature = "local")]
        assert_eq!(default_context_size(), 2048);
    }
}