modelexpress_server/
config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use clap::Parser;
5use config::ConfigError;
6use modelexpress_common::config::{LogFormat, LogLevel, load_layered_config};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::num::NonZeroU16;
10use std::path::PathBuf;
11use tracing::Level;
12
13use crate::cache::CacheEvictionConfig;
14
15/// Command line arguments for the server
16#[derive(Parser, Debug)]
17#[command(author, version, about, long_about = None)]
18pub struct ServerArgs {
19    /// Configuration file path
20    #[arg(short, long, value_name = "FILE")]
21    pub config: Option<PathBuf>,
22
23    /// Server port
24    #[arg(short, long, env = "MODEL_EXPRESS_SERVER_PORT")]
25    pub port: Option<NonZeroU16>,
26
27    /// Server host address
28    #[arg(long, env = "MODEL_EXPRESS_SERVER_HOST")]
29    pub host: Option<String>,
30
31    /// Log level
32    #[arg(short, long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
33    pub log_level: Option<LogLevel>,
34
35    /// Log format
36    #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
37    pub log_format: Option<LogFormat>,
38
39    /// Database file path
40    #[arg(short, long, env = "MODEL_EXPRESS_DATABASE_PATH")]
41    pub database_path: Option<PathBuf>,
42
43    /// Cache directory path
44    #[arg(long, env = "MODEL_EXPRESS_CACHE_DIRECTORY")]
45    pub cache_directory: Option<PathBuf>,
46
47    /// Enable cache eviction
48    #[arg(long, env = "MODEL_EXPRESS_CACHE_EVICTION_ENABLED")]
49    pub cache_eviction_enabled: Option<bool>,
50
51    /// Validate configuration and exit
52    #[arg(long)]
53    pub validate_config: bool,
54}
55
56/// Complete server configuration
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct ServerConfig {
59    /// Server settings
60    pub server: ServerSettings,
61    /// Database settings
62    pub database: DatabaseSettings,
63    /// Cache configuration
64    pub cache: CacheConfig,
65    /// Logging configuration
66    pub logging: LoggingConfig,
67}
68
69/// Server-specific settings
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ServerSettings {
72    /// Server host address
73    pub host: String,
74    /// Server port
75    pub port: NonZeroU16,
76}
77
78/// Database configuration
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct DatabaseSettings {
81    /// Database file path
82    pub path: PathBuf,
83}
84
85/// Cache configuration wrapper
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CacheConfig {
88    /// Cache eviction settings
89    pub eviction: CacheEvictionConfig,
90    /// Cache directory path
91    pub directory: PathBuf,
92    /// Maximum cache size in bytes
93    pub max_size_bytes: Option<u64>,
94}
95
96/// Logging configuration
97#[derive(Debug, Clone, Serialize, Deserialize, Default)]
98pub struct LoggingConfig {
99    /// Log level
100    #[serde(default)]
101    pub level: LogLevel,
102    /// Log format (json, pretty, compact)
103    #[serde(default)]
104    pub format: LogFormat,
105    /// Log to file
106    pub file: Option<PathBuf>,
107    /// Enable structured logging
108    pub structured: bool,
109}
110
111impl Default for ServerSettings {
112    fn default() -> Self {
113        Self {
114            host: "0.0.0.0".to_string(),
115            port: modelexpress_common::constants::DEFAULT_GRPC_PORT,
116        }
117    }
118}
119
120impl Default for DatabaseSettings {
121    fn default() -> Self {
122        Self {
123            path: PathBuf::from("./models.db"),
124        }
125    }
126}
127
128impl Default for CacheConfig {
129    fn default() -> Self {
130        Self {
131            eviction: CacheEvictionConfig::default(),
132            directory: PathBuf::from("./cache"),
133            max_size_bytes: None,
134        }
135    }
136}
137
138impl ServerConfig {
139    /// Load configuration from multiple sources in order of precedence:
140    /// 1. Command line arguments (highest priority)
141    /// 2. Environment variables
142    /// 3. Configuration file
143    /// 4. Default values (lowest priority)
144    pub fn load(args: ServerArgs) -> Result<Self, ConfigError> {
145        Self::load_internal(args, false)
146    }
147
148    /// Load and validate configuration file strictly without fallbacks.
149    /// This method should be used when validating configuration files.
150    /// It will return an error if the file has invalid syntax or values.
151    pub fn load_and_validate_strict(args: ServerArgs) -> Result<Self, ConfigError> {
152        Self::load_internal(args, true)
153    }
154
155    /// Internal method to load configuration with optional strict mode
156    fn load_internal(args: ServerArgs, strict_mode: bool) -> Result<Self, ConfigError> {
157        let mut config = if strict_mode {
158            // Use strict loading - fail on any configuration errors
159            if let Some(ref config_file) = args.config {
160                // Load file strictly without fallbacks
161                modelexpress_common::config::validate_config_file(config_file)?
162            } else {
163                // No config file specified, use defaults
164                Self::default()
165            }
166        } else {
167            // Use layered config loading with fallbacks to defaults
168            load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?
169        };
170
171        // Apply command line overrides (same for both modes)
172        if let Some(port) = args.port {
173            config.server.port = port;
174        }
175
176        if let Some(host) = args.host {
177            config.server.host = host;
178        }
179
180        if let Some(log_level) = args.log_level {
181            config.logging.level = log_level;
182        }
183
184        if let Some(log_format) = args.log_format {
185            config.logging.format = log_format;
186        }
187
188        if let Some(database_path) = args.database_path {
189            config.database.path = database_path;
190        }
191
192        // Apply cache overrides
193        if let Some(cache_directory) = args.cache_directory {
194            config.cache.directory = cache_directory;
195        }
196
197        if let Some(cache_eviction_enabled) = args.cache_eviction_enabled {
198            config.cache.eviction.enabled = cache_eviction_enabled;
199        }
200
201        // Validate the final configuration
202        config.validate()?;
203
204        Ok(config)
205    }
206
207    /// Validate the configuration
208    pub fn validate(&self) -> Result<(), ConfigError> {
209        // Validate database path parent directory exists
210        if let Some(parent) = self.database.path.parent()
211            && !parent.exists()
212        {
213            return Err(ConfigError::Message(format!(
214                "Database directory does not exist: {}",
215                parent.display()
216            )));
217        }
218
219        // Validate cache directory
220        if let Some(parent) = self.cache.directory.parent()
221            && !parent.exists()
222        {
223            return Err(ConfigError::Message(format!(
224                "Cache directory parent does not exist: {}",
225                parent.display()
226            )));
227        }
228
229        Ok(())
230    }
231
232    /// Get the server socket address
233    pub fn socket_addr(&self) -> Result<SocketAddr, ConfigError> {
234        let addr = format!("{}:{}", self.server.host, self.server.port);
235        addr.parse()
236            .map_err(|e| ConfigError::Message(format!("Invalid server address {addr}: {e}")))
237    }
238
239    /// Get the logging level as a tracing Level
240    pub fn log_level(&self) -> Level {
241        self.logging.level.into()
242    }
243
244    /// Print the configuration for debugging
245    pub fn print_config(&self) {
246        use tracing::info;
247
248        info!("Server Configuration:");
249        info!("  Host: {}", self.server.host);
250        info!("  Port: {}", self.server.port);
251
252        info!("Database Configuration:");
253        info!("  Path: {}", self.database.path.display());
254
255        info!("Cache Configuration:");
256        info!("  Directory: {}", self.cache.directory.display());
257        info!("  Max Size: {:?}", self.cache.max_size_bytes);
258        info!("  Eviction Enabled: {}", self.cache.eviction.enabled);
259        info!(
260            "  Eviction Check Interval: {}s",
261            self.cache.eviction.check_interval.num_seconds()
262        );
263
264        info!("Logging Configuration:");
265        info!("  Level: {}", self.logging.level);
266        info!("  Format: {}", self.logging.format);
267        info!("  File: {:?}", self.logging.file);
268        info!("  Structured: {}", self.logging.structured);
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use chrono::Duration;
276    use clap::Parser;
277    use modelexpress_common::config::{DurationConfig, parse_duration_string};
278    use std::fs;
279    use tempfile::tempdir;
280
281    #[test]
282    fn test_log_level_enum_parsing() {
283        // Test valid log levels
284        let valid_levels = [
285            ("trace", LogLevel::Trace),
286            ("debug", LogLevel::Debug),
287            ("info", LogLevel::Info),
288            ("warn", LogLevel::Warn),
289            ("error", LogLevel::Error),
290        ];
291
292        for (level_str, expected_level) in &valid_levels {
293            let args = vec!["test", "--log-level", level_str];
294            if let Ok(parsed_args) = ServerArgs::try_parse_from(args) {
295                assert_eq!(parsed_args.log_level, Some(*expected_level));
296
297                // Test conversion to string
298                assert_eq!(expected_level.to_string(), *level_str);
299
300                // Test conversion to tracing::Level
301                let tracing_level: Level = (*expected_level).into();
302                // Just verify the conversion works - the exact debug format may vary
303                match expected_level {
304                    LogLevel::Trace => assert_eq!(tracing_level, Level::TRACE),
305                    LogLevel::Debug => assert_eq!(tracing_level, Level::DEBUG),
306                    LogLevel::Info => assert_eq!(tracing_level, Level::INFO),
307                    LogLevel::Warn => assert_eq!(tracing_level, Level::WARN),
308                    LogLevel::Error => assert_eq!(tracing_level, Level::ERROR),
309                }
310            } else {
311                panic!("Failed to parse valid log level: {level_str}");
312            }
313        }
314    }
315
316    #[test]
317    fn test_log_level_enum_invalid() {
318        // Test invalid log level
319        let args = vec!["test", "--log-level", "invalid"];
320        let result = ServerArgs::try_parse_from(args);
321        assert!(result.is_err());
322    }
323
324    #[test]
325    fn test_log_level_display() {
326        assert_eq!(LogLevel::Trace.to_string(), "trace");
327        assert_eq!(LogLevel::Debug.to_string(), "debug");
328        assert_eq!(LogLevel::Info.to_string(), "info");
329        assert_eq!(LogLevel::Warn.to_string(), "warn");
330        assert_eq!(LogLevel::Error.to_string(), "error");
331    }
332
333    #[test]
334    #[allow(clippy::expect_used)]
335    fn test_log_level_from_str() {
336        assert_eq!(
337            "trace"
338                .parse::<LogLevel>()
339                .expect("Failed to parse 'trace'"),
340            LogLevel::Trace
341        );
342        assert_eq!(
343            "debug"
344                .parse::<LogLevel>()
345                .expect("Failed to parse 'debug'"),
346            LogLevel::Debug
347        );
348        assert_eq!(
349            "info".parse::<LogLevel>().expect("Failed to parse 'info'"),
350            LogLevel::Info
351        );
352        assert_eq!(
353            "warn".parse::<LogLevel>().expect("Failed to parse 'warn'"),
354            LogLevel::Warn
355        );
356        assert_eq!(
357            "error"
358                .parse::<LogLevel>()
359                .expect("Failed to parse 'error'"),
360            LogLevel::Error
361        );
362
363        // Test case insensitive
364        assert_eq!(
365            "TRACE"
366                .parse::<LogLevel>()
367                .expect("Failed to parse 'TRACE'"),
368            LogLevel::Trace
369        );
370        assert_eq!(
371            "Info".parse::<LogLevel>().expect("Failed to parse 'Info'"),
372            LogLevel::Info
373        );
374
375        // Test invalid
376        assert!("invalid".parse::<LogLevel>().is_err());
377    }
378
379    #[test]
380    fn test_log_format_display() {
381        assert_eq!(LogFormat::Json.to_string(), "json");
382        assert_eq!(LogFormat::Pretty.to_string(), "pretty");
383        assert_eq!(LogFormat::Compact.to_string(), "compact");
384    }
385
386    #[test]
387    #[allow(clippy::expect_used)]
388    fn test_log_format_from_str() {
389        assert_eq!(
390            "json".parse::<LogFormat>().expect("Failed to parse 'json'"),
391            LogFormat::Json
392        );
393        assert_eq!(
394            "pretty"
395                .parse::<LogFormat>()
396                .expect("Failed to parse 'pretty'"),
397            LogFormat::Pretty
398        );
399        assert_eq!(
400            "compact"
401                .parse::<LogFormat>()
402                .expect("Failed to parse 'compact'"),
403            LogFormat::Compact
404        );
405
406        // Test case insensitive
407        assert_eq!(
408            "JSON".parse::<LogFormat>().expect("Failed to parse 'JSON'"),
409            LogFormat::Json
410        );
411        assert_eq!(
412            "Pretty"
413                .parse::<LogFormat>()
414                .expect("Failed to parse 'Pretty'"),
415            LogFormat::Pretty
416        );
417
418        // Test invalid
419        assert!("invalid".parse::<LogFormat>().is_err());
420    }
421
422    #[test]
423    #[allow(clippy::expect_used)]
424    fn test_parse_duration_string() {
425        // Test various duration formats
426        assert_eq!(
427            parse_duration_string("30m")
428                .expect("Failed to parse 30m")
429                .num_seconds(),
430            30 * 60
431        );
432        assert_eq!(
433            parse_duration_string("45s")
434                .expect("Failed to parse 45s")
435                .num_seconds(),
436            45
437        );
438        assert_eq!(
439            parse_duration_string("1d")
440                .expect("Failed to parse 1d")
441                .num_seconds(),
442            24 * 3600
443        );
444        assert_eq!(
445            parse_duration_string("2h")
446                .expect("Failed to parse 2h")
447                .num_seconds(),
448            2 * 3600
449        );
450        assert_eq!(
451            parse_duration_string("2h30m")
452                .expect("Failed to parse 2h30m")
453                .num_seconds(),
454            2 * 3600 + 30 * 60
455        );
456
457        // Test invalid format
458        assert!(parse_duration_string("invalid").is_err());
459    }
460
461    #[test]
462    fn test_duration_config() {
463        // Test creation
464        let duration_config = DurationConfig::new(Duration::hours(2));
465        assert_eq!(duration_config.num_seconds(), 2 * 3600);
466        assert_eq!(duration_config.as_chrono_duration(), Duration::hours(2));
467
468        // Test hours constructor
469        let duration_config = DurationConfig::hours(3);
470        assert_eq!(duration_config.num_seconds(), 3 * 3600);
471
472        // Test display
473        assert_eq!(duration_config.to_string(), "10800s");
474    }
475
476    #[test]
477    #[allow(clippy::expect_used)]
478    fn test_duration_config_serde() {
479        // Test deserializing from string
480        let json = r#""2h""#;
481        let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
482        assert_eq!(duration_config.num_seconds(), 2 * 3600);
483
484        // Test deserializing from number (seconds)
485        let json = r#"3600"#;
486        let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
487        assert_eq!(duration_config.num_seconds(), 3600);
488
489        // Test serializing (it should serialize as just the number)
490        let duration_config = DurationConfig::hours(1);
491        let serialized = serde_json::to_string(&duration_config).expect("Failed to serialize");
492        assert_eq!(serialized, r#"3600"#);
493    }
494
495    #[test]
496    #[allow(clippy::expect_used)]
497    fn test_server_config_load_and_validate_strict_valid_config() {
498        let temp_dir = tempdir().expect("Failed to create temp dir");
499        let config_file = temp_dir.path().join("valid_server_config.yaml");
500
501        let valid_config = r#"
502            server:
503              host: "127.0.0.1"
504              port: 8002
505            database:
506              path: "./test.db"
507            cache:
508              eviction:
509                enabled: false
510                policy:
511                  type: lru
512                  unused_threshold: "3d"
513                  max_models: 10
514                  min_free_space_bytes: 1000000
515                check_interval: "30m"
516              directory: "./test_cache"
517              max_size_bytes: 5000000
518            logging:
519              level: Debug
520              format: Json
521              file: null
522              structured: true
523        "#;
524
525        fs::write(&config_file, valid_config).expect("Failed to write config file");
526
527        let args = ServerArgs {
528            config: Some(config_file),
529            port: None,
530            host: None,
531            log_level: None,
532            log_format: None,
533            database_path: None,
534            cache_directory: None,
535            cache_eviction_enabled: None,
536            validate_config: false,
537        };
538
539        let result = ServerConfig::load_and_validate_strict(args);
540        assert!(result.is_ok());
541
542        let config = result.expect("Expected successful config parsing");
543        assert_eq!(config.server.host, "127.0.0.1");
544        assert_eq!(config.server.port.get(), 8002);
545        assert_eq!(config.database.path, PathBuf::from("./test.db"));
546        assert!(!config.cache.eviction.enabled);
547        assert_eq!(config.logging.level, LogLevel::Debug);
548        assert_eq!(config.logging.format, LogFormat::Json);
549    }
550
551    #[test]
552    #[allow(clippy::expect_used)]
553    fn test_server_config_load_and_validate_strict_invalid_config() {
554        let temp_dir = tempdir().expect("Failed to create temp dir");
555        let config_file = temp_dir.path().join("invalid_server_config.yaml");
556
557        let invalid_config = r#"
558            server:
559              host: "127.0.0.1"
560              port: 8002
561            database:
562              pat: "./test.db"  # Wrong field name (should be 'path')
563            cache:
564              eviction:
565                enabled: "not_a_boolean"  # Invalid type
566        "#;
567
568        fs::write(&config_file, invalid_config).expect("Failed to write config file");
569
570        let args = ServerArgs {
571            config: Some(config_file),
572            port: None,
573            host: None,
574            log_level: None,
575            log_format: None,
576            database_path: None,
577            cache_directory: None,
578            cache_eviction_enabled: None,
579            validate_config: false,
580        };
581
582        let result = ServerConfig::load_and_validate_strict(args);
583        assert!(result.is_err());
584    }
585
586    #[test]
587    #[allow(clippy::expect_used)]
588    fn test_server_config_load_and_validate_strict_with_cli_overrides() {
589        let temp_dir = tempdir().expect("Failed to create temp dir");
590        let config_file = temp_dir.path().join("override_test_config.yaml");
591
592        let base_config = r#"
593            server:
594              host: "127.0.0.1"
595              port: 8002
596            database:
597              path: "./test.db"
598            cache:
599              eviction:
600                enabled: true
601                policy:
602                  type: lru
603                  unused_threshold: "1d"
604                  max_models: null
605                  min_free_space_bytes: null
606                check_interval: "1h"
607              directory: "./cache"
608              max_size_bytes: null
609            logging:
610              level: Info
611              format: Pretty
612              file: null
613              structured: false
614        "#;
615
616        fs::write(&config_file, base_config).expect("Failed to write config file");
617
618        let args = ServerArgs {
619            config: Some(config_file),
620            port: Some(NonZeroU16::new(9000).expect("9000 is non-zero")),
621            host: Some("0.0.0.0".to_string()),
622            log_level: Some(LogLevel::Error),
623            log_format: Some(LogFormat::Json),
624            database_path: Some(PathBuf::from("./override.db")),
625            cache_directory: Some(PathBuf::from("/tmp/override_cache")),
626            cache_eviction_enabled: Some(false),
627            validate_config: false,
628        };
629
630        let result = ServerConfig::load_and_validate_strict(args);
631        assert!(result.is_ok());
632
633        let config = result.expect("Expected successful config parsing");
634        // CLI overrides should be applied
635        assert_eq!(config.server.host, "0.0.0.0");
636        assert_eq!(config.server.port.get(), 9000);
637        assert_eq!(config.database.path, PathBuf::from("./override.db"));
638        assert_eq!(config.logging.level, LogLevel::Error);
639        assert_eq!(config.logging.format, LogFormat::Json);
640        assert_eq!(config.cache.directory, PathBuf::from("/tmp/override_cache"));
641        assert!(!config.cache.eviction.enabled);
642    }
643
644    #[test]
645    #[allow(clippy::expect_used)]
646    fn test_server_config_load_and_validate_strict_no_config_file() {
647        let args = ServerArgs {
648            config: None,
649            port: Some(NonZeroU16::new(9001).expect("9001 is non-zero")),
650            host: Some("localhost".to_string()),
651            log_level: Some(LogLevel::Warn),
652            log_format: None,
653            database_path: None,
654            cache_directory: None,
655            cache_eviction_enabled: None,
656            validate_config: false,
657        };
658
659        // When no config file is specified, it should fall back to normal loading
660        let result = ServerConfig::load_and_validate_strict(args);
661        assert!(result.is_ok());
662
663        let config = result.expect("Expected successful config parsing");
664        assert_eq!(config.server.host, "localhost");
665        assert_eq!(config.server.port.get(), 9001);
666        assert_eq!(config.logging.level, LogLevel::Warn);
667    }
668}