Skip to main content

modelexpress_server/
config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use clap::Parser;
5use config::ConfigError;
6use modelexpress_common::config::{LogFormat, LogLevel, load_layered_config};
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::num::NonZeroU16;
10use std::path::PathBuf;
11use tracing::Level;
12
13use crate::cache::CacheEvictionConfig;
14
15/// Command line arguments for the server
16#[derive(Parser, Debug)]
17#[command(author, version, about, long_about = None)]
18pub struct ServerArgs {
19    /// Configuration file path
20    #[arg(short, long, value_name = "FILE")]
21    pub config: Option<PathBuf>,
22
23    /// Server port
24    #[arg(short, long, env = "MODEL_EXPRESS_SERVER_PORT")]
25    pub port: Option<NonZeroU16>,
26
27    /// Server host address
28    #[arg(long, env = "MODEL_EXPRESS_SERVER_HOST")]
29    pub host: Option<String>,
30
31    /// Log level
32    #[arg(short, long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
33    pub log_level: Option<LogLevel>,
34
35    /// Log format
36    #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
37    pub log_format: Option<LogFormat>,
38
39    /// Cache directory path
40    #[arg(long, env = "MODEL_EXPRESS_CACHE_DIRECTORY")]
41    pub cache_directory: Option<PathBuf>,
42
43    /// Enable cache eviction
44    #[arg(long, env = "MODEL_EXPRESS_CACHE_EVICTION_ENABLED")]
45    pub cache_eviction_enabled: Option<bool>,
46
47    /// Validate configuration and exit
48    #[arg(long)]
49    pub validate_config: bool,
50}
51
52/// Complete server configuration
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct ServerConfig {
55    /// Server settings
56    pub server: ServerSettings,
57    /// Cache configuration
58    pub cache: CacheConfig,
59    /// Logging configuration
60    pub logging: LoggingConfig,
61}
62
63/// Server-specific settings
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ServerSettings {
66    /// Server host address
67    pub host: String,
68    /// Server port
69    pub port: NonZeroU16,
70}
71
72/// Cache configuration wrapper
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct CacheConfig {
75    /// Cache eviction settings
76    pub eviction: CacheEvictionConfig,
77    /// Cache directory path
78    pub directory: PathBuf,
79    /// Maximum cache size in bytes
80    pub max_size_bytes: Option<u64>,
81}
82
83/// Logging configuration
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct LoggingConfig {
86    /// Log level
87    #[serde(default)]
88    pub level: LogLevel,
89    /// Log format (json, pretty, compact)
90    #[serde(default)]
91    pub format: LogFormat,
92    /// Log to file
93    pub file: Option<PathBuf>,
94    /// Enable structured logging
95    pub structured: bool,
96}
97
98impl Default for ServerSettings {
99    fn default() -> Self {
100        Self {
101            host: "0.0.0.0".to_string(),
102            port: modelexpress_common::constants::DEFAULT_GRPC_PORT,
103        }
104    }
105}
106
107impl Default for CacheConfig {
108    fn default() -> Self {
109        Self {
110            eviction: CacheEvictionConfig::default(),
111            directory: PathBuf::from("./cache"),
112            max_size_bytes: None,
113        }
114    }
115}
116
117impl ServerConfig {
118    /// Load configuration from multiple sources in order of precedence:
119    /// 1. Command line arguments (highest priority)
120    /// 2. Environment variables
121    /// 3. Configuration file
122    /// 4. Default values (lowest priority)
123    pub fn load(args: ServerArgs) -> Result<Self, ConfigError> {
124        Self::load_internal(args, false)
125    }
126
127    /// Load and validate configuration file strictly without fallbacks.
128    /// This method should be used when validating configuration files.
129    /// It will return an error if the file has invalid syntax or values.
130    pub fn load_and_validate_strict(args: ServerArgs) -> Result<Self, ConfigError> {
131        Self::load_internal(args, true)
132    }
133
134    /// Internal method to load configuration with optional strict mode
135    fn load_internal(args: ServerArgs, strict_mode: bool) -> Result<Self, ConfigError> {
136        let mut config = if strict_mode {
137            // Use strict loading - fail on any configuration errors
138            if let Some(ref config_file) = args.config {
139                // Load file strictly without fallbacks
140                modelexpress_common::config::validate_config_file(config_file)?
141            } else {
142                // No config file specified, use defaults
143                Self::default()
144            }
145        } else {
146            // Use layered config loading with fallbacks to defaults
147            load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?
148        };
149
150        // Apply command line overrides (same for both modes)
151        if let Some(port) = args.port {
152            config.server.port = port;
153        }
154
155        if let Some(host) = args.host {
156            config.server.host = host;
157        }
158
159        if let Some(log_level) = args.log_level {
160            config.logging.level = log_level;
161        }
162
163        if let Some(log_format) = args.log_format {
164            config.logging.format = log_format;
165        }
166
167        // Apply cache overrides
168        if let Some(cache_directory) = args.cache_directory {
169            config.cache.directory = cache_directory;
170        }
171
172        if let Some(cache_eviction_enabled) = args.cache_eviction_enabled {
173            config.cache.eviction.enabled = cache_eviction_enabled;
174        }
175
176        // Validate the final configuration
177        config.validate()?;
178
179        Ok(config)
180    }
181
182    /// Validate the configuration
183    pub fn validate(&self) -> Result<(), ConfigError> {
184        // Validate cache directory
185        if let Some(parent) = self.cache.directory.parent()
186            && !parent.exists()
187        {
188            return Err(ConfigError::Message(format!(
189                "Cache directory parent does not exist: {}",
190                parent.display()
191            )));
192        }
193
194        Ok(())
195    }
196
197    /// Get the server socket address
198    pub fn socket_addr(&self) -> Result<SocketAddr, ConfigError> {
199        let addr = format!("{}:{}", self.server.host, self.server.port);
200        addr.parse()
201            .map_err(|e| ConfigError::Message(format!("Invalid server address {addr}: {e}")))
202    }
203
204    /// Get the logging level as a tracing Level
205    pub fn log_level(&self) -> Level {
206        self.logging.level.into()
207    }
208
209    /// Print the configuration for debugging
210    pub fn print_config(&self) {
211        use tracing::info;
212
213        info!("Server Configuration:");
214        info!("  Host: {}", self.server.host);
215        info!("  Port: {}", self.server.port);
216
217        info!("Cache Configuration:");
218        info!("  Directory: {}", self.cache.directory.display());
219        info!("  Max Size: {:?}", self.cache.max_size_bytes);
220        info!("  Eviction Enabled: {}", self.cache.eviction.enabled);
221        info!(
222            "  Eviction Check Interval: {}s",
223            self.cache.eviction.check_interval.num_seconds()
224        );
225
226        info!("Logging Configuration:");
227        info!("  Level: {}", self.logging.level);
228        info!("  Format: {}", self.logging.format);
229        info!("  File: {:?}", self.logging.file);
230        info!("  Structured: {}", self.logging.structured);
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use chrono::Duration;
238    use clap::Parser;
239    use modelexpress_common::config::{DurationConfig, parse_duration_string};
240    use std::fs;
241    use tempfile::tempdir;
242
243    #[test]
244    fn test_log_level_enum_parsing() {
245        // Test valid log levels
246        let valid_levels = [
247            ("trace", LogLevel::Trace),
248            ("debug", LogLevel::Debug),
249            ("info", LogLevel::Info),
250            ("warn", LogLevel::Warn),
251            ("error", LogLevel::Error),
252        ];
253
254        for (level_str, expected_level) in &valid_levels {
255            let args = vec!["test", "--log-level", level_str];
256            if let Ok(parsed_args) = ServerArgs::try_parse_from(args) {
257                assert_eq!(parsed_args.log_level, Some(*expected_level));
258
259                // Test conversion to string
260                assert_eq!(expected_level.to_string(), *level_str);
261
262                // Test conversion to tracing::Level
263                let tracing_level: Level = (*expected_level).into();
264                // Just verify the conversion works - the exact debug format may vary
265                match expected_level {
266                    LogLevel::Trace => assert_eq!(tracing_level, Level::TRACE),
267                    LogLevel::Debug => assert_eq!(tracing_level, Level::DEBUG),
268                    LogLevel::Info => assert_eq!(tracing_level, Level::INFO),
269                    LogLevel::Warn => assert_eq!(tracing_level, Level::WARN),
270                    LogLevel::Error => assert_eq!(tracing_level, Level::ERROR),
271                }
272            } else {
273                panic!("Failed to parse valid log level: {level_str}");
274            }
275        }
276    }
277
278    #[test]
279    fn test_log_level_enum_invalid() {
280        // Test invalid log level
281        let args = vec!["test", "--log-level", "invalid"];
282        let result = ServerArgs::try_parse_from(args);
283        assert!(result.is_err());
284    }
285
286    #[test]
287    fn test_log_level_display() {
288        assert_eq!(LogLevel::Trace.to_string(), "trace");
289        assert_eq!(LogLevel::Debug.to_string(), "debug");
290        assert_eq!(LogLevel::Info.to_string(), "info");
291        assert_eq!(LogLevel::Warn.to_string(), "warn");
292        assert_eq!(LogLevel::Error.to_string(), "error");
293    }
294
295    #[test]
296    #[allow(clippy::expect_used)]
297    fn test_log_level_from_str() {
298        assert_eq!(
299            "trace"
300                .parse::<LogLevel>()
301                .expect("Failed to parse 'trace'"),
302            LogLevel::Trace
303        );
304        assert_eq!(
305            "debug"
306                .parse::<LogLevel>()
307                .expect("Failed to parse 'debug'"),
308            LogLevel::Debug
309        );
310        assert_eq!(
311            "info".parse::<LogLevel>().expect("Failed to parse 'info'"),
312            LogLevel::Info
313        );
314        assert_eq!(
315            "warn".parse::<LogLevel>().expect("Failed to parse 'warn'"),
316            LogLevel::Warn
317        );
318        assert_eq!(
319            "error"
320                .parse::<LogLevel>()
321                .expect("Failed to parse 'error'"),
322            LogLevel::Error
323        );
324
325        // Test case insensitive
326        assert_eq!(
327            "TRACE"
328                .parse::<LogLevel>()
329                .expect("Failed to parse 'TRACE'"),
330            LogLevel::Trace
331        );
332        assert_eq!(
333            "Info".parse::<LogLevel>().expect("Failed to parse 'Info'"),
334            LogLevel::Info
335        );
336
337        // Test invalid
338        assert!("invalid".parse::<LogLevel>().is_err());
339    }
340
341    #[test]
342    fn test_log_format_display() {
343        assert_eq!(LogFormat::Json.to_string(), "json");
344        assert_eq!(LogFormat::Pretty.to_string(), "pretty");
345        assert_eq!(LogFormat::Compact.to_string(), "compact");
346    }
347
348    #[test]
349    #[allow(clippy::expect_used)]
350    fn test_log_format_from_str() {
351        assert_eq!(
352            "json".parse::<LogFormat>().expect("Failed to parse 'json'"),
353            LogFormat::Json
354        );
355        assert_eq!(
356            "pretty"
357                .parse::<LogFormat>()
358                .expect("Failed to parse 'pretty'"),
359            LogFormat::Pretty
360        );
361        assert_eq!(
362            "compact"
363                .parse::<LogFormat>()
364                .expect("Failed to parse 'compact'"),
365            LogFormat::Compact
366        );
367
368        // Test case insensitive
369        assert_eq!(
370            "JSON".parse::<LogFormat>().expect("Failed to parse 'JSON'"),
371            LogFormat::Json
372        );
373        assert_eq!(
374            "Pretty"
375                .parse::<LogFormat>()
376                .expect("Failed to parse 'Pretty'"),
377            LogFormat::Pretty
378        );
379
380        // Test invalid
381        assert!("invalid".parse::<LogFormat>().is_err());
382    }
383
384    #[test]
385    #[allow(clippy::expect_used)]
386    fn test_parse_duration_string() {
387        // Test various duration formats
388        assert_eq!(
389            parse_duration_string("30m")
390                .expect("Failed to parse 30m")
391                .num_seconds(),
392            30 * 60
393        );
394        assert_eq!(
395            parse_duration_string("45s")
396                .expect("Failed to parse 45s")
397                .num_seconds(),
398            45
399        );
400        assert_eq!(
401            parse_duration_string("1d")
402                .expect("Failed to parse 1d")
403                .num_seconds(),
404            24 * 3600
405        );
406        assert_eq!(
407            parse_duration_string("2h")
408                .expect("Failed to parse 2h")
409                .num_seconds(),
410            2 * 3600
411        );
412        assert_eq!(
413            parse_duration_string("2h30m")
414                .expect("Failed to parse 2h30m")
415                .num_seconds(),
416            2 * 3600 + 30 * 60
417        );
418
419        // Test invalid format
420        assert!(parse_duration_string("invalid").is_err());
421    }
422
423    #[test]
424    fn test_duration_config() {
425        // Test creation
426        let duration_config = DurationConfig::new(Duration::hours(2));
427        assert_eq!(duration_config.num_seconds(), 2 * 3600);
428        assert_eq!(duration_config.as_chrono_duration(), Duration::hours(2));
429
430        // Test hours constructor
431        let duration_config = DurationConfig::hours(3);
432        assert_eq!(duration_config.num_seconds(), 3 * 3600);
433
434        // Test display
435        assert_eq!(duration_config.to_string(), "10800s");
436    }
437
438    #[test]
439    #[allow(clippy::expect_used)]
440    fn test_duration_config_serde() {
441        // Test deserializing from string
442        let json = r#""2h""#;
443        let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
444        assert_eq!(duration_config.num_seconds(), 2 * 3600);
445
446        // Test deserializing from number (seconds)
447        let json = r#"3600"#;
448        let duration_config: DurationConfig = serde_json::from_str(json).expect("Failed to parse");
449        assert_eq!(duration_config.num_seconds(), 3600);
450
451        // Test serializing (it should serialize as just the number)
452        let duration_config = DurationConfig::hours(1);
453        let serialized = serde_json::to_string(&duration_config).expect("Failed to serialize");
454        assert_eq!(serialized, r#"3600"#);
455    }
456
457    #[test]
458    #[allow(clippy::expect_used)]
459    fn test_server_config_load_and_validate_strict_valid_config() {
460        let temp_dir = tempdir().expect("Failed to create temp dir");
461        let config_file = temp_dir.path().join("valid_server_config.yaml");
462
463        let valid_config = r#"
464            server:
465              host: "127.0.0.1"
466              port: 8002
467            cache:
468              eviction:
469                enabled: false
470                policy:
471                  type: lru
472                  unused_threshold: "3d"
473                  max_models: 10
474                  min_free_space_bytes: 1000000
475                check_interval: "30m"
476              directory: "./test_cache"
477              max_size_bytes: 5000000
478            logging:
479              level: Debug
480              format: Json
481              file: null
482              structured: true
483        "#;
484
485        fs::write(&config_file, valid_config).expect("Failed to write config file");
486
487        let args = ServerArgs {
488            config: Some(config_file),
489            port: None,
490            host: None,
491            log_level: None,
492            log_format: None,
493            cache_directory: None,
494            cache_eviction_enabled: None,
495            validate_config: false,
496        };
497
498        let result = ServerConfig::load_and_validate_strict(args);
499        assert!(result.is_ok());
500
501        let config = result.expect("Expected successful config parsing");
502        assert_eq!(config.server.host, "127.0.0.1");
503        assert_eq!(config.server.port.get(), 8002);
504        assert!(!config.cache.eviction.enabled);
505        assert_eq!(config.logging.level, LogLevel::Debug);
506        assert_eq!(config.logging.format, LogFormat::Json);
507    }
508
509    #[test]
510    #[allow(clippy::expect_used)]
511    fn test_server_config_load_and_validate_strict_invalid_config() {
512        let temp_dir = tempdir().expect("Failed to create temp dir");
513        let config_file = temp_dir.path().join("invalid_server_config.yaml");
514
515        let invalid_config = r#"
516            server:
517              host: "127.0.0.1"
518              port: 8002
519            cache:
520              eviction:
521                enabled: "not_a_boolean"  # Invalid type
522        "#;
523
524        fs::write(&config_file, invalid_config).expect("Failed to write config file");
525
526        let args = ServerArgs {
527            config: Some(config_file),
528            port: None,
529            host: None,
530            log_level: None,
531            log_format: None,
532            cache_directory: None,
533            cache_eviction_enabled: None,
534            validate_config: false,
535        };
536
537        let result = ServerConfig::load_and_validate_strict(args);
538        assert!(result.is_err());
539    }
540
541    #[test]
542    #[allow(clippy::expect_used)]
543    fn test_server_config_load_and_validate_strict_with_cli_overrides() {
544        let temp_dir = tempdir().expect("Failed to create temp dir");
545        let config_file = temp_dir.path().join("override_test_config.yaml");
546
547        let base_config = r#"
548            server:
549              host: "127.0.0.1"
550              port: 8002
551            cache:
552              eviction:
553                enabled: true
554                policy:
555                  type: lru
556                  unused_threshold: "1d"
557                  max_models: null
558                  min_free_space_bytes: null
559                check_interval: "1h"
560              directory: "./cache"
561              max_size_bytes: null
562            logging:
563              level: Info
564              format: Pretty
565              file: null
566              structured: false
567        "#;
568
569        fs::write(&config_file, base_config).expect("Failed to write config file");
570
571        let args = ServerArgs {
572            config: Some(config_file),
573            port: Some(NonZeroU16::new(9000).expect("9000 is non-zero")),
574            host: Some("0.0.0.0".to_string()),
575            log_level: Some(LogLevel::Error),
576            log_format: Some(LogFormat::Json),
577            cache_directory: Some(PathBuf::from("/tmp/override_cache")),
578            cache_eviction_enabled: Some(false),
579            validate_config: false,
580        };
581
582        let result = ServerConfig::load_and_validate_strict(args);
583        assert!(result.is_ok());
584
585        let config = result.expect("Expected successful config parsing");
586        // CLI overrides should be applied
587        assert_eq!(config.server.host, "0.0.0.0");
588        assert_eq!(config.server.port.get(), 9000);
589        assert_eq!(config.logging.level, LogLevel::Error);
590        assert_eq!(config.logging.format, LogFormat::Json);
591        assert_eq!(config.cache.directory, PathBuf::from("/tmp/override_cache"));
592        assert!(!config.cache.eviction.enabled);
593    }
594
595    #[test]
596    #[allow(clippy::expect_used)]
597    fn test_server_config_load_and_validate_strict_no_config_file() {
598        let args = ServerArgs {
599            config: None,
600            port: Some(NonZeroU16::new(9001).expect("9001 is non-zero")),
601            host: Some("localhost".to_string()),
602            log_level: Some(LogLevel::Warn),
603            log_format: None,
604            cache_directory: None,
605            cache_eviction_enabled: None,
606            validate_config: false,
607        };
608
609        // When no config file is specified, it should fall back to normal loading
610        let result = ServerConfig::load_and_validate_strict(args);
611        assert!(result.is_ok());
612
613        let config = result.expect("Expected successful config parsing");
614        assert_eq!(config.server.host, "localhost");
615        assert_eq!(config.server.port.get(), 9001);
616        assert_eq!(config.logging.level, LogLevel::Warn);
617    }
618}