Skip to main content

ferrum_cli/
config.rs

1//! CLI configuration management
2//!
3//! Handles loading and parsing of configuration files for the CLI tool.
4
5use ferrum_types::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9use tokio::fs;
10
11/// CLI configuration
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CliConfig {
14    /// Server configuration
15    pub server: ServerCliConfig,
16
17    /// Model configuration
18    pub models: ModelCliConfig,
19
20    /// Benchmark configuration
21    pub benchmark: BenchmarkConfig,
22
23    /// Client configuration
24    pub client: ClientConfig,
25
26    /// Development configuration
27    pub dev: DevConfig,
28}
29
30/// Server CLI configuration
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ServerCliConfig {
33    /// Default host
34    pub host: String,
35
36    /// Default port
37    pub port: u16,
38
39    /// Configuration file path
40    pub config_path: String,
41
42    /// Log level
43    pub log_level: String,
44
45    /// Enable hot reload
46    pub hot_reload: bool,
47}
48
49/// Model CLI configuration
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ModelCliConfig {
52    /// Default model directory
53    pub model_dir: String,
54
55    /// Model cache directory
56    pub cache_dir: String,
57
58    /// Default model
59    pub default_model: Option<String>,
60
61    /// Model aliases
62    pub aliases: HashMap<String, String>,
63
64    /// Download settings
65    pub download: DownloadConfig,
66}
67
68/// Download configuration
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct DownloadConfig {
71    /// HuggingFace cache directory
72    pub hf_cache_dir: String,
73
74    /// Download timeout in seconds
75    pub timeout_seconds: u64,
76
77    /// Max concurrent downloads
78    pub max_concurrent: usize,
79
80    /// Retry attempts
81    pub retry_attempts: u32,
82}
83
84/// Benchmark configuration
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct BenchmarkConfig {
87    /// Default number of requests
88    pub num_requests: usize,
89
90    /// Default concurrency level
91    pub concurrency: usize,
92
93    /// Default prompt length
94    pub prompt_length: usize,
95
96    /// Default max tokens
97    pub max_tokens: usize,
98
99    /// Warmup requests
100    pub warmup_requests: usize,
101
102    /// Output directory for reports
103    pub output_dir: String,
104}
105
106/// Client configuration
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ClientConfig {
109    /// Default API base URL
110    pub base_url: String,
111
112    /// Default API key
113    pub api_key: Option<String>,
114
115    /// Request timeout
116    pub timeout_seconds: u64,
117
118    /// Retry configuration
119    pub retry: RetryConfig,
120}
121
122/// Retry configuration
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct RetryConfig {
125    /// Maximum retry attempts
126    pub max_attempts: u32,
127
128    /// Initial delay in milliseconds
129    pub initial_delay_ms: u64,
130
131    /// Maximum delay in milliseconds
132    pub max_delay_ms: u64,
133
134    /// Backoff multiplier
135    pub backoff_multiplier: f64,
136}
137
138/// Development configuration
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct DevConfig {
141    /// Enable debug mode
142    pub debug: bool,
143
144    /// Profile memory usage
145    pub profile_memory: bool,
146
147    /// Enable GPU profiling
148    pub profile_gpu: bool,
149
150    /// Mock backends for testing
151    pub mock_backends: bool,
152
153    /// Test data directory
154    pub test_data_dir: String,
155}
156
157impl CliConfig {
158    /// Load configuration from file
159    pub async fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
160        let path = path.as_ref();
161
162        if !path.exists() {
163            // Create default config file
164            let default_config = Self::default();
165            let content = toml::to_string_pretty(&default_config).map_err(|e| {
166                ferrum_types::FerrumError::configuration(format!(
167                    "Failed to serialize default config: {}",
168                    e
169                ))
170            })?;
171
172            if let Some(parent) = path.parent() {
173                fs::create_dir_all(parent).await.map_err(|e| {
174                    ferrum_types::FerrumError::io_str(format!(
175                        "Failed to create config directory: {}",
176                        e
177                    ))
178                })?;
179            }
180
181            fs::write(path, content).await.map_err(|e| {
182                ferrum_types::FerrumError::io_str(format!("Failed to write default config: {}", e))
183            })?;
184
185            return Ok(default_config);
186        }
187
188        let content = fs::read_to_string(path).await.map_err(|e| {
189            ferrum_types::FerrumError::io_str(format!("Failed to read config file: {}", e))
190        })?;
191
192        toml::from_str(&content).map_err(|e| {
193            ferrum_types::FerrumError::configuration(format!("Failed to parse config: {}", e))
194        })
195    }
196
197    /// Save configuration to file
198    pub async fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
199        let content = toml::to_string_pretty(self).map_err(|e| {
200            ferrum_types::FerrumError::configuration(format!("Failed to serialize config: {}", e))
201        })?;
202
203        fs::write(path, content).await.map_err(|e| {
204            ferrum_types::FerrumError::io_str(format!("Failed to write config file: {}", e))
205        })
206    }
207
208    /// Validate configuration
209    pub fn validate(&self) -> Result<()> {
210        // Validate server config
211        if self.server.port == 0 {
212            return Err(ferrum_types::FerrumError::configuration(
213                "Server port cannot be 0".to_string(),
214            ));
215        }
216
217        // Validate model config
218        if !Path::new(&self.models.model_dir).exists() {
219            return Err(ferrum_types::FerrumError::configuration(format!(
220                "Model directory does not exist: {}",
221                self.models.model_dir
222            )));
223        }
224
225        // Validate benchmark config
226        if self.benchmark.num_requests == 0 {
227            return Err(ferrum_types::FerrumError::configuration(
228                "Number of requests cannot be 0".to_string(),
229            ));
230        }
231
232        if self.benchmark.concurrency == 0 {
233            return Err(ferrum_types::FerrumError::configuration(
234                "Concurrency cannot be 0".to_string(),
235            ));
236        }
237
238        Ok(())
239    }
240}
241
242impl Default for CliConfig {
243    fn default() -> Self {
244        Self {
245            server: ServerCliConfig::default(),
246            models: ModelCliConfig::default(),
247            benchmark: BenchmarkConfig::default(),
248            client: ClientConfig::default(),
249            dev: DevConfig::default(),
250        }
251    }
252}
253
254impl Default for ServerCliConfig {
255    fn default() -> Self {
256        Self {
257            host: "127.0.0.1".to_string(),
258            port: 8000,
259            config_path: "server.toml".to_string(),
260            log_level: "info".to_string(),
261            hot_reload: false,
262        }
263    }
264}
265
266impl Default for ModelCliConfig {
267    fn default() -> Self {
268        Self {
269            model_dir: "./models".to_string(),
270            cache_dir: "./cache".to_string(),
271            default_model: None,
272            aliases: HashMap::new(),
273            download: DownloadConfig::default(),
274        }
275    }
276}
277
278impl Default for DownloadConfig {
279    fn default() -> Self {
280        Self {
281            hf_cache_dir: std::env::var("HF_HOME")
282                .ok()
283                .or_else(|| {
284                    dirs::home_dir()
285                        .map(|h| h.join(".cache/huggingface").to_string_lossy().to_string())
286                })
287                .unwrap_or_else(|| "./hf_cache".to_string()),
288            timeout_seconds: 300,
289            max_concurrent: 4,
290            retry_attempts: 3,
291        }
292    }
293}
294
295impl Default for BenchmarkConfig {
296    fn default() -> Self {
297        Self {
298            num_requests: 100,
299            concurrency: 10,
300            prompt_length: 512,
301            max_tokens: 256,
302            warmup_requests: 10,
303            output_dir: "./benchmark_results".to_string(),
304        }
305    }
306}
307
308impl Default for ClientConfig {
309    fn default() -> Self {
310        Self {
311            base_url: "http://127.0.0.1:8000".to_string(),
312            api_key: None,
313            timeout_seconds: 30,
314            retry: RetryConfig::default(),
315        }
316    }
317}
318
319impl Default for RetryConfig {
320    fn default() -> Self {
321        Self {
322            max_attempts: 3,
323            initial_delay_ms: 100,
324            max_delay_ms: 5000,
325            backoff_multiplier: 2.0,
326        }
327    }
328}
329
330impl Default for DevConfig {
331    fn default() -> Self {
332        Self {
333            debug: false,
334            profile_memory: false,
335            profile_gpu: false,
336            mock_backends: false,
337            test_data_dir: "./test_data".to_string(),
338        }
339    }
340}