1use clap::Parser;
5use config::ConfigError;
6use modelexpress_common::config::{LogFormat, LogLevel, load_layered_config};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::num::NonZeroU16;
10use std::path::PathBuf;
11use tracing::Level;
12
13use crate::cache::CacheEvictionConfig;
14
15#[derive(Parser, Debug)]
17#[command(author, version, about, long_about = None)]
18pub struct ServerArgs {
19 #[arg(short, long, value_name = "FILE")]
21 pub config: Option<PathBuf>,
22
23 #[arg(short, long, env = "MODEL_EXPRESS_SERVER_PORT")]
25 pub port: Option<NonZeroU16>,
26
27 #[arg(long, env = "MODEL_EXPRESS_SERVER_HOST")]
29 pub host: Option<String>,
30
31 #[arg(short, long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
33 pub log_level: Option<LogLevel>,
34
35 #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
37 pub log_format: Option<LogFormat>,
38
39 #[arg(short, long, env = "MODEL_EXPRESS_DATABASE_PATH")]
41 pub database_path: Option<PathBuf>,
42
43 #[arg(long, env = "MODEL_EXPRESS_CACHE_DIRECTORY")]
45 pub cache_directory: Option<PathBuf>,
46
47 #[arg(long, env = "MODEL_EXPRESS_CACHE_EVICTION_ENABLED")]
49 pub cache_eviction_enabled: Option<bool>,
50
51 #[arg(long)]
53 pub validate_config: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct ServerConfig {
59 pub server: ServerSettings,
61 pub database: DatabaseSettings,
63 pub cache: CacheConfig,
65 pub logging: LoggingConfig,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ServerSettings {
72 pub host: String,
74 pub port: NonZeroU16,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct DatabaseSettings {
81 pub path: PathBuf,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CacheConfig {
88 pub eviction: CacheEvictionConfig,
90 pub directory: PathBuf,
92 pub max_size_bytes: Option<u64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, Default)]
98pub struct LoggingConfig {
99 #[serde(default)]
101 pub level: LogLevel,
102 #[serde(default)]
104 pub format: LogFormat,
105 pub file: Option<PathBuf>,
107 pub structured: bool,
109}
110
111impl Default for ServerSettings {
112 fn default() -> Self {
113 Self {
114 host: "0.0.0.0".to_string(),
115 port: modelexpress_common::constants::DEFAULT_GRPC_PORT,
116 }
117 }
118}
119
120impl Default for DatabaseSettings {
121 fn default() -> Self {
122 Self {
123 path: PathBuf::from("./models.db"),
124 }
125 }
126}
127
128impl Default for CacheConfig {
129 fn default() -> Self {
130 Self {
131 eviction: CacheEvictionConfig::default(),
132 directory: PathBuf::from("./cache"),
133 max_size_bytes: None,
134 }
135 }
136}
137
138impl ServerConfig {
139 pub fn load(args: ServerArgs) -> Result<Self, ConfigError> {
145 Self::load_internal(args, false)
146 }
147
148 pub fn load_and_validate_strict(args: ServerArgs) -> Result<Self, ConfigError> {
152 Self::load_internal(args, true)
153 }
154
155 fn load_internal(args: ServerArgs, strict_mode: bool) -> Result<Self, ConfigError> {
157 let mut config = if strict_mode {
158 if let Some(ref config_file) = args.config {
160 modelexpress_common::config::validate_config_file(config_file)?
162 } else {
163 Self::default()
165 }
166 } else {
167 load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?
169 };
170
171 if let Some(port) = args.port {
173 config.server.port = port;
174 }
175
176 if let Some(host) = args.host {
177 config.server.host = host;
178 }
179
180 if let Some(log_level) = args.log_level {
181 config.logging.level = log_level;
182 }
183
184 if let Some(log_format) = args.log_format {
185 config.logging.format = log_format;
186 }
187
188 if let Some(database_path) = args.database_path {
189 config.database.path = database_path;
190 }
191
192 if let Some(cache_directory) = args.cache_directory {
194 config.cache.directory = cache_directory;
195 }
196
197 if let Some(cache_eviction_enabled) = args.cache_eviction_enabled {
198 config.cache.eviction.enabled = cache_eviction_enabled;
199 }
200
201 config.validate()?;
203
204 Ok(config)
205 }
206
207 pub fn validate(&self) -> Result<(), ConfigError> {
209 if let Some(parent) = self.database.path.parent()
211 && !parent.exists()
212 {
213 return Err(ConfigError::Message(format!(
214 "Database directory does not exist: {}",
215 parent.display()
216 )));
217 }
218
219 if let Some(parent) = self.cache.directory.parent()
221 && !parent.exists()
222 {
223 return Err(ConfigError::Message(format!(
224 "Cache directory parent does not exist: {}",
225 parent.display()
226 )));
227 }
228
229 Ok(())
230 }
231
232 pub fn socket_addr(&self) -> Result<SocketAddr, ConfigError> {
234 let addr = format!("{}:{}", self.server.host, self.server.port);
235 addr.parse()
236 .map_err(|e| ConfigError::Message(format!("Invalid server address {addr}: {e}")))
237 }
238
239 pub fn log_level(&self) -> Level {
241 self.logging.level.into()
242 }
243
244 pub fn print_config(&self) {
246 use tracing::info;
247
248 info!("Server Configuration:");
249 info!(" Host: {}", self.server.host);
250 info!(" Port: {}", self.server.port);
251
252 info!("Database Configuration:");
253 info!(" Path: {}", self.database.path.display());
254
255 info!("Cache Configuration:");
256 info!(" Directory: {}", self.cache.directory.display());
257 info!(" Max Size: {:?}", self.cache.max_size_bytes);
258 info!(" Eviction Enabled: {}", self.cache.eviction.enabled);
259 info!(
260 " Eviction Check Interval: {}s",
261 self.cache.eviction.check_interval.num_seconds()
262 );
263
264 info!("Logging Configuration:");
265 info!(" Level: {}", self.logging.level);
266 info!(" Format: {}", self.logging.format);
267 info!(" File: {:?}", self.logging.file);
268 info!(" Structured: {}", self.logging.structured);
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use chrono::Duration;
276 use clap::Parser;
277 use modelexpress_common::config::{DurationConfig, parse_duration_string};
278 use std::fs;
279 use tempfile::tempdir;
280
281 #[test]
282 fn test_log_level_enum_parsing() {
283 let valid_levels = [
285 ("trace", LogLevel::Trace),
286 ("debug", LogLevel::Debug),
287 ("info", LogLevel::Info),
288 ("warn", LogLevel::Warn),
289 ("error", LogLevel::Error),
290 ];
291
292 for (level_str, expected_level) in &valid_levels {
293 let args = vec!["test", "--log-level", level_str];
294 if let Ok(parsed_args) = ServerArgs::try_parse_from(args) {
295 assert_eq!(parsed_args.log_level, Some(*expected_level));
296
297 assert_eq!(expected_level.to_string(), *level_str);
299
300 let tracing_level: Level = (*expected_level).into();
302 match expected_level {
304 LogLevel::Trace => assert_eq!(tracing_level, Level::TRACE),
305 LogLevel::Debug => assert_eq!(tracing_level, Level::DEBUG),
306 LogLevel::Info => assert_eq!(tracing_level, Level::INFO),
307 LogLevel::Warn => assert_eq!(tracing_level, Level::WARN),
308 LogLevel::Error => assert_eq!(tracing_level, Level::ERROR),
309 }
310 } else {
311 panic!("Failed to parse valid log level: {level_str}");
312 }
313 }
314 }
315
316 #[test]
317 fn test_log_level_enum_invalid() {
318 let args = vec!["test", "--log-level", "invalid"];
320 let result = ServerArgs::try_parse_from(args);
321 assert!(result.is_err());
322 }
323
324 #[test]
325 fn test_log_level_display() {
326 assert_eq!(LogLevel::Trace.to_string(), "trace");
327 assert_eq!(LogLevel::Debug.to_string(), "debug");
328 assert_eq!(LogLevel::Info.to_string(), "info");
329 assert_eq!(LogLevel::Warn.to_string(), "warn");
330 assert_eq!(LogLevel::Error.to_string(), "error");
331 }
332
333 #[test]
334 #[allow(clippy::expect_used)]
335 fn test_log_level_from_str() {
336 assert_eq!(
337 "trace"
338 .parse::<LogLevel>()
339 .expect("Failed to parse 'trace'"),
340 LogLevel::Trace
341 );
342 assert_eq!(
343 "debug"
344 .parse::<LogLevel>()
345 .expect("Failed to parse 'debug'"),
346 LogLevel::Debug
347 );
348 assert_eq!(
349 "info".parse::<LogLevel>().expect("Failed to parse 'info'"),
350 LogLevel::Info
351 );
352 assert_eq!(
353 "warn".parse::<LogLevel>().expect("Failed to parse 'warn'"),
354 LogLevel::Warn
355 );
356 assert_eq!(
357 "error"
358 .parse::<LogLevel>()
359 .expect("Failed to parse 'error'"),
360 LogLevel::Error
361 );
362
363 assert_eq!(
365 "TRACE"
366 .parse::<LogLevel>()
367 .expect("Failed to parse 'TRACE'"),
368 LogLevel::Trace
369 );
370 assert_eq!(
371 "Info".parse::<LogLevel>().expect("Failed to parse 'Info'"),
372 LogLevel::Info
373 );
374
375 assert!("invalid".parse::<LogLevel>().is_err());
377 }
378
379 #[test]
380 fn test_log_format_display() {
381 assert_eq!(LogFormat::Json.to_string(), "json");
382 assert_eq!(LogFormat::Pretty.to_string(), "pretty");
383 assert_eq!(LogFormat::Compact.to_string(), "compact");
384 }
385
386 #[test]
387 #[allow(clippy::expect_used)]
388 fn test_log_format_from_str() {
389 assert_eq!(
390 "json".parse::<LogFormat>().expect("Failed to parse 'json'"),
391 LogFormat::Json
392 );
393 assert_eq!(
394 "pretty"
395 .parse::<LogFormat>()
396 .expect("Failed to parse 'pretty'"),
397 LogFormat::Pretty
398 );
399 assert_eq!(
400 "compact"
401 .parse::<LogFormat>()
402 .expect("Failed to parse 'compact'"),
403 LogFormat::Compact
404 );
405
406 assert_eq!(
408 "JSON".parse::<LogFormat>().expect("Failed to parse 'JSON'"),
409 LogFormat::Json
410 );
411 assert_eq!(
412 "Pretty"
413 .parse::<LogFormat>()
414 .expect("Failed to parse 'Pretty'"),
415 LogFormat::Pretty
416 );
417
418 assert!("invalid".parse::<LogFormat>().is_err());
420 }
421
422 #[test]
423 #[allow(clippy::expect_used)]
424 fn test_parse_duration_string() {
425 assert_eq!(
427 parse_duration_string("30m")
428 .expect("Failed to parse 30m")
429 .num_seconds(),
430 30 * 60
431 );
432 assert_eq!(
433 parse_duration_string("45s")
434 .expect("Failed to parse 45s")
435 .num_seconds(),
436 45
437 );
438 assert_eq!(
439 parse_duration_string("1d")
440 .expect("Failed to parse 1d")
441 .num_seconds(),
442 24 * 3600
443 );
444 assert_eq!(
445 parse_duration_string("2h")
446 .expect("Failed to parse 2h")
447 .num_seconds(),
448 2 * 3600
449 );
450 assert_eq!(
451 parse_duration_string("2h30m")
452 .expect("Failed to parse 2h30m")
453 .num_seconds(),
454 2 * 3600 + 30 * 60
455 );
456
457 assert!(parse_duration_string("invalid").is_err());
459 }
460
461 #[test]
462 fn test_duration_config() {
463 let duration_config = DurationConfig::new(Duration::hours(2));
465 assert_eq!(duration_config.num_seconds(), 2 * 3600);
466 assert_eq!(duration_config.as_chrono_duration(), Duration::hours(2));
467
468 let duration_config = DurationConfig::hours(3);
470 assert_eq!(duration_config.num_seconds(), 3 * 3600);
471
472 assert_eq!(duration_config.to_string(), "10800s");
474 }
475
476 #[test]
477 #[allow(clippy::expect_used)]
478 fn test_duration_config_serde() {
479 let json = r#""2h""#;
481 let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
482 assert_eq!(duration_config.num_seconds(), 2 * 3600);
483
484 let json = r#"3600"#;
486 let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
487 assert_eq!(duration_config.num_seconds(), 3600);
488
489 let duration_config = DurationConfig::hours(1);
491 let serialized = serde_json::to_string(&duration_config).expect("Failed to serialize");
492 assert_eq!(serialized, r#"3600"#);
493 }
494
495 #[test]
496 #[allow(clippy::expect_used)]
497 fn test_server_config_load_and_validate_strict_valid_config() {
498 let temp_dir = tempdir().expect("Failed to create temp dir");
499 let config_file = temp_dir.path().join("valid_server_config.yaml");
500
501 let valid_config = r#"
502 server:
503 host: "127.0.0.1"
504 port: 8002
505 database:
506 path: "./test.db"
507 cache:
508 eviction:
509 enabled: false
510 policy:
511 type: lru
512 unused_threshold: "3d"
513 max_models: 10
514 min_free_space_bytes: 1000000
515 check_interval: "30m"
516 directory: "./test_cache"
517 max_size_bytes: 5000000
518 logging:
519 level: Debug
520 format: Json
521 file: null
522 structured: true
523 "#;
524
525 fs::write(&config_file, valid_config).expect("Failed to write config file");
526
527 let args = ServerArgs {
528 config: Some(config_file),
529 port: None,
530 host: None,
531 log_level: None,
532 log_format: None,
533 database_path: None,
534 cache_directory: None,
535 cache_eviction_enabled: None,
536 validate_config: false,
537 };
538
539 let result = ServerConfig::load_and_validate_strict(args);
540 assert!(result.is_ok());
541
542 let config = result.expect("Expected successful config parsing");
543 assert_eq!(config.server.host, "127.0.0.1");
544 assert_eq!(config.server.port.get(), 8002);
545 assert_eq!(config.database.path, PathBuf::from("./test.db"));
546 assert!(!config.cache.eviction.enabled);
547 assert_eq!(config.logging.level, LogLevel::Debug);
548 assert_eq!(config.logging.format, LogFormat::Json);
549 }
550
551 #[test]
552 #[allow(clippy::expect_used)]
553 fn test_server_config_load_and_validate_strict_invalid_config() {
554 let temp_dir = tempdir().expect("Failed to create temp dir");
555 let config_file = temp_dir.path().join("invalid_server_config.yaml");
556
557 let invalid_config = r#"
558 server:
559 host: "127.0.0.1"
560 port: 8002
561 database:
562 pat: "./test.db" # Wrong field name (should be 'path')
563 cache:
564 eviction:
565 enabled: "not_a_boolean" # Invalid type
566 "#;
567
568 fs::write(&config_file, invalid_config).expect("Failed to write config file");
569
570 let args = ServerArgs {
571 config: Some(config_file),
572 port: None,
573 host: None,
574 log_level: None,
575 log_format: None,
576 database_path: None,
577 cache_directory: None,
578 cache_eviction_enabled: None,
579 validate_config: false,
580 };
581
582 let result = ServerConfig::load_and_validate_strict(args);
583 assert!(result.is_err());
584 }
585
586 #[test]
587 #[allow(clippy::expect_used)]
588 fn test_server_config_load_and_validate_strict_with_cli_overrides() {
589 let temp_dir = tempdir().expect("Failed to create temp dir");
590 let config_file = temp_dir.path().join("override_test_config.yaml");
591
592 let base_config = r#"
593 server:
594 host: "127.0.0.1"
595 port: 8002
596 database:
597 path: "./test.db"
598 cache:
599 eviction:
600 enabled: true
601 policy:
602 type: lru
603 unused_threshold: "1d"
604 max_models: null
605 min_free_space_bytes: null
606 check_interval: "1h"
607 directory: "./cache"
608 max_size_bytes: null
609 logging:
610 level: Info
611 format: Pretty
612 file: null
613 structured: false
614 "#;
615
616 fs::write(&config_file, base_config).expect("Failed to write config file");
617
618 let args = ServerArgs {
619 config: Some(config_file),
620 port: Some(NonZeroU16::new(9000).expect("9000 is non-zero")),
621 host: Some("0.0.0.0".to_string()),
622 log_level: Some(LogLevel::Error),
623 log_format: Some(LogFormat::Json),
624 database_path: Some(PathBuf::from("./override.db")),
625 cache_directory: Some(PathBuf::from("/tmp/override_cache")),
626 cache_eviction_enabled: Some(false),
627 validate_config: false,
628 };
629
630 let result = ServerConfig::load_and_validate_strict(args);
631 assert!(result.is_ok());
632
633 let config = result.expect("Expected successful config parsing");
634 assert_eq!(config.server.host, "0.0.0.0");
636 assert_eq!(config.server.port.get(), 9000);
637 assert_eq!(config.database.path, PathBuf::from("./override.db"));
638 assert_eq!(config.logging.level, LogLevel::Error);
639 assert_eq!(config.logging.format, LogFormat::Json);
640 assert_eq!(config.cache.directory, PathBuf::from("/tmp/override_cache"));
641 assert!(!config.cache.eviction.enabled);
642 }
643
644 #[test]
645 #[allow(clippy::expect_used)]
646 fn test_server_config_load_and_validate_strict_no_config_file() {
647 let args = ServerArgs {
648 config: None,
649 port: Some(NonZeroU16::new(9001).expect("9001 is non-zero")),
650 host: Some("localhost".to_string()),
651 log_level: Some(LogLevel::Warn),
652 log_format: None,
653 database_path: None,
654 cache_directory: None,
655 cache_eviction_enabled: None,
656 validate_config: false,
657 };
658
659 let result = ServerConfig::load_and_validate_strict(args);
661 assert!(result.is_ok());
662
663 let config = result.expect("Expected successful config parsing");
664 assert_eq!(config.server.host, "localhost");
665 assert_eq!(config.server.port.get(), 9001);
666 assert_eq!(config.logging.level, LogLevel::Warn);
667 }
668}