modelexpress_common/
client_config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::cache::CacheConfig;
5use crate::config::{ConnectionConfig, LogFormat, LogLevel, load_layered_config};
6use anyhow::Result;
7use clap::Parser;
8use config::ConfigError;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12/// Command line arguments for the client
13#[derive(Parser, Debug)]
14#[command(author, version, about, long_about = None)]
15pub struct ClientArgs {
16    /// Configuration file path
17    #[arg(short, long, value_name = "FILE")]
18    pub config: Option<PathBuf>,
19
20    /// Server endpoint
21    #[arg(short, long, env = "MODEL_EXPRESS_ENDPOINT")]
22    pub endpoint: Option<String>,
23
24    /// Request timeout in seconds
25    #[arg(short, long, env = "MODEL_EXPRESS_TIMEOUT")]
26    pub timeout: Option<u64>,
27
28    /// Cache path override
29    #[arg(long, env = "MODEL_EXPRESS_CACHE_PATH")]
30    pub cache_path: Option<PathBuf>,
31
32    /// Log level
33    #[arg(short = 'v', long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
34    pub log_level: Option<LogLevel>,
35
36    /// Log format
37    #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
38    pub log_format: Option<LogFormat>,
39
40    /// Quiet mode (suppress all output except errors)
41    #[arg(long, short = 'q')]
42    pub quiet: bool,
43
44    /// Maximum number of retries
45    #[arg(long, env = "MODEL_EXPRESS_MAX_RETRIES")]
46    pub max_retries: Option<u32>,
47
48    /// Retry delay in seconds
49    #[arg(long, env = "MODEL_EXPRESS_RETRY_DELAY")]
50    pub retry_delay: Option<u64>,
51}
52
53/// Complete client configuration
54#[derive(Debug, Clone, Serialize, Deserialize, Default)]
55pub struct ClientConfig {
56    /// Connection settings
57    pub connection: ConnectionConfig,
58    /// Cache configuration
59    pub cache: CacheConfig,
60    /// Logging configuration
61    pub logging: LoggingConfig,
62}
63
64/// Logging configuration for the client
65#[derive(Debug, Clone, Serialize, Deserialize, Default)]
66pub struct LoggingConfig {
67    /// Log level
68    #[serde(default)]
69    pub level: LogLevel,
70    /// Log format
71    #[serde(default)]
72    pub format: LogFormat,
73    /// Quiet mode
74    pub quiet: bool,
75}
76
77impl ClientConfig {
78    /// Load configuration from multiple sources in order of precedence:
79    /// 1. Command line arguments (highest priority)
80    /// 2. Environment variables
81    /// 3. Configuration file
82    /// 4. Default values (lowest priority)
83    pub fn load(args: ClientArgs) -> Result<Self, ConfigError> {
84        // Start with layered config loading (file + env + defaults)
85        let mut config =
86            load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?;
87
88        // Override with command line arguments
89        if let Some(endpoint) = args.endpoint {
90            config.connection.endpoint = endpoint;
91        }
92
93        if let Some(timeout) = args.timeout {
94            config.connection.timeout_secs = Some(timeout);
95        }
96
97        if let Some(max_retries) = args.max_retries {
98            config.connection.max_retries = Some(max_retries);
99        }
100
101        if let Some(retry_delay) = args.retry_delay {
102            config.connection.retry_delay_secs = Some(retry_delay);
103        }
104
105        if let Some(cache_path) = args.cache_path {
106            config.cache.local_path = cache_path;
107        }
108
109        if let Some(log_level) = args.log_level {
110            config.logging.level = log_level;
111        }
112
113        if let Some(log_format) = args.log_format {
114            config.logging.format = log_format;
115        }
116
117        if args.quiet {
118            config.logging.quiet = true;
119        }
120
121        // Validate configuration
122        config.validate()?;
123
124        Ok(config)
125    }
126
127    /// Validate the configuration
128    pub fn validate(&self) -> Result<(), ConfigError> {
129        // Validate endpoint
130        if self.connection.endpoint.is_empty() {
131            return Err(ConfigError::Message(
132                "Server endpoint cannot be empty".to_string(),
133            ));
134        }
135
136        // Validate timeout
137        if let Some(timeout) = self.connection.timeout_secs
138            && timeout == 0
139        {
140            return Err(ConfigError::Message(
141                "Timeout must be greater than 0".to_string(),
142            ));
143        }
144
145        // Validate cache path exists or can be created
146        if !self.cache.local_path.exists()
147            && let Err(e) = std::fs::create_dir_all(&self.cache.local_path)
148        {
149            return Err(ConfigError::Message(format!(
150                "Cannot create cache directory {:?}: {}",
151                self.cache.local_path, e
152            )));
153        }
154
155        Ok(())
156    }
157
158    /// Get the gRPC endpoint for backward compatibility
159    pub fn grpc_endpoint(&self) -> &str {
160        &self.connection.endpoint
161    }
162
163    /// Get the timeout in seconds for backward compatibility
164    pub fn timeout_secs(&self) -> Option<u64> {
165        self.connection.timeout_secs
166    }
167
168    /// Create a simple client config for testing
169    pub fn for_testing(endpoint: impl Into<String>) -> Self {
170        Self {
171            connection: ConnectionConfig::new(endpoint),
172            cache: CacheConfig::default(),
173            logging: LoggingConfig::default(),
174        }
175    }
176
177    /// Apply cache path override if provided
178    pub fn with_cache_path(mut self, cache_path: Option<PathBuf>) -> Self {
179        if let Some(path) = cache_path {
180            self.cache.local_path = path;
181        }
182        self
183    }
184
185    /// Set timeout for the connection
186    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
187        self.connection.timeout_secs = Some(timeout_secs);
188        self
189    }
190
191    /// Set the server endpoint for both connection and cache
192    pub fn with_endpoint(mut self, endpoint: String) -> Self {
193        self.connection.endpoint = endpoint.clone();
194        self.cache.server_endpoint = endpoint;
195        self
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_client_config_default() {
205        let config = ClientConfig::default();
206        assert!(config.connection.endpoint.contains("8001"));
207        assert_eq!(config.connection.timeout_secs, Some(30));
208        assert!(!config.logging.quiet);
209    }
210
211    #[test]
212    fn test_client_config_for_testing() {
213        let config = ClientConfig::for_testing("http://test.example.com:1234");
214        assert_eq!(config.connection.endpoint, "http://test.example.com:1234");
215    }
216
217    #[test]
218    fn test_client_config_with_endpoint() {
219        let config =
220            ClientConfig::default().with_endpoint("http://custom.example.com:5678".to_string());
221
222        assert_eq!(config.connection.endpoint, "http://custom.example.com:5678");
223        assert_eq!(
224            config.cache.server_endpoint,
225            "http://custom.example.com:5678"
226        );
227    }
228
229    #[test]
230    fn test_client_config_validation() {
231        let mut config = ClientConfig::default();
232        assert!(config.validate().is_ok());
233
234        config.connection.endpoint = String::new();
235        assert!(config.validate().is_err());
236    }
237
238    #[test]
239    fn test_client_config_backward_compatibility() {
240        let config = ClientConfig::for_testing("http://test.com:8080");
241        assert_eq!(config.grpc_endpoint(), "http://test.com:8080");
242        assert_eq!(config.timeout_secs(), Some(30));
243    }
244}