1use crate::error::{HeliosError, Result};
2use serde::{Deserialize, Serialize};
3use std::fs;
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8 pub llm: LLMConfig,
9 #[serde(default)]
10 pub local: Option<LocalConfig>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LLMConfig {
15 pub model_name: String,
16 pub base_url: String,
17 pub api_key: String,
18 #[serde(default = "default_temperature")]
19 pub temperature: f32,
20 #[serde(default = "default_max_tokens")]
21 pub max_tokens: u32,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LocalConfig {
26 pub huggingface_repo: String,
27 pub model_file: String,
28 #[serde(default = "default_context_size")]
29 pub context_size: usize,
30 #[serde(default = "default_temperature")]
31 pub temperature: f32,
32 #[serde(default = "default_max_tokens")]
33 pub max_tokens: u32,
34}
35
36fn default_temperature() -> f32 {
37 0.7
38}
39
40fn default_max_tokens() -> u32 {
41 2048
42}
43
44fn default_context_size() -> usize {
45 2048
46}
47
48impl Config {
49 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
51 let content = fs::read_to_string(path)
52 .map_err(|e| HeliosError::ConfigError(format!("Failed to read config file: {}", e)))?;
53
54 let config: Config = toml::from_str(&content)?;
55 Ok(config)
56 }
57
58 pub fn new_default() -> Self {
60 Self {
61 llm: LLMConfig {
62 model_name: "gpt-3.5-turbo".to_string(),
63 base_url: "https://api.openai.com/v1".to_string(),
64 api_key: "your-api-key-here".to_string(),
65 temperature: 0.7,
66 max_tokens: 2048,
67 },
68 local: None,
69 }
70 }
71
72 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
74 let content = toml::to_string_pretty(self)
75 .map_err(|e| HeliosError::ConfigError(format!("Failed to serialize config: {}", e)))?;
76
77 fs::write(path, content)
78 .map_err(|e| HeliosError::ConfigError(format!("Failed to write config file: {}", e)))?;
79
80 Ok(())
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use std::fs;
88 use tempfile::tempdir;
89
90 #[test]
91 fn test_config_from_file() {
92 let config_content = r#"
93[llm]
94model_name = "gpt-4"
95base_url = "https://api.openai.com/v1"
96api_key = "test-key"
97temperature = 0.7
98max_tokens = 2048
99
100[local]
101huggingface_repo = "test/repo"
102model_file = "model.gguf"
103context_size = 4096
104temperature = 0.5
105max_tokens = 1024
106"#;
107 let dir = tempdir().unwrap();
108 let config_path = dir.path().join("config.toml");
109 fs::write(&config_path, config_content).unwrap();
110
111 let config = Config::from_file(&config_path).unwrap();
112 assert_eq!(config.llm.model_name, "gpt-4");
113 assert_eq!(config.local.as_ref().unwrap().huggingface_repo, "test/repo");
114 }
115
116 #[test]
117 fn test_config_new_default() {
118 let config = Config::new_default();
119 assert_eq!(config.llm.model_name, "gpt-3.5-turbo");
120 assert_eq!(config.llm.base_url, "https://api.openai.com/v1");
121 assert_eq!(config.llm.api_key, "your-api-key-here");
122 assert_eq!(config.llm.temperature, 0.7);
123 assert_eq!(config.llm.max_tokens, 2048);
124 assert!(config.local.is_none());
125 }
126
127 #[test]
128 fn test_config_save() {
129 let config = Config::new_default();
130 let dir = tempdir().unwrap();
131 let config_path = dir.path().join("config.toml");
132
133 config.save(&config_path).unwrap();
134 assert!(config_path.exists());
135
136 let loaded_config = Config::from_file(&config_path).unwrap();
137 assert_eq!(loaded_config.llm.model_name, config.llm.model_name);
138 }
139
140 #[test]
141 fn test_default_functions() {
142 assert_eq!(default_temperature(), 0.7);
143 assert_eq!(default_max_tokens(), 2048);
144 assert_eq!(default_context_size(), 2048);
145 }
146}