modelexpress_common/
client_config.rs1use crate::cache::CacheConfig;
5use crate::config::{ConnectionConfig, LogFormat, LogLevel, load_layered_config};
6use anyhow::Result;
7use clap::Parser;
8use config::ConfigError;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12#[derive(Parser, Debug)]
14#[command(author, version, about, long_about = None)]
15pub struct ClientArgs {
16 #[arg(short, long, value_name = "FILE")]
18 pub config: Option<PathBuf>,
19
20 #[arg(short, long, env = "MODEL_EXPRESS_ENDPOINT")]
22 pub endpoint: Option<String>,
23
24 #[arg(short, long, env = "MODEL_EXPRESS_TIMEOUT")]
26 pub timeout: Option<u64>,
27
28 #[arg(long, env = "MODEL_EXPRESS_CACHE_PATH")]
30 pub cache_path: Option<PathBuf>,
31
32 #[arg(short = 'v', long, env = "MODEL_EXPRESS_LOG_LEVEL", value_enum)]
34 pub log_level: Option<LogLevel>,
35
36 #[arg(long, env = "MODEL_EXPRESS_LOG_FORMAT", value_enum)]
38 pub log_format: Option<LogFormat>,
39
40 #[arg(long, short = 'q')]
42 pub quiet: bool,
43
44 #[arg(long, env = "MODEL_EXPRESS_MAX_RETRIES")]
46 pub max_retries: Option<u32>,
47
48 #[arg(long, env = "MODEL_EXPRESS_RETRY_DELAY")]
50 pub retry_delay: Option<u64>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
55pub struct ClientConfig {
56 pub connection: ConnectionConfig,
58 pub cache: CacheConfig,
60 pub logging: LoggingConfig,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, Default)]
66pub struct LoggingConfig {
67 #[serde(default)]
69 pub level: LogLevel,
70 #[serde(default)]
72 pub format: LogFormat,
73 pub quiet: bool,
75}
76
77impl ClientConfig {
78 pub fn load(args: ClientArgs) -> Result<Self, ConfigError> {
84 let mut config =
86 load_layered_config(args.config.clone(), "MODEL_EXPRESS", Self::default())?;
87
88 if let Some(endpoint) = args.endpoint {
90 config.connection.endpoint = endpoint;
91 }
92
93 if let Some(timeout) = args.timeout {
94 config.connection.timeout_secs = Some(timeout);
95 }
96
97 if let Some(max_retries) = args.max_retries {
98 config.connection.max_retries = Some(max_retries);
99 }
100
101 if let Some(retry_delay) = args.retry_delay {
102 config.connection.retry_delay_secs = Some(retry_delay);
103 }
104
105 if let Some(cache_path) = args.cache_path {
106 config.cache.local_path = cache_path;
107 }
108
109 if let Some(log_level) = args.log_level {
110 config.logging.level = log_level;
111 }
112
113 if let Some(log_format) = args.log_format {
114 config.logging.format = log_format;
115 }
116
117 if args.quiet {
118 config.logging.quiet = true;
119 }
120
121 config.validate()?;
123
124 Ok(config)
125 }
126
127 pub fn validate(&self) -> Result<(), ConfigError> {
129 if self.connection.endpoint.is_empty() {
131 return Err(ConfigError::Message(
132 "Server endpoint cannot be empty".to_string(),
133 ));
134 }
135
136 if let Some(timeout) = self.connection.timeout_secs
138 && timeout == 0
139 {
140 return Err(ConfigError::Message(
141 "Timeout must be greater than 0".to_string(),
142 ));
143 }
144
145 if !self.cache.local_path.exists()
147 && let Err(e) = std::fs::create_dir_all(&self.cache.local_path)
148 {
149 return Err(ConfigError::Message(format!(
150 "Cannot create cache directory {:?}: {}",
151 self.cache.local_path, e
152 )));
153 }
154
155 Ok(())
156 }
157
158 pub fn grpc_endpoint(&self) -> &str {
160 &self.connection.endpoint
161 }
162
163 pub fn timeout_secs(&self) -> Option<u64> {
165 self.connection.timeout_secs
166 }
167
168 pub fn for_testing(endpoint: impl Into<String>) -> Self {
170 Self {
171 connection: ConnectionConfig::new(endpoint),
172 cache: CacheConfig::default(),
173 logging: LoggingConfig::default(),
174 }
175 }
176
177 pub fn with_cache_path(mut self, cache_path: Option<PathBuf>) -> Self {
179 if let Some(path) = cache_path {
180 self.cache.local_path = path;
181 }
182 self
183 }
184
185 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
187 self.connection.timeout_secs = Some(timeout_secs);
188 self
189 }
190
191 pub fn with_endpoint(mut self, endpoint: String) -> Self {
193 self.connection.endpoint = endpoint.clone();
194 self.cache.server_endpoint = endpoint;
195 self
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_client_config_default() {
205 let config = ClientConfig::default();
206 assert!(config.connection.endpoint.contains("8001"));
207 assert_eq!(config.connection.timeout_secs, Some(30));
208 assert!(!config.logging.quiet);
209 }
210
211 #[test]
212 fn test_client_config_for_testing() {
213 let config = ClientConfig::for_testing("http://test.example.com:1234");
214 assert_eq!(config.connection.endpoint, "http://test.example.com:1234");
215 }
216
217 #[test]
218 fn test_client_config_with_endpoint() {
219 let config =
220 ClientConfig::default().with_endpoint("http://custom.example.com:5678".to_string());
221
222 assert_eq!(config.connection.endpoint, "http://custom.example.com:5678");
223 assert_eq!(
224 config.cache.server_endpoint,
225 "http://custom.example.com:5678"
226 );
227 }
228
229 #[test]
230 fn test_client_config_validation() {
231 let mut config = ClientConfig::default();
232 assert!(config.validate().is_ok());
233
234 config.connection.endpoint = String::new();
235 assert!(config.validate().is_err());
236 }
237
238 #[test]
239 fn test_client_config_backward_compatibility() {
240 let config = ClientConfig::for_testing("http://test.com:8080");
241 assert_eq!(config.grpc_endpoint(), "http://test.com:8080");
242 assert_eq!(config.timeout_secs(), Some(30));
243 }
244}