1use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8use crate::error::{RuntimeError, RuntimeResult};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[serde(default)]
13pub struct OxiBonsaiConfig {
14 pub server: ServerConfig,
16 pub sampling: SamplingConfig,
18 pub model: ModelConfig,
20 pub observability: ObservabilityConfig,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(default)]
27pub struct ServerConfig {
28 pub host: String,
30 pub port: u16,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(default)]
37pub struct SamplingConfig {
38 pub temperature: f32,
40 pub top_k: usize,
42 pub top_p: f32,
44 pub repetition_penalty: f32,
46 pub max_tokens: usize,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(default)]
53pub struct ModelConfig {
54 pub model_path: Option<String>,
56 pub tokenizer_path: Option<String>,
58 pub max_seq_len: usize,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(default)]
65pub struct ObservabilityConfig {
66 pub log_level: String,
68 pub json_logs: bool,
70}
71
72impl Default for ServerConfig {
73 fn default() -> Self {
74 Self {
75 host: "0.0.0.0".to_string(),
76 port: 8080,
77 }
78 }
79}
80
81impl Default for SamplingConfig {
82 fn default() -> Self {
83 Self {
84 temperature: 0.7,
85 top_k: 40,
86 top_p: 0.9,
87 repetition_penalty: 1.1,
88 max_tokens: 512,
89 }
90 }
91}
92
93impl Default for ModelConfig {
94 fn default() -> Self {
95 Self {
96 model_path: None,
97 tokenizer_path: None,
98 max_seq_len: 4096,
99 }
100 }
101}
102
103impl Default for ObservabilityConfig {
104 fn default() -> Self {
105 Self {
106 log_level: "info".to_string(),
107 json_logs: false,
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum WarningSeverity {
115 Info,
117 Warning,
119 Error,
121}
122
123impl std::fmt::Display for WarningSeverity {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 Self::Info => write!(f, "info"),
127 Self::Warning => write!(f, "warning"),
128 Self::Error => write!(f, "error"),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct ConfigWarning {
136 pub field: String,
138 pub message: String,
140 pub severity: WarningSeverity,
142}
143
144impl std::fmt::Display for ConfigWarning {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(f, "[{}] {}: {}", self.severity, self.field, self.message)
147 }
148}
149
150impl OxiBonsaiConfig {
151 pub fn load(path: &Path) -> RuntimeResult<Self> {
153 let content = std::fs::read_to_string(path).map_err(|e| {
154 RuntimeError::Config(format!(
155 "failed to read config file {}: {e}",
156 path.display()
157 ))
158 })?;
159 let config: Self = toml::from_str(&content).map_err(|e| {
160 RuntimeError::Config(format!(
161 "failed to parse config file {}: {e}",
162 path.display()
163 ))
164 })?;
165 Ok(config)
166 }
167
168 pub fn load_or_default(path: Option<&Path>) -> Self {
170 match path {
171 Some(p) => match Self::load(p) {
172 Ok(cfg) => cfg,
173 Err(e) => {
174 tracing::warn!(error = %e, "failed to load config, using defaults");
175 Self::default()
176 }
177 },
178 None => Self::default(),
179 }
180 }
181
182 pub fn validate(&self) -> RuntimeResult<()> {
184 if self.sampling.temperature < 0.0 {
185 return Err(RuntimeError::Config(format!(
186 "sampling.temperature must be >= 0.0, got {}",
187 self.sampling.temperature
188 )));
189 }
190 if self.sampling.top_p < 0.0 || self.sampling.top_p > 1.0 {
191 return Err(RuntimeError::Config(format!(
192 "sampling.top_p must be in [0.0, 1.0], got {}",
193 self.sampling.top_p
194 )));
195 }
196 if self.sampling.repetition_penalty < 1.0 {
197 return Err(RuntimeError::Config(format!(
198 "sampling.repetition_penalty must be >= 1.0, got {}",
199 self.sampling.repetition_penalty
200 )));
201 }
202 if self.sampling.max_tokens == 0 {
203 return Err(RuntimeError::Config(
204 "sampling.max_tokens must be > 0".to_string(),
205 ));
206 }
207 if self.model.max_seq_len == 0 {
208 return Err(RuntimeError::Config(
209 "model.max_seq_len must be > 0".to_string(),
210 ));
211 }
212 if self.server.host.is_empty() {
213 return Err(RuntimeError::Config(
214 "server.host must not be empty".to_string(),
215 ));
216 }
217 Ok(())
219 }
220
221 pub fn dry_run_check(&self) -> Vec<ConfigWarning> {
227 let mut warnings = Vec::new();
228
229 match &self.model.model_path {
231 None => {
232 warnings.push(ConfigWarning {
233 field: "model.model_path".to_string(),
234 message: "no model path configured".to_string(),
235 severity: WarningSeverity::Warning,
236 });
237 }
238 Some(path) => {
239 if !Path::new(path).exists() {
240 warnings.push(ConfigWarning {
241 field: "model.model_path".to_string(),
242 message: format!("model file does not exist: {}", path),
243 severity: WarningSeverity::Error,
244 });
245 }
246 }
247 }
248
249 match &self.model.tokenizer_path {
251 None => {
252 warnings.push(ConfigWarning {
253 field: "model.tokenizer_path".to_string(),
254 message: "no tokenizer path configured; token IDs will be used".to_string(),
255 severity: WarningSeverity::Info,
256 });
257 }
258 Some(path) => {
259 if !Path::new(path).exists() {
260 warnings.push(ConfigWarning {
261 field: "model.tokenizer_path".to_string(),
262 message: format!("tokenizer file does not exist: {}", path),
263 severity: WarningSeverity::Error,
264 });
265 }
266 }
267 }
268
269 if self.model.max_seq_len > 65536 {
271 warnings.push(ConfigWarning {
272 field: "model.max_seq_len".to_string(),
273 message: format!(
274 "very large max_seq_len ({}); may require significant memory",
275 self.model.max_seq_len
276 ),
277 severity: WarningSeverity::Warning,
278 });
279 }
280
281 if self.sampling.temperature > 2.0 {
283 warnings.push(ConfigWarning {
284 field: "sampling.temperature".to_string(),
285 message: format!(
286 "high temperature ({}) may produce incoherent output",
287 self.sampling.temperature
288 ),
289 severity: WarningSeverity::Warning,
290 });
291 }
292
293 warnings
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn default_values() {
303 let cfg = OxiBonsaiConfig::default();
304 assert_eq!(cfg.server.host, "0.0.0.0");
305 assert_eq!(cfg.server.port, 8080);
306 assert!((cfg.sampling.temperature - 0.7).abs() < f32::EPSILON);
307 assert_eq!(cfg.sampling.top_k, 40);
308 assert!((cfg.sampling.top_p - 0.9).abs() < f32::EPSILON);
309 assert!((cfg.sampling.repetition_penalty - 1.1).abs() < f32::EPSILON);
310 assert_eq!(cfg.sampling.max_tokens, 512);
311 assert_eq!(cfg.model.max_seq_len, 4096);
312 assert!(cfg.model.model_path.is_none());
313 assert!(cfg.model.tokenizer_path.is_none());
314 assert_eq!(cfg.observability.log_level, "info");
315 assert!(!cfg.observability.json_logs);
316 }
317
318 #[test]
319 fn toml_parsing() {
320 let model_path = std::env::temp_dir().join("model.gguf");
321 let tokenizer_path = std::env::temp_dir().join("tokenizer.json");
322 let toml_str = format!(
323 r#"
324[server]
325host = "127.0.0.1"
326port = 3000
327
328[sampling]
329temperature = 0.5
330top_k = 50
331top_p = 0.95
332repetition_penalty = 1.2
333max_tokens = 1024
334
335[model]
336model_path = "{}"
337tokenizer_path = "{}"
338max_seq_len = 8192
339
340[observability]
341log_level = "debug"
342json_logs = true
343"#,
344 model_path.display(),
345 tokenizer_path.display()
346 );
347 let cfg: OxiBonsaiConfig = toml::from_str(&toml_str).expect("should parse valid TOML");
348 assert_eq!(cfg.server.host, "127.0.0.1");
349 assert_eq!(cfg.server.port, 3000);
350 assert!((cfg.sampling.temperature - 0.5).abs() < f32::EPSILON);
351 assert_eq!(cfg.sampling.top_k, 50);
352 assert_eq!(cfg.sampling.max_tokens, 1024);
353 assert_eq!(
354 cfg.model.model_path.as_deref(),
355 Some(model_path.to_str().expect("path is valid UTF-8"))
356 );
357 assert_eq!(cfg.model.max_seq_len, 8192);
358 assert_eq!(cfg.observability.log_level, "debug");
359 assert!(cfg.observability.json_logs);
360 }
361
362 #[test]
363 fn partial_toml_uses_defaults() {
364 let toml_str = r#"
365[server]
366port = 9090
367"#;
368 let cfg: OxiBonsaiConfig = toml::from_str(toml_str).expect("should parse partial TOML");
369 assert_eq!(cfg.server.port, 9090);
370 assert_eq!(cfg.server.host, "0.0.0.0");
372 assert!((cfg.sampling.temperature - 0.7).abs() < f32::EPSILON);
373 assert_eq!(cfg.model.max_seq_len, 4096);
374 }
375
376 #[test]
377 fn missing_file_returns_default() {
378 let path = std::env::temp_dir().join("nonexistent_oxibonsai_config_12345.toml");
379 let cfg = OxiBonsaiConfig::load_or_default(Some(&path));
380 assert_eq!(cfg.server.port, 8080);
381 }
382
383 #[test]
384 fn load_or_default_none_returns_default() {
385 let cfg = OxiBonsaiConfig::load_or_default(None);
386 assert_eq!(cfg.server.host, "0.0.0.0");
387 }
388
389 #[test]
392 fn validate_defaults_ok() {
393 let cfg = OxiBonsaiConfig::default();
394 assert!(cfg.validate().is_ok());
395 }
396
397 #[test]
398 fn validate_negative_temperature() {
399 let mut cfg = OxiBonsaiConfig::default();
400 cfg.sampling.temperature = -1.0;
401 assert!(cfg.validate().is_err());
402 }
403
404 #[test]
405 fn validate_top_p_out_of_range() {
406 let mut cfg = OxiBonsaiConfig::default();
407 cfg.sampling.top_p = 1.5;
408 assert!(cfg.validate().is_err());
409
410 cfg.sampling.top_p = -0.1;
411 assert!(cfg.validate().is_err());
412 }
413
414 #[test]
415 fn validate_repetition_penalty_too_low() {
416 let mut cfg = OxiBonsaiConfig::default();
417 cfg.sampling.repetition_penalty = 0.5;
418 assert!(cfg.validate().is_err());
419 }
420
421 #[test]
422 fn validate_max_tokens_zero() {
423 let mut cfg = OxiBonsaiConfig::default();
424 cfg.sampling.max_tokens = 0;
425 assert!(cfg.validate().is_err());
426 }
427
428 #[test]
429 fn validate_max_seq_len_zero() {
430 let mut cfg = OxiBonsaiConfig::default();
431 cfg.model.max_seq_len = 0;
432 assert!(cfg.validate().is_err());
433 }
434
435 #[test]
436 fn validate_empty_host() {
437 let mut cfg = OxiBonsaiConfig::default();
438 cfg.server.host = String::new();
439 assert!(cfg.validate().is_err());
440 }
441
442 #[test]
445 fn dry_run_no_model_path() {
446 let cfg = OxiBonsaiConfig::default();
447 let warnings = cfg.dry_run_check();
448 assert!(warnings.iter().any(|w| w.field == "model.model_path"));
449 }
450
451 #[test]
452 fn dry_run_nonexistent_model() {
453 let mut cfg = OxiBonsaiConfig::default();
454 cfg.model.model_path = Some(
455 std::env::temp_dir()
456 .join("nonexistent_oxibonsai_test_99999.gguf")
457 .display()
458 .to_string(),
459 );
460 let warnings = cfg.dry_run_check();
461 let model_warning = warnings
462 .iter()
463 .find(|w| w.field == "model.model_path")
464 .expect("should have model warning");
465 assert_eq!(model_warning.severity, WarningSeverity::Error);
466 }
467
468 #[test]
469 fn dry_run_high_temperature() {
470 let mut cfg = OxiBonsaiConfig::default();
471 cfg.sampling.temperature = 3.0;
472 let warnings = cfg.dry_run_check();
473 assert!(warnings.iter().any(|w| w.field == "sampling.temperature"));
474 }
475
476 #[test]
477 fn dry_run_large_seq_len() {
478 let mut cfg = OxiBonsaiConfig::default();
479 cfg.model.max_seq_len = 100_000;
480 let warnings = cfg.dry_run_check();
481 assert!(warnings.iter().any(|w| w.field == "model.max_seq_len"));
482 }
483
484 #[test]
485 fn warning_severity_display() {
486 assert_eq!(format!("{}", WarningSeverity::Info), "info");
487 assert_eq!(format!("{}", WarningSeverity::Warning), "warning");
488 assert_eq!(format!("{}", WarningSeverity::Error), "error");
489 }
490
491 #[test]
492 fn config_warning_display() {
493 let w = ConfigWarning {
494 field: "test.field".to_string(),
495 message: "test message".to_string(),
496 severity: WarningSeverity::Warning,
497 };
498 let s = format!("{}", w);
499 assert!(s.contains("warning"));
500 assert!(s.contains("test.field"));
501 assert!(s.contains("test message"));
502 }
503
504 #[test]
505 fn load_from_temp_file() {
506 let dir = std::env::temp_dir();
507 let path = dir.join("oxibonsai_test_config.toml");
508 std::fs::write(
509 &path,
510 r#"
511[server]
512host = "10.0.0.1"
513port = 4444
514"#,
515 )
516 .expect("write temp config");
517
518 let cfg = OxiBonsaiConfig::load(&path).expect("should load temp config");
519 assert_eq!(cfg.server.host, "10.0.0.1");
520 assert_eq!(cfg.server.port, 4444);
521
522 let _ = std::fs::remove_file(&path);
523 }
524}