use crate::error::{HeliosError, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub llm: LLMConfig,
#[cfg(feature = "local")]
#[serde(default)]
pub local: Option<LocalConfig>,
#[cfg(feature = "candle")]
#[serde(default)]
pub candle: Option<CandleConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfig {
pub model_name: String,
pub base_url: String,
pub api_key: String,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
}
#[cfg(feature = "local")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalConfig {
pub huggingface_repo: String,
pub model_file: String,
#[serde(default = "default_context_size")]
pub context_size: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
}
#[cfg(feature = "candle")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandleConfig {
pub huggingface_repo: String,
pub model_file: String,
#[serde(default = "default_context_size")]
pub context_size: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_use_gpu")]
pub use_gpu: bool,
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> u32 {
2048
}
#[cfg(any(feature = "local", feature = "candle"))]
fn default_context_size() -> usize {
2048
}
#[cfg(feature = "candle")]
fn default_use_gpu() -> bool {
true
}
impl Config {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = fs::read_to_string(path)
.map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
pub fn new_default() -> Self {
Self {
llm: LLMConfig {
model_name: "gpt-3.5-turbo".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
api_key: "your-api-key-here".to_string(),
temperature: 0.7,
max_tokens: 2048,
},
#[cfg(feature = "local")]
local: None,
#[cfg(feature = "candle")]
candle: None,
}
}
pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
Self::from_file(path).unwrap_or_else(|_| Self::new_default())
}
pub fn builder() -> ConfigBuilder {
ConfigBuilder::new()
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let content = toml::to_string_pretty(self)
.map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
fs::write(path, content)
.map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
Ok(())
}
}
pub struct ConfigBuilder {
model_name: String,
base_url: String,
api_key: String,
temperature: f32,
max_tokens: u32,
}
impl ConfigBuilder {
pub fn new() -> Self {
Self {
model_name: "gpt-3.5-turbo".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "your-api-key-here".to_string()),
temperature: 0.7,
max_tokens: 2048,
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model_name = model.into();
self
}
pub fn m(self, model: impl Into<String>) -> Self {
self.model(model)
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = key.into();
self
}
pub fn key(self, key: impl Into<String>) -> Self {
self.api_key(key)
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn url(self, url: impl Into<String>) -> Self {
self.base_url(url)
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
pub fn temp(self, temp: f32) -> Self {
self.temperature(temp)
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = tokens;
self
}
pub fn tokens(self, tokens: u32) -> Self {
self.max_tokens(tokens)
}
pub fn build(self) -> Config {
Config {
llm: LLMConfig {
model_name: self.model_name,
base_url: self.base_url,
api_key: self.api_key,
temperature: self.temperature,
max_tokens: self.max_tokens,
},
#[cfg(feature = "local")]
local: None,
#[cfg(feature = "candle")]
candle: None,
}
}
}
impl Default for ConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
#[cfg(feature = "local")]
fn test_config_from_file() {
let config_content = r#"
[llm]
model_name = "gpt-4"
base_url = "https://api.openai.com/v1"
api_key = "test-key"
temperature = 0.7
max_tokens = 2048
[local]
huggingface_repo = "test/repo"
model_file = "model.gguf"
context_size = 4096
temperature = 0.5
max_tokens = 1024
"#;
let dir = tempdir().unwrap();
let config_path = dir.path().join("config.toml");
fs::write(&config_path, config_content).unwrap();
let config = Config::from_file(&config_path).unwrap();
assert_eq!(config.llm.model_name, "gpt-4");
assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
}
#[test]
#[cfg(not(feature = "local"))]
fn test_config_from_file() {
let config_content = r#"
[llm]
model_name = "gpt-4"
base_url = "https://api.openai.com/v1"
api_key = "test-key"
temperature = 0.7
max_tokens = 2048
"#;
let dir = tempdir().unwrap();
let config_path = dir.path().join("config.toml");
fs::write(&config_path, config_content).unwrap();
let config = Config::from_file(&config_path).unwrap();
assert_eq!(config.llm.model_name, "gpt-4");
}
#[test]
fn test_config_new_default() {
let config = Config::new_default();
assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
assert_eq!(config.llm.api_key, "your-api-key-here");
assert_eq!(config.llm.temperature, 0.7);
assert_eq!(config.llm.max_tokens, 2048);
#[cfg(feature = "local")]
assert!(config.local.is_none());
}
#[test]
fn test_config_save() {
let config = Config::new_default();
let dir = tempdir().unwrap();
let config_path = dir.path().join("config.toml");
config.save(&config_path).unwrap();
assert!(config_path.exists());
let loaded_config = Config::from_file(&config_path).unwrap();
assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
}
#[test]
fn test_default_functions() {
assert_eq!(default_temperature(), 0.7);
assert_eq!(default_max_tokens(), 2048);
#[cfg(feature = "local")]
assert_eq!(default_context_size(), 2048);
}
}