Skip to main content

oxibonsai_runtime/
config.rs

1//! Layered configuration system for OxiBonsai.
2//!
3//! Loading order: defaults → TOML file → CLI argument overrides.
4
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8use crate::error::{RuntimeError, RuntimeResult};
9
10/// Top-level OxiBonsai configuration.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[serde(default)]
13pub struct OxiBonsaiConfig {
14    /// Server configuration.
15    pub server: ServerConfig,
16    /// Sampling parameters.
17    pub sampling: SamplingConfig,
18    /// Model paths and limits.
19    pub model: ModelConfig,
20    /// Observability settings.
21    pub observability: ObservabilityConfig,
22}
23
24/// HTTP server configuration.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(default)]
27pub struct ServerConfig {
28    /// Bind host address.
29    pub host: String,
30    /// Bind port.
31    pub port: u16,
32}
33
34/// Sampling parameters configuration.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(default)]
37pub struct SamplingConfig {
38    /// Temperature for softmax scaling. 0.0 = greedy.
39    pub temperature: f32,
40    /// Top-k filtering (0 = disabled).
41    pub top_k: usize,
42    /// Top-p (nucleus) threshold (1.0 = disabled).
43    pub top_p: f32,
44    /// Repetition penalty (1.0 = disabled).
45    pub repetition_penalty: f32,
46    /// Maximum tokens to generate.
47    pub max_tokens: usize,
48}
49
50/// Model configuration.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(default)]
53pub struct ModelConfig {
54    /// Path to the GGUF model file.
55    pub model_path: Option<String>,
56    /// Path to tokenizer.json file.
57    pub tokenizer_path: Option<String>,
58    /// Maximum sequence length (prompt + generated).
59    pub max_seq_len: usize,
60}
61
62/// Observability configuration.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(default)]
65pub struct ObservabilityConfig {
66    /// Log level filter (e.g. "info", "debug", "warn").
67    pub log_level: String,
68    /// Whether to emit JSON-formatted logs.
69    pub json_logs: bool,
70}
71
72impl Default for ServerConfig {
73    fn default() -> Self {
74        Self {
75            host: "0.0.0.0".to_string(),
76            port: 8080,
77        }
78    }
79}
80
81impl Default for SamplingConfig {
82    fn default() -> Self {
83        Self {
84            temperature: 0.7,
85            top_k: 40,
86            top_p: 0.9,
87            repetition_penalty: 1.1,
88            max_tokens: 512,
89        }
90    }
91}
92
93impl Default for ModelConfig {
94    fn default() -> Self {
95        Self {
96            model_path: None,
97            tokenizer_path: None,
98            max_seq_len: 4096,
99        }
100    }
101}
102
103impl Default for ObservabilityConfig {
104    fn default() -> Self {
105        Self {
106            log_level: "info".to_string(),
107            json_logs: false,
108        }
109    }
110}
111
112/// Severity level for configuration warnings.
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum WarningSeverity {
115    /// Informational only.
116    Info,
117    /// May cause suboptimal behavior.
118    Warning,
119    /// Will likely cause failures.
120    Error,
121}
122
123impl std::fmt::Display for WarningSeverity {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            Self::Info => write!(f, "info"),
127            Self::Warning => write!(f, "warning"),
128            Self::Error => write!(f, "error"),
129        }
130    }
131}
132
133/// A warning about a configuration value.
134#[derive(Debug, Clone)]
135pub struct ConfigWarning {
136    /// Which configuration field this warning applies to.
137    pub field: String,
138    /// Human-readable warning message.
139    pub message: String,
140    /// Severity of this warning.
141    pub severity: WarningSeverity,
142}
143
144impl std::fmt::Display for ConfigWarning {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(f, "[{}] {}: {}", self.severity, self.field, self.message)
147    }
148}
149
150impl OxiBonsaiConfig {
151    /// Load configuration from a TOML file.
152    pub fn load(path: &Path) -> RuntimeResult<Self> {
153        let content = std::fs::read_to_string(path).map_err(|e| {
154            RuntimeError::Config(format!(
155                "failed to read config file {}: {e}",
156                path.display()
157            ))
158        })?;
159        let config: Self = toml::from_str(&content).map_err(|e| {
160            RuntimeError::Config(format!(
161                "failed to parse config file {}: {e}",
162                path.display()
163            ))
164        })?;
165        Ok(config)
166    }
167
168    /// Load configuration from a TOML file if a path is given, otherwise return defaults.
169    pub fn load_or_default(path: Option<&Path>) -> Self {
170        match path {
171            Some(p) => match Self::load(p) {
172                Ok(cfg) => cfg,
173                Err(e) => {
174                    tracing::warn!(error = %e, "failed to load config, using defaults");
175                    Self::default()
176                }
177            },
178            None => Self::default(),
179        }
180    }
181
182    /// Validate this configuration, returning an error if any field is invalid.
183    pub fn validate(&self) -> RuntimeResult<()> {
184        if self.sampling.temperature < 0.0 {
185            return Err(RuntimeError::Config(format!(
186                "sampling.temperature must be >= 0.0, got {}",
187                self.sampling.temperature
188            )));
189        }
190        if self.sampling.top_p < 0.0 || self.sampling.top_p > 1.0 {
191            return Err(RuntimeError::Config(format!(
192                "sampling.top_p must be in [0.0, 1.0], got {}",
193                self.sampling.top_p
194            )));
195        }
196        if self.sampling.repetition_penalty < 1.0 {
197            return Err(RuntimeError::Config(format!(
198                "sampling.repetition_penalty must be >= 1.0, got {}",
199                self.sampling.repetition_penalty
200            )));
201        }
202        if self.sampling.max_tokens == 0 {
203            return Err(RuntimeError::Config(
204                "sampling.max_tokens must be > 0".to_string(),
205            ));
206        }
207        if self.model.max_seq_len == 0 {
208            return Err(RuntimeError::Config(
209                "model.max_seq_len must be > 0".to_string(),
210            ));
211        }
212        if self.server.host.is_empty() {
213            return Err(RuntimeError::Config(
214                "server.host must not be empty".to_string(),
215            ));
216        }
217        // Port 0 is technically valid (OS assigns), so no check needed
218        Ok(())
219    }
220
221    /// Run a dry-run check of this configuration.
222    ///
223    /// Returns warnings about potential issues without stopping execution.
224    /// Checks for model file existence, tokenizer existence, and
225    /// reasonable parameter values.
226    pub fn dry_run_check(&self) -> Vec<ConfigWarning> {
227        let mut warnings = Vec::new();
228
229        // Check model file
230        match &self.model.model_path {
231            None => {
232                warnings.push(ConfigWarning {
233                    field: "model.model_path".to_string(),
234                    message: "no model path configured".to_string(),
235                    severity: WarningSeverity::Warning,
236                });
237            }
238            Some(path) => {
239                if !Path::new(path).exists() {
240                    warnings.push(ConfigWarning {
241                        field: "model.model_path".to_string(),
242                        message: format!("model file does not exist: {}", path),
243                        severity: WarningSeverity::Error,
244                    });
245                }
246            }
247        }
248
249        // Check tokenizer file
250        match &self.model.tokenizer_path {
251            None => {
252                warnings.push(ConfigWarning {
253                    field: "model.tokenizer_path".to_string(),
254                    message: "no tokenizer path configured; token IDs will be used".to_string(),
255                    severity: WarningSeverity::Info,
256                });
257            }
258            Some(path) => {
259                if !Path::new(path).exists() {
260                    warnings.push(ConfigWarning {
261                        field: "model.tokenizer_path".to_string(),
262                        message: format!("tokenizer file does not exist: {}", path),
263                        severity: WarningSeverity::Error,
264                    });
265                }
266            }
267        }
268
269        // Check sequence length
270        if self.model.max_seq_len > 65536 {
271            warnings.push(ConfigWarning {
272                field: "model.max_seq_len".to_string(),
273                message: format!(
274                    "very large max_seq_len ({}); may require significant memory",
275                    self.model.max_seq_len
276                ),
277                severity: WarningSeverity::Warning,
278            });
279        }
280
281        // Check temperature
282        if self.sampling.temperature > 2.0 {
283            warnings.push(ConfigWarning {
284                field: "sampling.temperature".to_string(),
285                message: format!(
286                    "high temperature ({}) may produce incoherent output",
287                    self.sampling.temperature
288                ),
289                severity: WarningSeverity::Warning,
290            });
291        }
292
293        warnings
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn default_values() {
303        let cfg = OxiBonsaiConfig::default();
304        assert_eq!(cfg.server.host, "0.0.0.0");
305        assert_eq!(cfg.server.port, 8080);
306        assert!((cfg.sampling.temperature - 0.7).abs() < f32::EPSILON);
307        assert_eq!(cfg.sampling.top_k, 40);
308        assert!((cfg.sampling.top_p - 0.9).abs() < f32::EPSILON);
309        assert!((cfg.sampling.repetition_penalty - 1.1).abs() < f32::EPSILON);
310        assert_eq!(cfg.sampling.max_tokens, 512);
311        assert_eq!(cfg.model.max_seq_len, 4096);
312        assert!(cfg.model.model_path.is_none());
313        assert!(cfg.model.tokenizer_path.is_none());
314        assert_eq!(cfg.observability.log_level, "info");
315        assert!(!cfg.observability.json_logs);
316    }
317
318    #[test]
319    fn toml_parsing() {
320        let model_path = std::env::temp_dir().join("model.gguf");
321        let tokenizer_path = std::env::temp_dir().join("tokenizer.json");
322        let toml_str = format!(
323            r#"
324[server]
325host = "127.0.0.1"
326port = 3000
327
328[sampling]
329temperature = 0.5
330top_k = 50
331top_p = 0.95
332repetition_penalty = 1.2
333max_tokens = 1024
334
335[model]
336model_path = "{}"
337tokenizer_path = "{}"
338max_seq_len = 8192
339
340[observability]
341log_level = "debug"
342json_logs = true
343"#,
344            model_path.display(),
345            tokenizer_path.display()
346        );
347        let cfg: OxiBonsaiConfig = toml::from_str(&toml_str).expect("should parse valid TOML");
348        assert_eq!(cfg.server.host, "127.0.0.1");
349        assert_eq!(cfg.server.port, 3000);
350        assert!((cfg.sampling.temperature - 0.5).abs() < f32::EPSILON);
351        assert_eq!(cfg.sampling.top_k, 50);
352        assert_eq!(cfg.sampling.max_tokens, 1024);
353        assert_eq!(
354            cfg.model.model_path.as_deref(),
355            Some(model_path.to_str().expect("path is valid UTF-8"))
356        );
357        assert_eq!(cfg.model.max_seq_len, 8192);
358        assert_eq!(cfg.observability.log_level, "debug");
359        assert!(cfg.observability.json_logs);
360    }
361
362    #[test]
363    fn partial_toml_uses_defaults() {
364        let toml_str = r#"
365[server]
366port = 9090
367"#;
368        let cfg: OxiBonsaiConfig = toml::from_str(toml_str).expect("should parse partial TOML");
369        assert_eq!(cfg.server.port, 9090);
370        // Rest should be defaults
371        assert_eq!(cfg.server.host, "0.0.0.0");
372        assert!((cfg.sampling.temperature - 0.7).abs() < f32::EPSILON);
373        assert_eq!(cfg.model.max_seq_len, 4096);
374    }
375
376    #[test]
377    fn missing_file_returns_default() {
378        let path = std::env::temp_dir().join("nonexistent_oxibonsai_config_12345.toml");
379        let cfg = OxiBonsaiConfig::load_or_default(Some(&path));
380        assert_eq!(cfg.server.port, 8080);
381    }
382
383    #[test]
384    fn load_or_default_none_returns_default() {
385        let cfg = OxiBonsaiConfig::load_or_default(None);
386        assert_eq!(cfg.server.host, "0.0.0.0");
387    }
388
389    // ── Validation tests ──
390
391    #[test]
392    fn validate_defaults_ok() {
393        let cfg = OxiBonsaiConfig::default();
394        assert!(cfg.validate().is_ok());
395    }
396
397    #[test]
398    fn validate_negative_temperature() {
399        let mut cfg = OxiBonsaiConfig::default();
400        cfg.sampling.temperature = -1.0;
401        assert!(cfg.validate().is_err());
402    }
403
404    #[test]
405    fn validate_top_p_out_of_range() {
406        let mut cfg = OxiBonsaiConfig::default();
407        cfg.sampling.top_p = 1.5;
408        assert!(cfg.validate().is_err());
409
410        cfg.sampling.top_p = -0.1;
411        assert!(cfg.validate().is_err());
412    }
413
414    #[test]
415    fn validate_repetition_penalty_too_low() {
416        let mut cfg = OxiBonsaiConfig::default();
417        cfg.sampling.repetition_penalty = 0.5;
418        assert!(cfg.validate().is_err());
419    }
420
421    #[test]
422    fn validate_max_tokens_zero() {
423        let mut cfg = OxiBonsaiConfig::default();
424        cfg.sampling.max_tokens = 0;
425        assert!(cfg.validate().is_err());
426    }
427
428    #[test]
429    fn validate_max_seq_len_zero() {
430        let mut cfg = OxiBonsaiConfig::default();
431        cfg.model.max_seq_len = 0;
432        assert!(cfg.validate().is_err());
433    }
434
435    #[test]
436    fn validate_empty_host() {
437        let mut cfg = OxiBonsaiConfig::default();
438        cfg.server.host = String::new();
439        assert!(cfg.validate().is_err());
440    }
441
442    // ── Dry-run check tests ──
443
444    #[test]
445    fn dry_run_no_model_path() {
446        let cfg = OxiBonsaiConfig::default();
447        let warnings = cfg.dry_run_check();
448        assert!(warnings.iter().any(|w| w.field == "model.model_path"));
449    }
450
451    #[test]
452    fn dry_run_nonexistent_model() {
453        let mut cfg = OxiBonsaiConfig::default();
454        cfg.model.model_path = Some(
455            std::env::temp_dir()
456                .join("nonexistent_oxibonsai_test_99999.gguf")
457                .display()
458                .to_string(),
459        );
460        let warnings = cfg.dry_run_check();
461        let model_warning = warnings
462            .iter()
463            .find(|w| w.field == "model.model_path")
464            .expect("should have model warning");
465        assert_eq!(model_warning.severity, WarningSeverity::Error);
466    }
467
468    #[test]
469    fn dry_run_high_temperature() {
470        let mut cfg = OxiBonsaiConfig::default();
471        cfg.sampling.temperature = 3.0;
472        let warnings = cfg.dry_run_check();
473        assert!(warnings.iter().any(|w| w.field == "sampling.temperature"));
474    }
475
476    #[test]
477    fn dry_run_large_seq_len() {
478        let mut cfg = OxiBonsaiConfig::default();
479        cfg.model.max_seq_len = 100_000;
480        let warnings = cfg.dry_run_check();
481        assert!(warnings.iter().any(|w| w.field == "model.max_seq_len"));
482    }
483
484    #[test]
485    fn warning_severity_display() {
486        assert_eq!(format!("{}", WarningSeverity::Info), "info");
487        assert_eq!(format!("{}", WarningSeverity::Warning), "warning");
488        assert_eq!(format!("{}", WarningSeverity::Error), "error");
489    }
490
491    #[test]
492    fn config_warning_display() {
493        let w = ConfigWarning {
494            field: "test.field".to_string(),
495            message: "test message".to_string(),
496            severity: WarningSeverity::Warning,
497        };
498        let s = format!("{}", w);
499        assert!(s.contains("warning"));
500        assert!(s.contains("test.field"));
501        assert!(s.contains("test message"));
502    }
503
504    #[test]
505    fn load_from_temp_file() {
506        let dir = std::env::temp_dir();
507        let path = dir.join("oxibonsai_test_config.toml");
508        std::fs::write(
509            &path,
510            r#"
511[server]
512host = "10.0.0.1"
513port = 4444
514"#,
515        )
516        .expect("write temp config");
517
518        let cfg = OxiBonsaiConfig::load(&path).expect("should load temp config");
519        assert_eq!(cfg.server.host, "10.0.0.1");
520        assert_eq!(cfg.server.port, 4444);
521
522        let _ = std::fs::remove_file(&path);
523    }
524}