1use anyhow::Result;
7
8use super::defaults;
9use super::types::{Config, Server};
10
11pub trait EnvProvider {
13 fn get(&self, key: &str) -> Option<String>;
14}
15
16#[derive(Default)]
18pub struct StdEnvProvider;
19
20impl EnvProvider for StdEnvProvider {
21 fn get(&self, key: &str) -> Option<String> {
22 std::env::var(key).ok()
23 }
24}
25
26pub fn parse_server_from_env<E: EnvProvider>(index: usize, env: &E) -> Option<Server> {
35 let host_key = format!("NNTP_SERVER_{}_HOST", index);
37 let host = env.get(&host_key)?;
38
39 let port_key = format!("NNTP_SERVER_{}_PORT", index);
41 let port = env
42 .get(&port_key)
43 .and_then(|p| p.parse::<u16>().ok())
44 .unwrap_or(119); let name_key = format!("NNTP_SERVER_{}_NAME", index);
48 let name = env
49 .get(&name_key)
50 .unwrap_or_else(|| format!("Server {}", index));
51
52 let username_key = format!("NNTP_SERVER_{}_USERNAME", index);
54 let username = env.get(&username_key);
55
56 let password_key = format!("NNTP_SERVER_{}_PASSWORD", index);
57 let password = env.get(&password_key);
58
59 let max_conn_key = format!("NNTP_SERVER_{}_MAX_CONNECTIONS", index);
60 let max_connections = env
61 .get(&max_conn_key)
62 .and_then(|m| m.parse::<usize>().ok())
63 .and_then(|m| crate::types::MaxConnections::try_new(m).ok())
64 .unwrap_or_else(defaults::max_connections);
65
66 let use_tls_key = format!("NNTP_SERVER_{}_USE_TLS", index);
68 let use_tls = env
69 .get(&use_tls_key)
70 .and_then(|v| v.parse::<bool>().ok())
71 .unwrap_or(false);
72
73 let tls_verify_key = format!("NNTP_SERVER_{}_TLS_VERIFY_CERT", index);
74 let tls_verify_cert = env
75 .get(&tls_verify_key)
76 .and_then(|v| v.parse::<bool>().ok())
77 .unwrap_or_else(defaults::tls_verify_cert);
78
79 let tls_cert_path_key = format!("NNTP_SERVER_{}_TLS_CERT_PATH", index);
80 let tls_cert_path = env.get(&tls_cert_path_key);
81
82 let keepalive_key = format!("NNTP_SERVER_{}_CONNECTION_KEEPALIVE", index);
84 let connection_keepalive = env
85 .get(&keepalive_key)
86 .and_then(|k| k.parse::<u64>().ok())
87 .map(std::time::Duration::from_secs);
88
89 let health_max_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_MAX_PER_CYCLE", index);
91 let health_check_max_per_cycle = env
92 .get(&health_max_key)
93 .and_then(|h| h.parse::<usize>().ok())
94 .unwrap_or_else(defaults::health_check_max_per_cycle);
95
96 let health_timeout_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_POOL_TIMEOUT", index);
97 let health_check_pool_timeout = env
98 .get(&health_timeout_key)
99 .and_then(|h| h.parse::<u64>().ok())
100 .map(std::time::Duration::from_secs)
101 .unwrap_or_else(defaults::health_check_pool_timeout);
102
103 let tier_key = format!("NNTP_SERVER_{}_TIER", index);
104 let tier = match env.get(&tier_key) {
105 Some(tier_str) => tier_str.parse::<u8>().unwrap_or_else(|_| {
106 panic!(
107 "Invalid tier in {}: '{}' (must be 0-255)",
108 tier_key, tier_str
109 )
110 }),
111 None => 0,
112 };
113
114 Some(Server {
115 host: crate::types::HostName::try_new(host.clone())
116 .unwrap_or_else(|_| panic!("Invalid hostname in {}: '{}'", host_key, host)),
117 port: crate::types::Port::try_new(port)
118 .unwrap_or_else(|_| panic!("Invalid port in {}: {}", port_key, port)),
119 name: crate::types::ServerName::try_new(name.clone())
120 .unwrap_or_else(|_| panic!("Invalid server name in {}: '{}'", name_key, name)),
121 username,
122 password,
123 max_connections,
124 use_tls,
125 tls_verify_cert,
126 tls_cert_path,
127 connection_keepalive,
128 health_check_max_per_cycle,
129 health_check_pool_timeout,
130 tier,
131 })
132}
133
134fn load_servers_from_env() -> Option<Vec<Server>> {
145 load_servers_from_env_provider(&StdEnvProvider)
146}
147
148pub fn load_servers_from_env_provider<E: EnvProvider>(env: &E) -> Option<Vec<Server>> {
150 let servers: Vec<Server> = (0..)
151 .map(|i| parse_server_from_env(i, env))
152 .take_while(|s| s.is_some())
153 .flatten()
154 .collect();
155
156 if servers.is_empty() {
157 None
158 } else {
159 Some(servers)
160 }
161}
162
163pub fn has_server_env_vars() -> bool {
167 std::env::var("NNTP_SERVER_0_HOST").is_ok()
168}
169
170pub fn load_config_from_env() -> Result<Config> {
178 use anyhow::Context;
179
180 let servers = load_servers_from_env()
181 .context("No backend servers configured via environment variables. Set NNTP_SERVER_0_HOST, NNTP_SERVER_0_PORT, etc.")?;
182
183 let config = Config {
184 servers,
185 ..Default::default()
186 };
187
188 config.validate()?;
190
191 Ok(config)
192}
193
194pub fn load_config(config_path: &str) -> Result<Config> {
204 use anyhow::Context;
205
206 let config_content = std::fs::read_to_string(config_path)
207 .with_context(|| format!("Failed to read config file '{}'", config_path))?;
208
209 let mut config: Config = toml::from_str(&config_content)
210 .with_context(|| format!("Failed to parse config file '{}'", config_path))?;
211
212 if let Some(env_servers) = load_servers_from_env() {
214 tracing::info!(
215 "Using {} backend server(s) from environment variables (overriding config file)",
216 env_servers.len()
217 );
218 config.servers = env_servers;
219 }
220
221 config.validate()?;
223
224 Ok(config)
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum ConfigSource {
230 File,
232 Environment,
234 DefaultCreated,
236}
237
238impl ConfigSource {
239 #[must_use]
241 pub fn description(&self) -> &'static str {
242 match self {
243 Self::File => "configuration file",
244 Self::Environment => "environment variables",
245 Self::DefaultCreated => "default configuration (created)",
246 }
247 }
248}
249
250pub fn load_config_with_fallback(config_path: &str) -> Result<(Config, ConfigSource)> {
269 use anyhow::Context;
270
271 if std::path::Path::new(config_path).exists() {
273 match load_config(config_path) {
274 Ok(config) => {
275 tracing::info!("Loaded configuration from file: {}", config_path);
276 return Ok((config, ConfigSource::File));
277 }
278 Err(e) => {
279 tracing::error!(
280 "Failed to load existing config file '{}': {}",
281 config_path,
282 e
283 );
284 tracing::error!("Please check your config file syntax and try again");
285 return Err(e);
286 }
287 }
288 }
289
290 if has_server_env_vars() {
292 match load_config_from_env() {
293 Ok(config) => {
294 tracing::info!(
295 "Using configuration from environment variables (no config file found)"
296 );
297 return Ok((config, ConfigSource::Environment));
298 }
299 Err(e) => {
300 tracing::error!(
301 "Failed to load configuration from environment variables: {}",
302 e
303 );
304 return Err(e);
305 }
306 }
307 }
308
309 tracing::warn!(
311 "Config file '{}' not found and no NNTP_SERVER_* environment variables set",
312 config_path
313 );
314 tracing::warn!("Creating default config file - please edit it to add your backend servers");
315
316 let default_config = create_default_config();
317 let config_toml =
318 toml::to_string_pretty(&default_config).context("Failed to serialize default config")?;
319
320 std::fs::write(config_path, &config_toml)
321 .with_context(|| format!("Failed to write default config to '{}'", config_path))?;
322
323 tracing::info!("Created default config file: {}", config_path);
324 Ok((default_config, ConfigSource::DefaultCreated))
325}
326
327#[must_use]
329pub fn create_default_config() -> Config {
330 Config {
331 servers: vec![Server {
332 host: crate::types::HostName::try_new("news.example.com".to_string())
333 .expect("Valid hostname"),
334 port: crate::types::Port::try_new(119).expect("Valid port"),
335 name: crate::types::ServerName::try_new("Example News Server".to_string())
336 .expect("Valid server name"),
337 username: None,
338 password: None,
339 max_connections: defaults::max_connections(),
340 use_tls: false,
341 tls_verify_cert: defaults::tls_verify_cert(),
342 tls_cert_path: None,
343 connection_keepalive: None,
344 health_check_max_per_cycle: defaults::health_check_max_per_cycle(),
345 health_check_pool_timeout: defaults::health_check_pool_timeout(),
346 tier: 0,
347 }],
348 ..Default::default()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use std::collections::HashMap;
356
357 struct MockEnv {
359 vars: HashMap<String, String>,
360 }
361
362 impl MockEnv {
363 fn new() -> Self {
364 Self {
365 vars: HashMap::new(),
366 }
367 }
368
369 fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
370 self.vars.insert(key.into(), value.into());
371 self
372 }
373 }
374
375 impl EnvProvider for MockEnv {
376 fn get(&self, key: &str) -> Option<String> {
377 self.vars.get(key).cloned()
378 }
379 }
380
381 #[test]
382 fn test_parse_server_from_env_minimal() {
383 let mut env = MockEnv::new();
384 env.set("NNTP_SERVER_0_HOST", "news.example.com");
385
386 let server = parse_server_from_env(0, &env);
387 assert!(server.is_some());
388
389 let server = server.unwrap();
390 assert_eq!(server.host.as_str(), "news.example.com");
391 assert_eq!(server.port.get(), 119); assert_eq!(server.name.as_str(), "Server 0"); assert!(server.username.is_none());
394 assert!(server.password.is_none());
395 }
396
397 #[test]
398 fn test_parse_server_from_env_full() {
399 let mut env = MockEnv::new();
400 env.set("NNTP_SERVER_0_HOST", "secure.example.com")
401 .set("NNTP_SERVER_0_PORT", "563")
402 .set("NNTP_SERVER_0_NAME", "Secure News")
403 .set("NNTP_SERVER_0_USERNAME", "testuser")
404 .set("NNTP_SERVER_0_PASSWORD", "testpass")
405 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "20")
406 .set("NNTP_SERVER_0_USE_TLS", "true")
407 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
408
409 let server = parse_server_from_env(0, &env).unwrap();
410 assert_eq!(server.host.as_str(), "secure.example.com");
411 assert_eq!(server.port.get(), 563);
412 assert_eq!(server.name.as_str(), "Secure News");
413 assert_eq!(server.username, Some("testuser".to_string()));
414 assert_eq!(server.password, Some("testpass".to_string()));
415 assert_eq!(server.max_connections.get(), 20);
416 assert!(server.use_tls);
417 assert!(!server.tls_verify_cert);
418 }
419
420 #[test]
421 fn test_parse_server_from_env_no_host() {
422 let env = MockEnv::new();
423 let server = parse_server_from_env(0, &env);
424 assert!(server.is_none());
425 }
426
427 #[test]
428 fn test_parse_server_from_env_invalid_port() {
429 let mut env = MockEnv::new();
430 env.set("NNTP_SERVER_0_HOST", "news.example.com")
431 .set("NNTP_SERVER_0_PORT", "invalid");
432
433 let server = parse_server_from_env(0, &env).unwrap();
434 assert_eq!(server.port.get(), 119); }
436
437 #[test]
438 fn test_parse_server_from_env_invalid_max_connections() {
439 let mut env = MockEnv::new();
440 env.set("NNTP_SERVER_0_HOST", "news.example.com")
441 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "not_a_number");
442
443 let server = parse_server_from_env(0, &env).unwrap();
444 assert_eq!(server.max_connections.get(), 10); }
446
447 #[test]
448 fn test_parse_server_from_env_zero_max_connections() {
449 let mut env = MockEnv::new();
450 env.set("NNTP_SERVER_0_HOST", "news.example.com")
451 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "0");
452
453 let server = parse_server_from_env(0, &env).unwrap();
454 assert_eq!(server.max_connections.get(), 10); }
456
457 #[test]
458 fn test_parse_server_from_env_keepalive() {
459 let mut env = MockEnv::new();
460 env.set("NNTP_SERVER_0_HOST", "news.example.com")
461 .set("NNTP_SERVER_0_CONNECTION_KEEPALIVE", "300");
462
463 let server = parse_server_from_env(0, &env).unwrap();
464 assert_eq!(
465 server.connection_keepalive,
466 Some(std::time::Duration::from_secs(300))
467 );
468 }
469
470 #[test]
471 fn test_parse_server_from_env_health_check_config() {
472 let mut env = MockEnv::new();
473 env.set("NNTP_SERVER_0_HOST", "news.example.com")
474 .set("NNTP_SERVER_0_HEALTH_CHECK_MAX_PER_CYCLE", "5")
475 .set("NNTP_SERVER_0_HEALTH_CHECK_POOL_TIMEOUT", "15");
476
477 let server = parse_server_from_env(0, &env).unwrap();
478 assert_eq!(server.health_check_max_per_cycle, 5);
479 assert_eq!(
480 server.health_check_pool_timeout,
481 std::time::Duration::from_secs(15)
482 );
483 }
484
485 #[test]
486 fn test_parse_server_from_env_tls_cert_path() {
487 let mut env = MockEnv::new();
488 env.set("NNTP_SERVER_0_HOST", "news.example.com")
489 .set("NNTP_SERVER_0_USE_TLS", "true")
490 .set("NNTP_SERVER_0_TLS_CERT_PATH", "/path/to/ca.pem");
491
492 let server = parse_server_from_env(0, &env).unwrap();
493 assert!(server.use_tls);
494 assert_eq!(server.tls_cert_path, Some("/path/to/ca.pem".to_string()));
495 }
496
497 #[test]
498 fn test_load_servers_from_env_provider_empty() {
499 let env = MockEnv::new();
500 let servers = load_servers_from_env_provider(&env);
501 assert!(servers.is_none());
502 }
503
504 #[test]
505 fn test_load_servers_from_env_provider_single() {
506 let mut env = MockEnv::new();
507 env.set("NNTP_SERVER_0_HOST", "news1.example.com");
508
509 let servers = load_servers_from_env_provider(&env);
510 assert!(servers.is_some());
511
512 let servers = servers.unwrap();
513 assert_eq!(servers.len(), 1);
514 assert_eq!(servers[0].host.as_str(), "news1.example.com");
515 }
516
517 #[test]
518 fn test_load_servers_from_env_provider_multiple() {
519 let mut env = MockEnv::new();
520 env.set("NNTP_SERVER_0_HOST", "news1.example.com")
521 .set("NNTP_SERVER_0_PORT", "119")
522 .set("NNTP_SERVER_1_HOST", "news2.example.com")
523 .set("NNTP_SERVER_1_PORT", "563")
524 .set("NNTP_SERVER_1_USE_TLS", "true")
525 .set("NNTP_SERVER_2_HOST", "news3.example.com");
526
527 let servers = load_servers_from_env_provider(&env);
528 assert!(servers.is_some());
529
530 let servers = servers.unwrap();
531 assert_eq!(servers.len(), 3);
532 assert_eq!(servers[0].host.as_str(), "news1.example.com");
533 assert_eq!(servers[1].host.as_str(), "news2.example.com");
534 assert_eq!(servers[2].host.as_str(), "news3.example.com");
535 assert!(servers[1].use_tls);
536 assert!(!servers[0].use_tls);
537 }
538
539 #[test]
540 fn test_load_servers_from_env_provider_gaps() {
541 let mut env = MockEnv::new();
542 env.set("NNTP_SERVER_0_HOST", "news1.example.com")
544 .set("NNTP_SERVER_2_HOST", "news3.example.com");
545
546 let servers = load_servers_from_env_provider(&env);
547 assert!(servers.is_some());
548
549 let servers = servers.unwrap();
550 assert_eq!(servers.len(), 1);
552 assert_eq!(servers[0].host.as_str(), "news1.example.com");
553 }
554
555 #[test]
556 fn test_parse_server_from_env_bool_variations() {
557 let mut env = MockEnv::new();
558 env.set("NNTP_SERVER_0_HOST", "news.example.com")
559 .set("NNTP_SERVER_0_USE_TLS", "True")
560 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "FALSE");
561
562 let server = parse_server_from_env(0, &env).unwrap();
563 assert!(!server.use_tls); assert!(server.tls_verify_cert); }
568
569 #[test]
570 fn test_parse_server_from_env_correct_bool() {
571 let mut env = MockEnv::new();
572 env.set("NNTP_SERVER_0_HOST", "news.example.com")
573 .set("NNTP_SERVER_0_USE_TLS", "true")
574 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
575
576 let server = parse_server_from_env(0, &env).unwrap();
577 assert!(server.use_tls);
578 assert!(!server.tls_verify_cert);
579 }
580
581 #[test]
582 fn test_config_source_description() {
583 assert_eq!(ConfigSource::File.description(), "configuration file");
584 assert_eq!(
585 ConfigSource::Environment.description(),
586 "environment variables"
587 );
588 assert_eq!(
589 ConfigSource::DefaultCreated.description(),
590 "default configuration (created)"
591 );
592 }
593
594 #[test]
595 fn test_config_source_equality() {
596 assert_eq!(ConfigSource::File, ConfigSource::File);
597 assert_ne!(ConfigSource::File, ConfigSource::Environment);
598 assert_ne!(ConfigSource::Environment, ConfigSource::DefaultCreated);
599 }
600
601 #[test]
602 fn test_load_config_with_fallback_creates_default() {
603 use tempfile::NamedTempFile;
604
605 let temp_file = NamedTempFile::new().unwrap();
606 let path = temp_file.path().to_str().unwrap().to_string();
607
608 drop(temp_file);
610
611 let result = load_config_with_fallback(&path);
613 assert!(result.is_ok());
614
615 let (config, source) = result.unwrap();
616 assert_eq!(source, ConfigSource::DefaultCreated);
617 assert_eq!(config.servers.len(), 1);
618 assert_eq!(config.servers[0].host.as_str(), "news.example.com");
619
620 let _ = std::fs::remove_file(&path);
622 }
623
624 #[test]
625 fn test_load_config_with_fallback_reads_existing() {
626 use std::io::Write;
627 use tempfile::NamedTempFile;
628
629 let mut temp_file = NamedTempFile::new().unwrap();
630
631 let config_content = r#"
633[[servers]]
634host = "test.example.com"
635port = 119
636name = "Test Server"
637"#;
638 temp_file.write_all(config_content.as_bytes()).unwrap();
639 temp_file.flush().unwrap();
640
641 let path = temp_file.path().to_str().unwrap().to_string();
643
644 let result = load_config_with_fallback(&path);
645 assert!(result.is_ok());
646
647 let (config, source) = result.unwrap();
648 assert_eq!(source, ConfigSource::File);
649 assert_eq!(config.servers.len(), 1);
650 assert_eq!(config.servers[0].host.as_str(), "test.example.com");
651 }
652
653 #[test]
654 fn test_create_default_config() {
655 let config = create_default_config();
656 assert_eq!(config.servers.len(), 1);
657 assert_eq!(config.servers[0].host.as_str(), "news.example.com");
658 assert_eq!(config.servers[0].port.get(), 119);
659 assert!(!config.servers[0].use_tls);
660 }
661}