Skip to main content

modelexpress_common/
config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use chrono::Duration;
5use clap::ValueEnum;
6use config::{Config, ConfigError, Environment, File};
7use serde::{Deserialize, Deserializer, Serialize};
8use std::fmt;
9use std::path::{Path, PathBuf};
10use std::str::FromStr;
11use tracing::{Level, info};
12
13/// Parse a duration string into a `chrono::Duration`.
14/// Supports formats like "2h", "30m", "45s", "1d", etc.
15pub fn parse_duration_string(value: &str) -> Result<Duration, String> {
16    use jiff::{Span, SpanRelativeTo};
17    let span = Span::from_str(value).map_err(|err| format!("Invalid duration: {err}"))?;
18
19    // Convert jiff::Span to chrono::Duration
20    // For spans with days, we need to specify that days are 24 hours
21    let signed_duration = span
22        .to_duration(SpanRelativeTo::days_are_24_hours())
23        .map_err(|err| format!("Invalid duration: {err}"))?;
24
25    let std_duration = std::time::Duration::try_from(signed_duration)
26        .map_err(|err| format!("Invalid duration: {err}"))?;
27
28    Duration::from_std(std_duration).map_err(|err| format!("Duration out of range: {err}"))
29}
30
31/// A wrapper around chrono::Duration that can be deserialized from string or seconds
32#[derive(Debug, Clone)]
33pub struct DurationConfig {
34    duration: Duration,
35}
36
37impl DurationConfig {
38    pub fn new(duration: Duration) -> Self {
39        Self { duration }
40    }
41
42    pub fn hours(hours: i64) -> Self {
43        Self {
44            duration: Duration::hours(hours),
45        }
46    }
47
48    pub fn as_chrono_duration(&self) -> Duration {
49        self.duration
50    }
51
52    pub fn num_seconds(&self) -> i64 {
53        self.duration.num_seconds()
54    }
55}
56
57impl fmt::Display for DurationConfig {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        write!(f, "{}s", self.duration.num_seconds())
60    }
61}
62
63// Serialize as just the number of seconds (not as a struct)
64impl Serialize for DurationConfig {
65    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
66    where
67        S: serde::Serializer,
68    {
69        self.duration.num_seconds().serialize(serializer)
70    }
71}
72
73impl<'de> Deserialize<'de> for DurationConfig {
74    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
75    where
76        D: Deserializer<'de>,
77    {
78        use serde::de::{self, Visitor};
79
80        struct DurationVisitor;
81
82        impl<'de> Visitor<'de> for DurationVisitor {
83            type Value = DurationConfig;
84
85            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
86                formatter
87                    .write_str("a duration string like '2h', '30m', '45s' or number of seconds")
88            }
89
90            fn visit_str<E>(self, value: &str) -> Result<DurationConfig, E>
91            where
92                E: de::Error,
93            {
94                parse_duration_string(value)
95                    .map(DurationConfig::new)
96                    .map_err(de::Error::custom)
97            }
98
99            fn visit_i64<E>(self, value: i64) -> Result<DurationConfig, E>
100            where
101                E: de::Error,
102            {
103                Ok(DurationConfig::new(Duration::seconds(value)))
104            }
105
106            fn visit_u64<E>(self, value: u64) -> Result<DurationConfig, E>
107            where
108                E: de::Error,
109            {
110                Ok(DurationConfig::new(Duration::seconds(value as i64)))
111            }
112        }
113
114        deserializer.deserialize_any(DurationVisitor)
115    }
116}
117
118/// Log level wrapper for clap ValueEnum
119#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Serialize, Deserialize, Default)]
120pub enum LogLevel {
121    Trace,
122    Debug,
123    #[default]
124    Info,
125    Warn,
126    Error,
127}
128
129impl fmt::Display for LogLevel {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        match self {
132            LogLevel::Trace => write!(f, "trace"),
133            LogLevel::Debug => write!(f, "debug"),
134            LogLevel::Info => write!(f, "info"),
135            LogLevel::Warn => write!(f, "warn"),
136            LogLevel::Error => write!(f, "error"),
137        }
138    }
139}
140
141impl From<LogLevel> for Level {
142    fn from(log_level: LogLevel) -> Self {
143        match log_level {
144            LogLevel::Trace => Level::TRACE,
145            LogLevel::Debug => Level::DEBUG,
146            LogLevel::Info => Level::INFO,
147            LogLevel::Warn => Level::WARN,
148            LogLevel::Error => Level::ERROR,
149        }
150    }
151}
152
153impl FromStr for LogLevel {
154    type Err = String;
155
156    fn from_str(s: &str) -> Result<Self, Self::Err> {
157        match s.to_lowercase().as_str() {
158            "trace" => Ok(LogLevel::Trace),
159            "debug" => Ok(LogLevel::Debug),
160            "info" => Ok(LogLevel::Info),
161            "warn" => Ok(LogLevel::Warn),
162            "error" => Ok(LogLevel::Error),
163            _ => Err(format!("Invalid log level: {s}")),
164        }
165    }
166}
167
168/// Log format wrapper for clap ValueEnum
169#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Serialize, Deserialize, Default)]
170pub enum LogFormat {
171    Json,
172    #[default]
173    Pretty,
174    Compact,
175}
176
177impl fmt::Display for LogFormat {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        match self {
180            LogFormat::Json => write!(f, "json"),
181            LogFormat::Pretty => write!(f, "pretty"),
182            LogFormat::Compact => write!(f, "compact"),
183        }
184    }
185}
186
187impl FromStr for LogFormat {
188    type Err = String;
189
190    fn from_str(s: &str) -> Result<Self, Self::Err> {
191        match s.to_lowercase().as_str() {
192            "json" => Ok(LogFormat::Json),
193            "pretty" => Ok(LogFormat::Pretty),
194            "compact" => Ok(LogFormat::Compact),
195            _ => Err(format!("Invalid log format: {s}")),
196        }
197    }
198}
199
200/// Base trait for configuration loading with layered approach
201pub trait ConfigLoader<T> {
202    /// Load configuration from multiple sources in order of precedence:
203    /// 1. Command line arguments (highest priority)
204    /// 2. Environment variables
205    /// 3. Configuration file
206    /// 4. Default values (lowest priority)
207    fn load_layered(
208        config_file: Option<PathBuf>,
209        env_prefix: &str,
210        defaults: T,
211    ) -> Result<T, ConfigError>
212    where
213        T: serde::de::DeserializeOwned + Default;
214}
215
216/// Load configuration file strictly without any fallbacks to defaults.
217/// This function will return an error if the file doesn't exist, has invalid syntax,
218/// or contains invalid values. Use this for validation purposes.
219pub fn load_config_file_strict<T>(config_file: &Path) -> Result<T, ConfigError>
220where
221    T: serde::de::DeserializeOwned,
222{
223    if !config_file.exists() {
224        return Err(ConfigError::Message(format!(
225            "Configuration file not found: {}",
226            config_file.display()
227        )));
228    }
229
230    let config = Config::builder()
231        .add_source(File::from(config_file.to_path_buf()))
232        .build()?;
233
234    config.try_deserialize::<T>()
235}
236
237fn discover_default_config() -> Option<PathBuf> {
238    let default_configs = [
239        "model-express.yaml",
240        "model-express.yml",
241        "/etc/model-express/config.yaml",
242        "/etc/model-express/config.yml",
243    ];
244
245    for config_path in &default_configs {
246        if PathBuf::from(config_path).exists() {
247            return Some(PathBuf::from(config_path));
248        }
249    }
250    None
251}
252
253/// Load configuration with strict file parsing but with environment variable overrides.
254/// This is used internally by both strict validation and normal loading with fallbacks.
255fn load_config_with_env_strict<T>(
256    config_file: Option<PathBuf>,
257    env_prefix: &str,
258) -> Result<T, ConfigError>
259where
260    T: serde::de::DeserializeOwned,
261{
262    let mut builder = Config::builder();
263
264    // Only load config file if explicitly provided
265    if let Some(config_path) = &config_file {
266        if !config_path.exists() {
267            return Err(ConfigError::Message(format!(
268                "Configuration file not found: {}",
269                config_path.display()
270            )));
271        }
272        builder = builder.add_source(File::from(config_path.clone()));
273    } else if let Some(default_path) = discover_default_config() {
274        info!("Using default config: {}", default_path.display());
275        builder = builder.add_source(File::from(default_path));
276    } else {
277        return Err(ConfigError::Message(
278            "No configuration file specified and no default config found. \
279             Please specify a config file with --config or create a default config."
280                .to_string(),
281        ));
282    }
283
284    // Add environment variables
285    builder = builder.add_source(
286        Environment::with_prefix(env_prefix)
287            .try_parsing(true)
288            .separator("_"),
289    );
290
291    let config = builder.build()?;
292    config.try_deserialize::<T>()
293}
294
295/// Validate a configuration file by attempting to parse it strictly.
296/// Returns detailed error information if the file is invalid.
297pub fn validate_config_file<T>(config_file: &Path) -> Result<T, ConfigError>
298where
299    T: serde::de::DeserializeOwned,
300{
301    load_config_file_strict(config_file)
302}
303
304/// Default implementation of layered configuration loading with fallback to defaults
305pub fn load_layered_config<T>(
306    config_file: Option<PathBuf>,
307    env_prefix: &str,
308    defaults: T,
309) -> Result<T, ConfigError>
310where
311    T: serde::de::DeserializeOwned + Default,
312{
313    // Try to load configuration strictly first
314    match load_config_with_env_strict(config_file, env_prefix) {
315        Ok(config) => Ok(config),
316        Err(_) => {
317            // If strict loading fails, fall back to defaults
318            // This provides a safe fallback for partial configurations or errors
319            Ok(defaults)
320        }
321    }
322}
323
324/// Common configuration for client connections
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct ConnectionConfig {
327    /// The endpoint to connect to
328    pub endpoint: String,
329
330    /// Timeout in seconds for requests
331    pub timeout_secs: Option<u64>,
332
333    /// Maximum retries for failed requests
334    pub max_retries: Option<u32>,
335
336    /// Retry delay in seconds
337    pub retry_delay_secs: Option<u64>,
338}
339
340impl Default for ConnectionConfig {
341    fn default() -> Self {
342        Self {
343            endpoint: format!("http://localhost:{}", crate::constants::DEFAULT_GRPC_PORT),
344            timeout_secs: Some(crate::constants::DEFAULT_TIMEOUT_SECS),
345            max_retries: Some(3),
346            retry_delay_secs: Some(1),
347        }
348    }
349}
350
351impl ConnectionConfig {
352    pub fn new(endpoint: impl Into<String>) -> Self {
353        Self {
354            endpoint: endpoint.into(),
355            timeout_secs: Some(crate::constants::DEFAULT_TIMEOUT_SECS),
356            max_retries: Some(3),
357            retry_delay_secs: Some(1),
358        }
359    }
360
361    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
362        self.timeout_secs = Some(timeout_secs);
363        self
364    }
365
366    pub fn with_retries(mut self, max_retries: u32, delay_secs: u64) -> Self {
367        self.max_retries = Some(max_retries);
368        self.retry_delay_secs = Some(delay_secs);
369        self
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::fs;
377    use tempfile::tempdir;
378
379    #[test]
380    fn test_duration_config_from_string() {
381        match parse_duration_string("2h") {
382            Ok(duration) => assert_eq!(duration.num_hours(), 2),
383            Err(e) => panic!("Failed to parse duration '2h': {e}"),
384        }
385    }
386
387    #[test]
388    fn test_log_level_from_string() {
389        match "info".parse::<LogLevel>() {
390            Ok(level) => assert_eq!(level, LogLevel::Info),
391            Err(e) => panic!("Failed to parse 'info' as LogLevel: {e}"),
392        }
393        match "debug".parse::<LogLevel>() {
394            Ok(level) => assert_eq!(level, LogLevel::Debug),
395            Err(e) => panic!("Failed to parse 'debug' as LogLevel: {e}"),
396        }
397    }
398
399    #[test]
400    fn test_log_format_from_string() {
401        match "json".parse::<LogFormat>() {
402            Ok(format) => assert_eq!(format, LogFormat::Json),
403            Err(e) => panic!("Failed to parse 'json' as LogFormat: {e}"),
404        }
405        match "pretty".parse::<LogFormat>() {
406            Ok(format) => assert_eq!(format, LogFormat::Pretty),
407            Err(e) => panic!("Failed to parse 'pretty' as LogFormat: {e}"),
408        }
409    }
410
411    #[test]
412    fn test_connection_config_default() {
413        let config = ConnectionConfig::default();
414        assert!(config.endpoint.contains("8001"));
415        assert_eq!(config.timeout_secs, Some(30));
416    }
417
418    #[test]
419    fn test_connection_config_builder() {
420        let config = ConnectionConfig::new("http://test.com:8080")
421            .with_timeout(60)
422            .with_retries(5, 2);
423
424        assert_eq!(config.endpoint, "http://test.com:8080");
425        assert_eq!(config.timeout_secs, Some(60));
426        assert_eq!(config.max_retries, Some(5));
427        assert_eq!(config.retry_delay_secs, Some(2));
428    }
429
430    #[test]
431    #[allow(clippy::expect_used)]
432    fn test_load_config_file_strict_missing_file() {
433        let non_existent_file = PathBuf::from("/non/existent/file.yaml");
434        let result: Result<ConnectionConfig, ConfigError> =
435            load_config_file_strict(&non_existent_file);
436
437        assert!(result.is_err());
438        let error_message = result
439            .expect_err("Expected error for missing file")
440            .to_string();
441        assert!(error_message.contains("Configuration file not found"));
442    }
443
444    #[test]
445    #[allow(clippy::expect_used)]
446    fn test_load_config_file_strict_valid_file() {
447        let temp_dir = tempdir().expect("Failed to create temp dir");
448        let config_file = temp_dir.path().join("test_config.yaml");
449
450        let valid_config = r#"
451            endpoint: "http://localhost:9999"
452            timeout_secs: 60
453            max_retries: 5
454            retry_delay_secs: 2
455        "#;
456
457        fs::write(&config_file, valid_config).expect("Failed to write config file");
458
459        let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
460        assert!(result.is_ok());
461
462        let config = result.expect("Expected successful config parsing");
463        assert_eq!(config.endpoint, "http://localhost:9999");
464        assert_eq!(config.timeout_secs, Some(60));
465        assert_eq!(config.max_retries, Some(5));
466        assert_eq!(config.retry_delay_secs, Some(2));
467    }
468
469    #[test]
470    #[allow(clippy::expect_used)]
471    fn test_load_config_file_strict_invalid_yaml() {
472        let temp_dir = tempdir().expect("Failed to create temp dir");
473        let config_file = temp_dir.path().join("invalid_config.yaml");
474
475        let invalid_config = r#"
476            endpoint: "http://localhost:9999"
477            timeout_secs: not_a_number
478            invalid_yaml_structure:
479                missing_indent
480        "#;
481
482        fs::write(&config_file, invalid_config).expect("Failed to write config file");
483
484        let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
485        assert!(result.is_err());
486    }
487
488    #[test]
489    #[allow(clippy::expect_used)]
490    fn test_load_config_file_strict_wrong_type() {
491        let temp_dir = tempdir().expect("Failed to create temp dir");
492        let config_file = temp_dir.path().join("wrong_type_config.yaml");
493
494        let wrong_type_config = r#"
495            endpoint: "http://localhost:9999"
496            timeout_secs: "this_should_be_a_number"
497        "#;
498
499        fs::write(&config_file, wrong_type_config).expect("Failed to write config file");
500
501        let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
502        assert!(result.is_err());
503    }
504
505    #[test]
506    #[allow(clippy::expect_used)]
507    fn test_validate_config_file_same_as_strict() {
508        let temp_dir = tempdir().expect("Failed to create temp dir");
509        let config_file = temp_dir.path().join("test_config.yaml");
510
511        let valid_config = r#"
512            endpoint: "http://localhost:9999"
513            timeout_secs: 60
514        "#;
515
516        fs::write(&config_file, valid_config).expect("Failed to write config file");
517
518        let strict_result: Result<ConnectionConfig, ConfigError> =
519            load_config_file_strict(&config_file);
520        let validate_result: Result<ConnectionConfig, ConfigError> =
521            validate_config_file(&config_file);
522
523        assert!(strict_result.is_ok());
524        assert!(validate_result.is_ok());
525
526        let strict_config = strict_result.expect("Expected successful strict config parsing");
527        let validate_config = validate_result.expect("Expected successful validate config parsing");
528
529        assert_eq!(strict_config.endpoint, validate_config.endpoint);
530        assert_eq!(strict_config.timeout_secs, validate_config.timeout_secs);
531    }
532}