1use clap::Parser;
5use config::ConfigError;
6use modelexpress_common::config::{LogFormat, LogLevel, load_layered_config};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::num::NonZeroU16;
10use std::path::PathBuf;
11use tracing::Level;
12
13use crate::cache::CacheEvictionConfig;
14
15#[derive(Parser, Debug)]
17#[command(author, version, about, long_about = None)]
18pub struct ServerArgs {
19 #[arg(short, long, value_name = "FILE")]
21 pub config: Option<PathBuf>,
22
23 #[arg(short, long, env = "MODEL_EXPRESS_SERVER_PORT")]
25 pub port: Option<NonZeroU16>,
26
27 #[arg(long, env = "MODEL_EXPRESS_SERVER_HOST")]
29 pub host: Option<String>,
30
31 #[arg(short, long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
33 pub log_level: Option<LogLevel>,
34
35 #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
37 pub log_format: Option<LogFormat>,
38
39 #[arg(long, env = "MODEL_EXPRESS_CACHE_DIRECTORY")]
41 pub cache_directory: Option<PathBuf>,
42
43 #[arg(long, env = "MODEL_EXPRESS_CACHE_EVICTION_ENABLED")]
45 pub cache_eviction_enabled: Option<bool>,
46
47 #[arg(long)]
49 pub validate_config: bool,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct ServerConfig {
55 pub server: ServerSettings,
57 pub cache: CacheConfig,
59 pub logging: LoggingConfig,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ServerSettings {
66 pub host: String,
68 pub port: NonZeroU16,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct CacheConfig {
75 pub eviction: CacheEvictionConfig,
77 pub directory: PathBuf,
79 pub max_size_bytes: Option<u64>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct LoggingConfig {
86 #[serde(default)]
88 pub level: LogLevel,
89 #[serde(default)]
91 pub format: LogFormat,
92 pub file: Option<PathBuf>,
94 pub structured: bool,
96}
97
98impl Default for ServerSettings {
99 fn default() -> Self {
100 Self {
101 host: "0.0.0.0".to_string(),
102 port: modelexpress_common::constants::DEFAULT_GRPC_PORT,
103 }
104 }
105}
106
107impl Default for CacheConfig {
108 fn default() -> Self {
109 Self {
110 eviction: CacheEvictionConfig::default(),
111 directory: PathBuf::from("./cache"),
112 max_size_bytes: None,
113 }
114 }
115}
116
117impl ServerConfig {
118 pub fn load(args: ServerArgs) -> Result<Self, ConfigError> {
124 Self::load_internal(args, false)
125 }
126
127 pub fn load_and_validate_strict(args: ServerArgs) -> Result<Self, ConfigError> {
131 Self::load_internal(args, true)
132 }
133
134 fn load_internal(args: ServerArgs, strict_mode: bool) -> Result<Self, ConfigError> {
136 let mut config = if strict_mode {
137 if let Some(ref config_file) = args.config {
139 modelexpress_common::config::validate_config_file(config_file)?
141 } else {
142 Self::default()
144 }
145 } else {
146 load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?
148 };
149
150 if let Some(port) = args.port {
152 config.server.port = port;
153 }
154
155 if let Some(host) = args.host {
156 config.server.host = host;
157 }
158
159 if let Some(log_level) = args.log_level {
160 config.logging.level = log_level;
161 }
162
163 if let Some(log_format) = args.log_format {
164 config.logging.format = log_format;
165 }
166
167 if let Some(cache_directory) = args.cache_directory {
169 config.cache.directory = cache_directory;
170 }
171
172 if let Some(cache_eviction_enabled) = args.cache_eviction_enabled {
173 config.cache.eviction.enabled = cache_eviction_enabled;
174 }
175
176 config.validate()?;
178
179 Ok(config)
180 }
181
182 pub fn validate(&self) -> Result<(), ConfigError> {
184 if let Some(parent) = self.cache.directory.parent()
186 && !parent.exists()
187 {
188 return Err(ConfigError::Message(format!(
189 "Cache directory parent does not exist: {}",
190 parent.display()
191 )));
192 }
193
194 Ok(())
195 }
196
197 pub fn socket_addr(&self) -> Result<SocketAddr, ConfigError> {
199 let addr = format!("{}:{}", self.server.host, self.server.port);
200 addr.parse()
201 .map_err(|e| ConfigError::Message(format!("Invalid server address {addr}: {e}")))
202 }
203
204 pub fn log_level(&self) -> Level {
206 self.logging.level.into()
207 }
208
209 pub fn print_config(&self) {
211 use tracing::info;
212
213 info!("Server Configuration:");
214 info!(" Host: {}", self.server.host);
215 info!(" Port: {}", self.server.port);
216
217 info!("Cache Configuration:");
218 info!(" Directory: {}", self.cache.directory.display());
219 info!(" Max Size: {:?}", self.cache.max_size_bytes);
220 info!(" Eviction Enabled: {}", self.cache.eviction.enabled);
221 info!(
222 " Eviction Check Interval: {}s",
223 self.cache.eviction.check_interval.num_seconds()
224 );
225
226 info!("Logging Configuration:");
227 info!(" Level: {}", self.logging.level);
228 info!(" Format: {}", self.logging.format);
229 info!(" File: {:?}", self.logging.file);
230 info!(" Structured: {}", self.logging.structured);
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use chrono::Duration;
238 use clap::Parser;
239 use modelexpress_common::config::{DurationConfig, parse_duration_string};
240 use std::fs;
241 use tempfile::tempdir;
242
243 #[test]
244 fn test_log_level_enum_parsing() {
245 let valid_levels = [
247 ("trace", LogLevel::Trace),
248 ("debug", LogLevel::Debug),
249 ("info", LogLevel::Info),
250 ("warn", LogLevel::Warn),
251 ("error", LogLevel::Error),
252 ];
253
254 for (level_str, expected_level) in &valid_levels {
255 let args = vec!["test", "--log-level", level_str];
256 if let Ok(parsed_args) = ServerArgs::try_parse_from(args) {
257 assert_eq!(parsed_args.log_level, Some(*expected_level));
258
259 assert_eq!(expected_level.to_string(), *level_str);
261
262 let tracing_level: Level = (*expected_level).into();
264 match expected_level {
266 LogLevel::Trace => assert_eq!(tracing_level, Level::TRACE),
267 LogLevel::Debug => assert_eq!(tracing_level, Level::DEBUG),
268 LogLevel::Info => assert_eq!(tracing_level, Level::INFO),
269 LogLevel::Warn => assert_eq!(tracing_level, Level::WARN),
270 LogLevel::Error => assert_eq!(tracing_level, Level::ERROR),
271 }
272 } else {
273 panic!("Failed to parse valid log level: {level_str}");
274 }
275 }
276 }
277
278 #[test]
279 fn test_log_level_enum_invalid() {
280 let args = vec!["test", "--log-level", "invalid"];
282 let result = ServerArgs::try_parse_from(args);
283 assert!(result.is_err());
284 }
285
286 #[test]
287 fn test_log_level_display() {
288 assert_eq!(LogLevel::Trace.to_string(), "trace");
289 assert_eq!(LogLevel::Debug.to_string(), "debug");
290 assert_eq!(LogLevel::Info.to_string(), "info");
291 assert_eq!(LogLevel::Warn.to_string(), "warn");
292 assert_eq!(LogLevel::Error.to_string(), "error");
293 }
294
295 #[test]
296 #[allow(clippy::expect_used)]
297 fn test_log_level_from_str() {
298 assert_eq!(
299 "trace"
300 .parse::<LogLevel>()
301 .expect("Failed to parse 'trace'"),
302 LogLevel::Trace
303 );
304 assert_eq!(
305 "debug"
306 .parse::<LogLevel>()
307 .expect("Failed to parse 'debug'"),
308 LogLevel::Debug
309 );
310 assert_eq!(
311 "info".parse::<LogLevel>().expect("Failed to parse 'info'"),
312 LogLevel::Info
313 );
314 assert_eq!(
315 "warn".parse::<LogLevel>().expect("Failed to parse 'warn'"),
316 LogLevel::Warn
317 );
318 assert_eq!(
319 "error"
320 .parse::<LogLevel>()
321 .expect("Failed to parse 'error'"),
322 LogLevel::Error
323 );
324
325 assert_eq!(
327 "TRACE"
328 .parse::<LogLevel>()
329 .expect("Failed to parse 'TRACE'"),
330 LogLevel::Trace
331 );
332 assert_eq!(
333 "Info".parse::<LogLevel>().expect("Failed to parse 'Info'"),
334 LogLevel::Info
335 );
336
337 assert!("invalid".parse::<LogLevel>().is_err());
339 }
340
341 #[test]
342 fn test_log_format_display() {
343 assert_eq!(LogFormat::Json.to_string(), "json");
344 assert_eq!(LogFormat::Pretty.to_string(), "pretty");
345 assert_eq!(LogFormat::Compact.to_string(), "compact");
346 }
347
348 #[test]
349 #[allow(clippy::expect_used)]
350 fn test_log_format_from_str() {
351 assert_eq!(
352 "json".parse::<LogFormat>().expect("Failed to parse 'json'"),
353 LogFormat::Json
354 );
355 assert_eq!(
356 "pretty"
357 .parse::<LogFormat>()
358 .expect("Failed to parse 'pretty'"),
359 LogFormat::Pretty
360 );
361 assert_eq!(
362 "compact"
363 .parse::<LogFormat>()
364 .expect("Failed to parse 'compact'"),
365 LogFormat::Compact
366 );
367
368 assert_eq!(
370 "JSON".parse::<LogFormat>().expect("Failed to parse 'JSON'"),
371 LogFormat::Json
372 );
373 assert_eq!(
374 "Pretty"
375 .parse::<LogFormat>()
376 .expect("Failed to parse 'Pretty'"),
377 LogFormat::Pretty
378 );
379
380 assert!("invalid".parse::<LogFormat>().is_err());
382 }
383
384 #[test]
385 #[allow(clippy::expect_used)]
386 fn test_parse_duration_string() {
387 assert_eq!(
389 parse_duration_string("30m")
390 .expect("Failed to parse 30m")
391 .num_seconds(),
392 30 * 60
393 );
394 assert_eq!(
395 parse_duration_string("45s")
396 .expect("Failed to parse 45s")
397 .num_seconds(),
398 45
399 );
400 assert_eq!(
401 parse_duration_string("1d")
402 .expect("Failed to parse 1d")
403 .num_seconds(),
404 24 * 3600
405 );
406 assert_eq!(
407 parse_duration_string("2h")
408 .expect("Failed to parse 2h")
409 .num_seconds(),
410 2 * 3600
411 );
412 assert_eq!(
413 parse_duration_string("2h30m")
414 .expect("Failed to parse 2h30m")
415 .num_seconds(),
416 2 * 3600 + 30 * 60
417 );
418
419 assert!(parse_duration_string("invalid").is_err());
421 }
422
423 #[test]
424 fn test_duration_config() {
425 let duration_config = DurationConfig::new(Duration::hours(2));
427 assert_eq!(duration_config.num_seconds(), 2 * 3600);
428 assert_eq!(duration_config.as_chrono_duration(), Duration::hours(2));
429
430 let duration_config = DurationConfig::hours(3);
432 assert_eq!(duration_config.num_seconds(), 3 * 3600);
433
434 assert_eq!(duration_config.to_string(), "10800s");
436 }
437
438 #[test]
439 #[allow(clippy::expect_used)]
440 fn test_duration_config_serde() {
441 let json = r#""2h""#;
443 let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
444 assert_eq!(duration_config.num_seconds(), 2 * 3600);
445
446 let json = r#"3600"#;
448 let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
449 assert_eq!(duration_config.num_seconds(), 3600);
450
451 let duration_config = DurationConfig::hours(1);
453 let serialized = serde_json::to_string(&duration_config).expect("Failed to serialize");
454 assert_eq!(serialized, r#"3600"#);
455 }
456
457 #[test]
458 #[allow(clippy::expect_used)]
459 fn test_server_config_load_and_validate_strict_valid_config() {
460 let temp_dir = tempdir().expect("Failed to create temp dir");
461 let config_file = temp_dir.path().join("valid_server_config.yaml");
462
463 let valid_config = r#"
464 server:
465 host: "127.0.0.1"
466 port: 8002
467 cache:
468 eviction:
469 enabled: false
470 policy:
471 type: lru
472 unused_threshold: "3d"
473 max_models: 10
474 min_free_space_bytes: 1000000
475 check_interval: "30m"
476 directory: "./test_cache"
477 max_size_bytes: 5000000
478 logging:
479 level: Debug
480 format: Json
481 file: null
482 structured: true
483 "#;
484
485 fs::write(&config_file, valid_config).expect("Failed to write config file");
486
487 let args = ServerArgs {
488 config: Some(config_file),
489 port: None,
490 host: None,
491 log_level: None,
492 log_format: None,
493 cache_directory: None,
494 cache_eviction_enabled: None,
495 validate_config: false,
496 };
497
498 let result = ServerConfig::load_and_validate_strict(args);
499 assert!(result.is_ok());
500
501 let config = result.expect("Expected successful config parsing");
502 assert_eq!(config.server.host, "127.0.0.1");
503 assert_eq!(config.server.port.get(), 8002);
504 assert!(!config.cache.eviction.enabled);
505 assert_eq!(config.logging.level, LogLevel::Debug);
506 assert_eq!(config.logging.format, LogFormat::Json);
507 }
508
509 #[test]
510 #[allow(clippy::expect_used)]
511 fn test_server_config_load_and_validate_strict_invalid_config() {
512 let temp_dir = tempdir().expect("Failed to create temp dir");
513 let config_file = temp_dir.path().join("invalid_server_config.yaml");
514
515 let invalid_config = r#"
516 server:
517 host: "127.0.0.1"
518 port: 8002
519 cache:
520 eviction:
521 enabled: "not_a_boolean" # Invalid type
522 "#;
523
524 fs::write(&config_file, invalid_config).expect("Failed to write config file");
525
526 let args = ServerArgs {
527 config: Some(config_file),
528 port: None,
529 host: None,
530 log_level: None,
531 log_format: None,
532 cache_directory: None,
533 cache_eviction_enabled: None,
534 validate_config: false,
535 };
536
537 let result = ServerConfig::load_and_validate_strict(args);
538 assert!(result.is_err());
539 }
540
541 #[test]
542 #[allow(clippy::expect_used)]
543 fn test_server_config_load_and_validate_strict_with_cli_overrides() {
544 let temp_dir = tempdir().expect("Failed to create temp dir");
545 let config_file = temp_dir.path().join("override_test_config.yaml");
546
547 let base_config = r#"
548 server:
549 host: "127.0.0.1"
550 port: 8002
551 cache:
552 eviction:
553 enabled: true
554 policy:
555 type: lru
556 unused_threshold: "1d"
557 max_models: null
558 min_free_space_bytes: null
559 check_interval: "1h"
560 directory: "./cache"
561 max_size_bytes: null
562 logging:
563 level: Info
564 format: Pretty
565 file: null
566 structured: false
567 "#;
568
569 fs::write(&config_file, base_config).expect("Failed to write config file");
570
571 let args = ServerArgs {
572 config: Some(config_file),
573 port: Some(NonZeroU16::new(9000).expect("9000 is non-zero")),
574 host: Some("0.0.0.0".to_string()),
575 log_level: Some(LogLevel::Error),
576 log_format: Some(LogFormat::Json),
577 cache_directory: Some(PathBuf::from("/tmp/override_cache")),
578 cache_eviction_enabled: Some(false),
579 validate_config: false,
580 };
581
582 let result = ServerConfig::load_and_validate_strict(args);
583 assert!(result.is_ok());
584
585 let config = result.expect("Expected successful config parsing");
586 assert_eq!(config.server.host, "0.0.0.0");
588 assert_eq!(config.server.port.get(), 9000);
589 assert_eq!(config.logging.level, LogLevel::Error);
590 assert_eq!(config.logging.format, LogFormat::Json);
591 assert_eq!(config.cache.directory, PathBuf::from("/tmp/override_cache"));
592 assert!(!config.cache.eviction.enabled);
593 }
594
595 #[test]
596 #[allow(clippy::expect_used)]
597 fn test_server_config_load_and_validate_strict_no_config_file() {
598 let args = ServerArgs {
599 config: None,
600 port: Some(NonZeroU16::new(9001).expect("9001 is non-zero")),
601 host: Some("localhost".to_string()),
602 log_level: Some(LogLevel::Warn),
603 log_format: None,
604 cache_directory: None,
605 cache_eviction_enabled: None,
606 validate_config: false,
607 };
608
609 let result = ServerConfig::load_and_validate_strict(args);
611 assert!(result.is_ok());
612
613 let config = result.expect("Expected successful config parsing");
614 assert_eq!(config.server.host, "localhost");
615 assert_eq!(config.server.port.get(), 9001);
616 assert_eq!(config.logging.level, LogLevel::Warn);
617 }
618}