Skip to main content

modelexpress_common/
client_config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::cache::CacheConfig;
5use crate::config::{ConnectionConfig, LogFormat, LogLevel, load_layered_config};
6use anyhow::Result;
7use clap::Parser;
8use config::ConfigError;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12/// Shared command line arguments for the ModelExpress client.
13///
14/// # Adding New Arguments
15///
16/// This struct is the **single source of truth** for client CLI arguments and environment
17/// variables. It is shared between:
18/// - The `modelexpress-cli` binary (via `#[command(flatten)]` in the `Cli` struct)
19/// - Any other client binaries that need these arguments
20/// - The `ClientConfig::load()` function which applies these values
21///
22/// When adding a new argument:
23/// 1. Add the field here with appropriate `#[arg(...)]` attributes
24/// 2. Include `env = "MODEL_EXPRESS_..."` for environment variable support
25/// 3. Update `ClientConfig::load()` to apply the new argument to the config
26/// 4. Add tests in the `tests` module below
27/// 5. Update CLI.md documentation if applicable
28///
29/// # Short Flags
30///
31/// Avoid using `-v` as a short flag here - it's reserved for the CLI's `--verbose` flag
32/// which uses `-v`, `-vv`, `-vvv` counting. The CLI embeds this struct via flatten,
33/// so short flag conflicts will cause runtime panics.
34#[derive(Parser, Debug, Clone)]
35#[command(author, version, about, long_about = None)]
36pub struct ClientArgs {
37    /// Configuration file path
38    #[arg(short, long, value_name = "FILE")]
39    pub config: Option<PathBuf>,
40
41    /// Server endpoint
42    #[arg(short, long, env = "MODEL_EXPRESS_ENDPOINT")]
43    pub endpoint: Option<String>,
44
45    /// Request timeout in seconds
46    #[arg(short, long, env = "MODEL_EXPRESS_TIMEOUT")]
47    pub timeout: Option<u64>,
48
49    /// Cache path override
50    #[arg(long, env = "MODEL_EXPRESS_CACHE_PATH")]
51    pub cache_path: Option<PathBuf>,
52
53    /// Log level (no short flag to avoid conflict with CLI's -v/--verbose)
54    #[arg(long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
55    pub log_level: Option<LogLevel>,
56
57    /// Log format
58    #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
59    pub log_format: Option<LogFormat>,
60
61    /// Quiet mode (suppress all output except errors)
62    #[arg(long, short = 'q')]
63    pub quiet: bool,
64
65    /// Maximum number of retries
66    #[arg(long, env = "MODEL_EXPRESS_MAX_RETRIES")]
67    pub max_retries: Option<u32>,
68
69    /// Retry delay in seconds
70    #[arg(long, env = "MODEL_EXPRESS_RETRY_DELAY")]
71    pub retry_delay: Option<u64>,
72
73    /// Disable shared storage mode (will transfer files from server to client)
74    #[arg(long, env = "MODEL_EXPRESS_NO_SHARED_STORAGE")]
75    pub no_shared_storage: bool,
76
77    /// Chunk size in bytes for file transfer when shared storage is disabled
78    #[arg(long, env = "MODEL_EXPRESS_TRANSFER_CHUNK_SIZE")]
79    pub transfer_chunk_size: Option<usize>,
80}
81
82/// Complete client configuration
83#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct ClientConfig {
85    /// Connection settings
86    pub connection: ConnectionConfig,
87    /// Cache configuration
88    pub cache: CacheConfig,
89    /// Logging configuration
90    pub logging: LoggingConfig,
91}
92
93/// Logging configuration for the client
94#[derive(Debug, Clone, Serialize, Deserialize, Default)]
95pub struct LoggingConfig {
96    /// Log level
97    #[serde(default)]
98    pub level: LogLevel,
99    /// Log format
100    #[serde(default)]
101    pub format: LogFormat,
102    /// Quiet mode
103    pub quiet: bool,
104}
105
106impl ClientConfig {
107    /// Load configuration from multiple sources in order of precedence:
108    /// 1. Command line arguments (highest priority)
109    /// 2. Environment variables (handled by clap's `env` attribute on `ClientArgs`)
110    /// 3. Configuration file
111    /// 4. Default values (lowest priority)
112    ///
113    /// # Adding New Arguments
114    ///
115    /// When you add a new field to `ClientArgs`:
116    /// 1. Add the corresponding override logic below in the "Apply CLI argument overrides" section
117    /// 2. Map the `ClientArgs` field to the appropriate `ClientConfig` field
118    /// 3. Add a test in the `tests` module to verify the override works
119    pub fn load(args: ClientArgs) -> Result<Self, ConfigError> {
120        // Start with layered config loading (file + env + defaults)
121        let mut config =
122            load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?;
123
124        // ==================== APPLY CLI ARGUMENT OVERRIDES ====================
125        // When adding a new field to ClientArgs, add the override logic here.
126        // These overrides apply CLI arguments (which include env vars via clap)
127        // on top of the config file values.
128
129        // Connection settings
130        if let Some(endpoint) = args.endpoint {
131            config.connection.endpoint = endpoint;
132        }
133
134        if let Some(timeout) = args.timeout {
135            config.connection.timeout_secs = Some(timeout);
136        }
137
138        if let Some(max_retries) = args.max_retries {
139            config.connection.max_retries = Some(max_retries);
140        }
141
142        if let Some(retry_delay) = args.retry_delay {
143            config.connection.retry_delay_secs = Some(retry_delay);
144        }
145
146        // Cache settings
147        if let Some(cache_path) = args.cache_path {
148            config.cache.local_path = cache_path;
149        }
150
151        if args.no_shared_storage {
152            config.cache.shared_storage = false;
153        }
154
155        if let Some(chunk_size) = args.transfer_chunk_size {
156            config.cache.transfer_chunk_size = chunk_size;
157        }
158
159        // Logging settings
160        if let Some(log_level) = args.log_level {
161            config.logging.level = log_level;
162        }
163
164        if let Some(log_format) = args.log_format {
165            config.logging.format = log_format;
166        }
167
168        if args.quiet {
169            config.logging.quiet = true;
170        }
171
172        // ==================== END CLI ARGUMENT OVERRIDES ====================
173
174        // Validate configuration
175        config.validate()?;
176
177        Ok(config)
178    }
179
180    /// Validate the configuration
181    pub fn validate(&self) -> Result<(), ConfigError> {
182        // Validate endpoint
183        if self.connection.endpoint.is_empty() {
184            return Err(ConfigError::Message(
185                "Server endpoint cannot be empty".to_string(),
186            ));
187        }
188
189        // Validate timeout
190        if let Some(timeout) = self.connection.timeout_secs
191            && timeout == 0
192        {
193            return Err(ConfigError::Message(
194                "Timeout must be greater than 0".to_string(),
195            ));
196        }
197
198        // Validate cache path exists or can be created
199        if !self.cache.local_path.exists()
200            && let Err(e) = std::fs::create_dir_all(&self.cache.local_path)
201        {
202            return Err(ConfigError::Message(format!(
203                "Cannot create cache directory {:?}: {}",
204                self.cache.local_path, e
205            )));
206        }
207
208        Ok(())
209    }
210
211    /// Get the gRPC endpoint for backward compatibility
212    pub fn grpc_endpoint(&self) -> &str {
213        &self.connection.endpoint
214    }
215
216    /// Get the timeout in seconds for backward compatibility
217    pub fn timeout_secs(&self) -> Option<u64> {
218        self.connection.timeout_secs
219    }
220
221    /// Create a simple client config for testing
222    pub fn for_testing(endpoint: impl Into<String>) -> Self {
223        Self {
224            connection: ConnectionConfig::new(endpoint),
225            cache: CacheConfig::default(),
226            logging: LoggingConfig::default(),
227        }
228    }
229
230    /// Apply cache path override if provided
231    pub fn with_cache_path(mut self, cache_path: Option<PathBuf>) -> Self {
232        if let Some(path) = cache_path {
233            self.cache.local_path = path;
234        }
235        self
236    }
237
238    /// Set timeout for the connection
239    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
240        self.connection.timeout_secs = Some(timeout_secs);
241        self
242    }
243
244    /// Set the server endpoint for both connection and cache
245    pub fn with_endpoint(mut self, endpoint: String) -> Self {
246        self.connection.endpoint = endpoint.clone();
247        self.cache.server_endpoint = endpoint;
248        self
249    }
250}
251
252#[cfg(test)]
253#[allow(clippy::expect_used)]
254mod tests {
255    use super::*;
256    use crate::constants;
257
258    #[test]
259    fn test_client_config_default() {
260        let config = ClientConfig::default();
261        assert!(config.connection.endpoint.contains("8001"));
262        assert_eq!(config.connection.timeout_secs, Some(30));
263        assert!(!config.logging.quiet);
264    }
265
266    #[test]
267    fn test_client_config_for_testing() {
268        let config = ClientConfig::for_testing("http://test.example.com:1234");
269        assert_eq!(config.connection.endpoint, "http://test.example.com:1234");
270    }
271
272    #[test]
273    fn test_client_config_with_endpoint() {
274        let config =
275            ClientConfig::default().with_endpoint("http://custom.example.com:5678".to_string());
276
277        assert_eq!(config.connection.endpoint, "http://custom.example.com:5678");
278        assert_eq!(
279            config.cache.server_endpoint,
280            "http://custom.example.com:5678"
281        );
282    }
283
284    #[test]
285    fn test_client_config_validation() {
286        let mut config = ClientConfig::default();
287        assert!(config.validate().is_ok());
288
289        config.connection.endpoint = String::new();
290        assert!(config.validate().is_err());
291    }
292
293    #[test]
294    fn test_client_config_backward_compatibility() {
295        let config = ClientConfig::for_testing("http://test.com:8080");
296        assert_eq!(config.grpc_endpoint(), "http://test.com:8080");
297        assert_eq!(config.timeout_secs(), Some(30));
298    }
299
300    #[test]
301    fn test_client_config_shared_storage_defaults() {
302        let config = ClientConfig::default();
303        assert!(config.cache.shared_storage);
304        assert_eq!(
305            config.cache.transfer_chunk_size,
306            constants::DEFAULT_TRANSFER_CHUNK_SIZE
307        );
308    }
309
310    #[test]
311    fn test_client_config_shared_storage_override() {
312        let mut config = ClientConfig::default();
313        config.cache.shared_storage = false;
314        config.cache.transfer_chunk_size = 64 * 1024;
315
316        assert!(!config.cache.shared_storage);
317        assert_eq!(config.cache.transfer_chunk_size, 64 * 1024);
318    }
319
320    #[test]
321    fn test_client_args_parse_defaults() {
322        // Test that ClientArgs can be parsed with no arguments (uses defaults)
323        let args = ClientArgs::try_parse_from(["test"]).expect("Failed to parse empty args");
324
325        assert!(args.endpoint.is_none());
326        assert!(args.timeout.is_none());
327        assert!(args.cache_path.is_none());
328        assert!(!args.quiet);
329        assert!(!args.no_shared_storage);
330        assert!(args.transfer_chunk_size.is_none());
331    }
332
333    #[test]
334    fn test_client_args_parse_cli_flags() {
335        // Test parsing various CLI flags
336        let args = ClientArgs::try_parse_from([
337            "test",
338            "--endpoint",
339            "http://custom:9000",
340            "--timeout",
341            "60",
342            "--quiet",
343            "--no-shared-storage",
344            "--transfer-chunk-size",
345            "1048576",
346        ])
347        .expect("Failed to parse CLI args");
348
349        assert_eq!(args.endpoint, Some("http://custom:9000".to_string()));
350        assert_eq!(args.timeout, Some(60));
351        assert!(args.quiet);
352        assert!(args.no_shared_storage);
353        assert_eq!(args.transfer_chunk_size, Some(1048576));
354    }
355
356    #[test]
357    fn test_client_args_short_flags() {
358        // Test short flag variants (-e for endpoint, -t for timeout, -q for quiet)
359        let args =
360            ClientArgs::try_parse_from(["test", "-e", "http://short:8000", "-t", "45", "-q"])
361                .expect("Failed to parse short flags");
362
363        assert_eq!(args.endpoint, Some("http://short:8000".to_string()));
364        assert_eq!(args.timeout, Some(45));
365        assert!(args.quiet);
366    }
367
368    #[test]
369    fn test_client_args_log_level() {
370        // Test --log-level flag (no short flag to avoid conflict with CLI's -v)
371        let args = ClientArgs::try_parse_from(["test", "--log-level", "debug"])
372            .expect("Failed to parse log level");
373
374        assert_eq!(args.log_level, Some(LogLevel::Debug));
375    }
376
377    #[test]
378    fn test_client_config_load_applies_cli_args() {
379        // Test that ClientConfig::load() properly applies CLI arguments
380        let args = ClientArgs {
381            config: None,
382            endpoint: Some("http://cli-override:7777".to_string()),
383            timeout: Some(120),
384            cache_path: None,
385            log_level: None,
386            log_format: None,
387            quiet: true,
388            max_retries: Some(5),
389            retry_delay: Some(10),
390            no_shared_storage: true,
391            transfer_chunk_size: Some(2097152),
392        };
393
394        let config = ClientConfig::load(args).expect("Failed to load config");
395
396        assert_eq!(config.connection.endpoint, "http://cli-override:7777");
397        assert_eq!(config.connection.timeout_secs, Some(120));
398        assert!(config.logging.quiet);
399        assert_eq!(config.connection.max_retries, Some(5));
400        assert_eq!(config.connection.retry_delay_secs, Some(10));
401        assert!(!config.cache.shared_storage);
402        assert_eq!(config.cache.transfer_chunk_size, 2097152);
403    }
404}