lmrc_http_common/
config.rs1use serde::{Deserialize, Serialize};
22use std::env;
23use std::net::{SocketAddr, IpAddr};
24use std::str::FromStr;
25use thiserror::Error;
26
27#[derive(Debug, Error)]
29pub enum ConfigError {
30 #[error("Missing required environment variable: {0}")]
32 MissingEnvVar(String),
33
34 #[error("Invalid configuration value for {key}: {message}")]
36 InvalidValue { key: String, message: String },
37
38 #[error("Failed to parse {key}: {source}")]
40 ParseError {
41 key: String,
42 #[source]
43 source: Box<dyn std::error::Error + Send + Sync>,
44 },
45}
46
47pub type ConfigResult<T> = Result<T, ConfigError>;
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ServerConfig {
70 pub host: String,
72 pub port: u16,
74 pub cors_origins: Vec<String>,
76}
77
78impl ServerConfig {
79 pub fn from_env() -> ConfigResult<Self> {
86 Self::from_env_with_prefix("SERVER_")
87 }
88
89 pub fn from_env_with_prefix(prefix: &str) -> ConfigResult<Self> {
95 let host = env::var(format!("{}HOST", prefix))
96 .unwrap_or_else(|_| "0.0.0.0".to_string());
97
98 let port = env::var(format!("{}PORT", prefix))
99 .ok()
100 .and_then(|s| s.parse().ok())
101 .unwrap_or(8080);
102
103 let cors_origins = env::var("CORS_ORIGINS")
104 .unwrap_or_else(|_| "http://localhost:3000".to_string())
105 .split(',')
106 .map(|s| s.trim().to_string())
107 .filter(|s| !s.is_empty())
108 .collect();
109
110 Ok(Self {
111 host,
112 port,
113 cors_origins,
114 })
115 }
116
117 pub fn bind_addr(&self) -> ConfigResult<SocketAddr> {
119 let ip = IpAddr::from_str(&self.host).map_err(|e| ConfigError::ParseError {
120 key: "host".to_string(),
121 source: Box::new(e),
122 })?;
123
124 Ok(SocketAddr::new(ip, self.port))
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct DatabaseConfig {
149 pub url: String,
151 pub max_connections: u32,
153 pub connect_timeout: u64,
155}
156
157impl DatabaseConfig {
158 pub fn from_env(prefix: Option<&str>) -> ConfigResult<Self> {
169 let prefix = prefix.unwrap_or("DATABASE_");
170
171 let url = env::var(format!("{}URL", prefix)).map_err(|_| {
172 ConfigError::MissingEnvVar(format!("{}URL", prefix))
173 })?;
174
175 let max_connections = env::var(format!("{}MAX_CONNECTIONS", prefix))
176 .ok()
177 .and_then(|s| s.parse().ok())
178 .unwrap_or(10);
179
180 let connect_timeout = env::var(format!("{}CONNECT_TIMEOUT", prefix))
181 .ok()
182 .and_then(|s| s.parse().ok())
183 .unwrap_or(30);
184
185 Ok(Self {
186 url,
187 max_connections,
188 connect_timeout,
189 })
190 }
191}
192
193pub trait ConfigLoader: Sized {
195 fn from_env() -> ConfigResult<Self>;
197
198 fn validate(&self) -> ConfigResult<()> {
200 Ok(())
201 }
202}
203
204impl ConfigLoader for ServerConfig {
205 fn from_env() -> ConfigResult<Self> {
206 Self::from_env()
207 }
208
209 fn validate(&self) -> ConfigResult<()> {
210 if self.port == 0 {
211 return Err(ConfigError::InvalidValue {
212 key: "port".to_string(),
213 message: "Port cannot be 0".to_string(),
214 });
215 }
216
217 if self.host.is_empty() {
218 return Err(ConfigError::InvalidValue {
219 key: "host".to_string(),
220 message: "Host cannot be empty".to_string(),
221 });
222 }
223
224 if IpAddr::from_str(&self.host).is_err() {
226 return Err(ConfigError::InvalidValue {
227 key: "host".to_string(),
228 message: format!("Invalid IP address: {}", self.host),
229 });
230 }
231
232 Ok(())
233 }
234}
235
236impl ConfigLoader for DatabaseConfig {
237 fn from_env() -> ConfigResult<Self> {
238 Self::from_env(None)
239 }
240
241 fn validate(&self) -> ConfigResult<()> {
242 if self.url.is_empty() {
243 return Err(ConfigError::InvalidValue {
244 key: "url".to_string(),
245 message: "Database URL cannot be empty".to_string(),
246 });
247 }
248
249 if self.max_connections == 0 {
250 return Err(ConfigError::InvalidValue {
251 key: "max_connections".to_string(),
252 message: "Max connections must be greater than 0".to_string(),
253 });
254 }
255
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::env;
264
265 #[test]
266 fn test_server_config_defaults() {
267 unsafe {
269 env::remove_var("SERVER_HOST");
270 env::remove_var("SERVER_PORT");
271 env::remove_var("CORS_ORIGINS");
272 }
273
274 let config = ServerConfig::from_env().unwrap();
275 assert_eq!(config.host, "0.0.0.0");
276 assert_eq!(config.port, 8080);
277 assert_eq!(config.cors_origins, vec!["http://localhost:3000"]);
278 }
279
280 #[test]
281 fn test_server_config_custom_values() {
282 unsafe {
284 env::remove_var("SERVER_HOST");
285 env::remove_var("SERVER_PORT");
286 env::remove_var("CORS_ORIGINS");
287 }
288
289 unsafe {
290 env::set_var("SERVER_HOST", "127.0.0.1");
291 env::set_var("SERVER_PORT", "3000");
292 env::set_var("CORS_ORIGINS", "http://example.com,http://test.com");
293 }
294
295 let config = ServerConfig::from_env().unwrap();
296 assert_eq!(config.host, "127.0.0.1");
297 assert_eq!(config.port, 3000);
298 assert_eq!(
299 config.cors_origins,
300 vec!["http://example.com", "http://test.com"]
301 );
302
303 unsafe {
305 env::remove_var("SERVER_HOST");
306 env::remove_var("SERVER_PORT");
307 env::remove_var("CORS_ORIGINS");
308 }
309 }
310
311 #[test]
312 fn test_server_config_bind_addr() {
313 let config = ServerConfig {
314 host: "127.0.0.1".to_string(),
315 port: 8080,
316 cors_origins: vec![],
317 };
318
319 let addr = config.bind_addr().unwrap();
320 assert_eq!(addr.to_string(), "127.0.0.1:8080");
321 }
322
323 #[test]
324 fn test_server_config_validation() {
325 let mut config = ServerConfig {
326 host: "127.0.0.1".to_string(),
327 port: 8080,
328 cors_origins: vec![],
329 };
330
331 assert!(config.validate().is_ok());
332
333 config.port = 0;
334 assert!(config.validate().is_err());
335
336 config.port = 8080;
337 config.host = "invalid".to_string();
338 assert!(config.validate().is_err());
339 }
340
341 #[test]
342 fn test_database_config_from_env() {
343 unsafe {
344 env::set_var("DATABASE_URL", "postgres://localhost/test");
345 env::set_var("DATABASE_MAX_CONNECTIONS", "20");
346 env::set_var("DATABASE_CONNECT_TIMEOUT", "60");
347 }
348
349 let config = DatabaseConfig::from_env(None).unwrap();
350 assert_eq!(config.url, "postgres://localhost/test");
351 assert_eq!(config.max_connections, 20);
352 assert_eq!(config.connect_timeout, 60);
353
354 unsafe {
356 env::remove_var("DATABASE_URL");
357 env::remove_var("DATABASE_MAX_CONNECTIONS");
358 env::remove_var("DATABASE_CONNECT_TIMEOUT");
359 }
360 }
361
362 #[test]
363 fn test_database_config_missing_url() {
364 unsafe {
365 env::remove_var("DATABASE_URL");
366 }
367
368 let result = DatabaseConfig::from_env(None);
369 assert!(result.is_err());
370 assert!(matches!(result.unwrap_err(), ConfigError::MissingEnvVar(_)));
371 }
372
373 #[test]
374 fn test_database_config_validation() {
375 let mut config = DatabaseConfig {
376 url: "postgres://localhost/test".to_string(),
377 max_connections: 10,
378 connect_timeout: 30,
379 };
380
381 assert!(config.validate().is_ok());
382
383 config.url = String::new();
384 assert!(config.validate().is_err());
385
386 config.url = "postgres://localhost/test".to_string();
387 config.max_connections = 0;
388 assert!(config.validate().is_err());
389 }
390}