1use chrono::Duration;
5use clap::ValueEnum;
6use config::{Config, ConfigError, Environment, File};
7use serde::{Deserialize, Deserializer, Serialize};
8use std::fmt;
9use std::path::{Path, PathBuf};
10use std::str::FromStr;
11use tracing::{Level, info};
12
13pub fn parse_duration_string(value: &str) -> Result<Duration, String> {
16 use jiff::{Span, SpanRelativeTo};
17 let span = Span::from_str(value).map_err(|err| format!("Invalid duration: {err}"))?;
18
19 let signed_duration = span
22 .to_duration(SpanRelativeTo::days_are_24_hours())
23 .map_err(|err| format!("Invalid duration: {err}"))?;
24
25 let std_duration = std::time::Duration::try_from(signed_duration)
26 .map_err(|err| format!("Invalid duration: {err}"))?;
27
28 Duration::from_std(std_duration).map_err(|err| format!("Duration out of range: {err}"))
29}
30
31#[derive(Debug, Clone)]
33pub struct DurationConfig {
34 duration: Duration,
35}
36
37impl DurationConfig {
38 pub fn new(duration: Duration) -> Self {
39 Self { duration }
40 }
41
42 pub fn hours(hours: i64) -> Self {
43 Self {
44 duration: Duration::hours(hours),
45 }
46 }
47
48 pub fn as_chrono_duration(&self) -> Duration {
49 self.duration
50 }
51
52 pub fn num_seconds(&self) -> i64 {
53 self.duration.num_seconds()
54 }
55}
56
57impl fmt::Display for DurationConfig {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 write!(f, "{}s", self.duration.num_seconds())
60 }
61}
62
63impl Serialize for DurationConfig {
65 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
66 where
67 S: serde::Serializer,
68 {
69 self.duration.num_seconds().serialize(serializer)
70 }
71}
72
73impl<'de> Deserialize<'de> for DurationConfig {
74 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
75 where
76 D: Deserializer<'de>,
77 {
78 use serde::de::{self, Visitor};
79
80 struct DurationVisitor;
81
82 impl<'de> Visitor<'de> for DurationVisitor {
83 type Value = DurationConfig;
84
85 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
86 formatter
87 .write_str("a duration string like '2h', '30m', '45s' or number of seconds")
88 }
89
90 fn visit_str<E>(self, value: &str) -> Result<DurationConfig, E>
91 where
92 E: de::Error,
93 {
94 parse_duration_string(value)
95 .map(DurationConfig::new)
96 .map_err(de::Error::custom)
97 }
98
99 fn visit_i64<E>(self, value: i64) -> Result<DurationConfig, E>
100 where
101 E: de::Error,
102 {
103 Ok(DurationConfig::new(Duration::seconds(value)))
104 }
105
106 fn visit_u64<E>(self, value: u64) -> Result<DurationConfig, E>
107 where
108 E: de::Error,
109 {
110 Ok(DurationConfig::new(Duration::seconds(value as i64)))
111 }
112 }
113
114 deserializer.deserialize_any(DurationVisitor)
115 }
116}
117
118#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Serialize, Deserialize, Default)]
120pub enum LogLevel {
121 Trace,
122 Debug,
123 #[default]
124 Info,
125 Warn,
126 Error,
127}
128
129impl fmt::Display for LogLevel {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 match self {
132 LogLevel::Trace => write!(f, "trace"),
133 LogLevel::Debug => write!(f, "debug"),
134 LogLevel::Info => write!(f, "info"),
135 LogLevel::Warn => write!(f, "warn"),
136 LogLevel::Error => write!(f, "error"),
137 }
138 }
139}
140
141impl From<LogLevel> for Level {
142 fn from(log_level: LogLevel) -> Self {
143 match log_level {
144 LogLevel::Trace => Level::TRACE,
145 LogLevel::Debug => Level::DEBUG,
146 LogLevel::Info => Level::INFO,
147 LogLevel::Warn => Level::WARN,
148 LogLevel::Error => Level::ERROR,
149 }
150 }
151}
152
153impl FromStr for LogLevel {
154 type Err = String;
155
156 fn from_str(s: &str) -> Result<Self, Self::Err> {
157 match s.to_lowercase().as_str() {
158 "trace" => Ok(LogLevel::Trace),
159 "debug" => Ok(LogLevel::Debug),
160 "info" => Ok(LogLevel::Info),
161 "warn" => Ok(LogLevel::Warn),
162 "error" => Ok(LogLevel::Error),
163 _ => Err(format!("Invalid log level: {s}")),
164 }
165 }
166}
167
168#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Serialize, Deserialize, Default)]
170pub enum LogFormat {
171 Json,
172 #[default]
173 Pretty,
174 Compact,
175}
176
177impl fmt::Display for LogFormat {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 match self {
180 LogFormat::Json => write!(f, "json"),
181 LogFormat::Pretty => write!(f, "pretty"),
182 LogFormat::Compact => write!(f, "compact"),
183 }
184 }
185}
186
187impl FromStr for LogFormat {
188 type Err = String;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 match s.to_lowercase().as_str() {
192 "json" => Ok(LogFormat::Json),
193 "pretty" => Ok(LogFormat::Pretty),
194 "compact" => Ok(LogFormat::Compact),
195 _ => Err(format!("Invalid log format: {s}")),
196 }
197 }
198}
199
200pub trait ConfigLoader<T> {
202 fn load_layered(
208 config_file: Option<PathBuf>,
209 env_prefix: &str,
210 defaults: T,
211 ) -> Result<T, ConfigError>
212 where
213 T: serde::de::DeserializeOwned + Default;
214}
215
216pub fn load_config_file_strict<T>(config_file: &Path) -> Result<T, ConfigError>
220where
221 T: serde::de::DeserializeOwned,
222{
223 if !config_file.exists() {
224 return Err(ConfigError::Message(format!(
225 "Configuration file not found: {}",
226 config_file.display()
227 )));
228 }
229
230 let config = Config::builder()
231 .add_source(File::from(config_file.to_path_buf()))
232 .build()?;
233
234 config.try_deserialize::<T>()
235}
236
237fn discover_default_config() -> Option<PathBuf> {
238 let default_configs = [
239 "model-express.yaml",
240 "model-express.yml",
241 "/etc/model-express/config.yaml",
242 "/etc/model-express/config.yml",
243 ];
244
245 for config_path in &default_configs {
246 if PathBuf::from(config_path).exists() {
247 return Some(PathBuf::from(config_path));
248 }
249 }
250 None
251}
252
253fn load_config_with_env_strict<T>(
256 config_file: Option<PathBuf>,
257 env_prefix: &str,
258) -> Result<T, ConfigError>
259where
260 T: serde::de::DeserializeOwned,
261{
262 let mut builder = Config::builder();
263
264 if let Some(config_path) = &config_file {
266 if !config_path.exists() {
267 return Err(ConfigError::Message(format!(
268 "Configuration file not found: {}",
269 config_path.display()
270 )));
271 }
272 builder = builder.add_source(File::from(config_path.clone()));
273 } else if let Some(default_path) = discover_default_config() {
274 info!("Using default config: {}", default_path.display());
275 builder = builder.add_source(File::from(default_path));
276 } else {
277 return Err(ConfigError::Message(
278 "No configuration file specified and no default config found. \
279 Please specify a config file with --config or create a default config."
280 .to_string(),
281 ));
282 }
283
284 builder = builder.add_source(
286 Environment::with_prefix(env_prefix)
287 .try_parsing(true)
288 .separator("_"),
289 );
290
291 let config = builder.build()?;
292 config.try_deserialize::<T>()
293}
294
295pub fn validate_config_file<T>(config_file: &Path) -> Result<T, ConfigError>
298where
299 T: serde::de::DeserializeOwned,
300{
301 load_config_file_strict(config_file)
302}
303
304pub fn load_layered_config<T>(
306 config_file: Option<PathBuf>,
307 env_prefix: &str,
308 defaults: T,
309) -> Result<T, ConfigError>
310where
311 T: serde::de::DeserializeOwned + Default,
312{
313 match load_config_with_env_strict(config_file, env_prefix) {
315 Ok(config) => Ok(config),
316 Err(_) => {
317 Ok(defaults)
320 }
321 }
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct ConnectionConfig {
327 pub endpoint: String,
329
330 pub timeout_secs: Option<u64>,
332
333 pub max_retries: Option<u32>,
335
336 pub retry_delay_secs: Option<u64>,
338}
339
340impl Default for ConnectionConfig {
341 fn default() -> Self {
342 Self {
343 endpoint: format!("http://localhost:{}", crate::constants::DEFAULT_GRPC_PORT),
344 timeout_secs: Some(crate::constants::DEFAULT_TIMEOUT_SECS),
345 max_retries: Some(3),
346 retry_delay_secs: Some(1),
347 }
348 }
349}
350
351impl ConnectionConfig {
352 pub fn new(endpoint: impl Into<String>) -> Self {
353 Self {
354 endpoint: endpoint.into(),
355 timeout_secs: Some(crate::constants::DEFAULT_TIMEOUT_SECS),
356 max_retries: Some(3),
357 retry_delay_secs: Some(1),
358 }
359 }
360
361 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
362 self.timeout_secs = Some(timeout_secs);
363 self
364 }
365
366 pub fn with_retries(mut self, max_retries: u32, delay_secs: u64) -> Self {
367 self.max_retries = Some(max_retries);
368 self.retry_delay_secs = Some(delay_secs);
369 self
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use std::fs;
377 use tempfile::tempdir;
378
379 #[test]
380 fn test_duration_config_from_string() {
381 match parse_duration_string("2h") {
382 Ok(duration) => assert_eq!(duration.num_hours(), 2),
383 Err(e) => panic!("Failed to parse duration '2h': {e}"),
384 }
385 }
386
387 #[test]
388 fn test_log_level_from_string() {
389 match "info".parse::<LogLevel>() {
390 Ok(level) => assert_eq!(level, LogLevel::Info),
391 Err(e) => panic!("Failed to parse 'info' as LogLevel: {e}"),
392 }
393 match "debug".parse::<LogLevel>() {
394 Ok(level) => assert_eq!(level, LogLevel::Debug),
395 Err(e) => panic!("Failed to parse 'debug' as LogLevel: {e}"),
396 }
397 }
398
399 #[test]
400 fn test_log_format_from_string() {
401 match "json".parse::<LogFormat>() {
402 Ok(format) => assert_eq!(format, LogFormat::Json),
403 Err(e) => panic!("Failed to parse 'json' as LogFormat: {e}"),
404 }
405 match "pretty".parse::<LogFormat>() {
406 Ok(format) => assert_eq!(format, LogFormat::Pretty),
407 Err(e) => panic!("Failed to parse 'pretty' as LogFormat: {e}"),
408 }
409 }
410
411 #[test]
412 fn test_connection_config_default() {
413 let config = ConnectionConfig::default();
414 assert!(config.endpoint.contains("8001"));
415 assert_eq!(config.timeout_secs, Some(30));
416 }
417
418 #[test]
419 fn test_connection_config_builder() {
420 let config = ConnectionConfig::new("http://test.com:8080")
421 .with_timeout(60)
422 .with_retries(5, 2);
423
424 assert_eq!(config.endpoint, "http://test.com:8080");
425 assert_eq!(config.timeout_secs, Some(60));
426 assert_eq!(config.max_retries, Some(5));
427 assert_eq!(config.retry_delay_secs, Some(2));
428 }
429
430 #[test]
431 #[allow(clippy::expect_used)]
432 fn test_load_config_file_strict_missing_file() {
433 let non_existent_file = PathBuf::from("/non/existent/file.yaml");
434 let result: Result<ConnectionConfig, ConfigError> =
435 load_config_file_strict(&non_existent_file);
436
437 assert!(result.is_err());
438 let error_message = result
439 .expect_err("Expected error for missing file")
440 .to_string();
441 assert!(error_message.contains("Configuration file not found"));
442 }
443
444 #[test]
445 #[allow(clippy::expect_used)]
446 fn test_load_config_file_strict_valid_file() {
447 let temp_dir = tempdir().expect("Failed to create temp dir");
448 let config_file = temp_dir.path().join("test_config.yaml");
449
450 let valid_config = r#"
451 endpoint: "http://localhost:9999"
452 timeout_secs: 60
453 max_retries: 5
454 retry_delay_secs: 2
455 "#;
456
457 fs::write(&config_file, valid_config).expect("Failed to write config file");
458
459 let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
460 assert!(result.is_ok());
461
462 let config = result.expect("Expected successful config parsing");
463 assert_eq!(config.endpoint, "http://localhost:9999");
464 assert_eq!(config.timeout_secs, Some(60));
465 assert_eq!(config.max_retries, Some(5));
466 assert_eq!(config.retry_delay_secs, Some(2));
467 }
468
469 #[test]
470 #[allow(clippy::expect_used)]
471 fn test_load_config_file_strict_invalid_yaml() {
472 let temp_dir = tempdir().expect("Failed to create temp dir");
473 let config_file = temp_dir.path().join("invalid_config.yaml");
474
475 let invalid_config = r#"
476 endpoint: "http://localhost:9999"
477 timeout_secs: not_a_number
478 invalid_yaml_structure:
479 missing_indent
480 "#;
481
482 fs::write(&config_file, invalid_config).expect("Failed to write config file");
483
484 let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
485 assert!(result.is_err());
486 }
487
488 #[test]
489 #[allow(clippy::expect_used)]
490 fn test_load_config_file_strict_wrong_type() {
491 let temp_dir = tempdir().expect("Failed to create temp dir");
492 let config_file = temp_dir.path().join("wrong_type_config.yaml");
493
494 let wrong_type_config = r#"
495 endpoint: "http://localhost:9999"
496 timeout_secs: "this_should_be_a_number"
497 "#;
498
499 fs::write(&config_file, wrong_type_config).expect("Failed to write config file");
500
501 let result: Result<ConnectionConfig, ConfigError> = load_config_file_strict(&config_file);
502 assert!(result.is_err());
503 }
504
505 #[test]
506 #[allow(clippy::expect_used)]
507 fn test_validate_config_file_same_as_strict() {
508 let temp_dir = tempdir().expect("Failed to create temp dir");
509 let config_file = temp_dir.path().join("test_config.yaml");
510
511 let valid_config = r#"
512 endpoint: "http://localhost:9999"
513 timeout_secs: 60
514 "#;
515
516 fs::write(&config_file, valid_config).expect("Failed to write config file");
517
518 let strict_result: Result<ConnectionConfig, ConfigError> =
519 load_config_file_strict(&config_file);
520 let validate_result: Result<ConnectionConfig, ConfigError> =
521 validate_config_file(&config_file);
522
523 assert!(strict_result.is_ok());
524 assert!(validate_result.is_ok());
525
526 let strict_config = strict_result.expect("Expected successful strict config parsing");
527 let validate_config = validate_result.expect("Expected successful validate config parsing");
528
529 assert_eq!(strict_config.endpoint, validate_config.endpoint);
530 assert_eq!(strict_config.timeout_secs, validate_config.timeout_secs);
531 }
532}