1use anyhow::Result;
7
8use super::defaults;
9use super::types::{Config, Server};
10
11pub trait EnvProvider {
13 fn get(&self, key: &str) -> Option<String>;
14}
15
16#[derive(Default)]
18pub struct StdEnvProvider;
19
20impl EnvProvider for StdEnvProvider {
21 fn get(&self, key: &str) -> Option<String> {
22 std::env::var(key).ok()
23 }
24}
25
26pub fn parse_server_from_env<E: EnvProvider>(index: usize, env: &E) -> Option<Server> {
35 let host_key = format!("NNTP_SERVER_{}_HOST", index);
37 let host = env.get(&host_key)?;
38
39 let port_key = format!("NNTP_SERVER_{}_PORT", index);
41 let port = env
42 .get(&port_key)
43 .and_then(|p| p.parse::<u16>().ok())
44 .unwrap_or(119); let name_key = format!("NNTP_SERVER_{}_NAME", index);
48 let name = env
49 .get(&name_key)
50 .unwrap_or_else(|| format!("Server {}", index));
51
52 let username_key = format!("NNTP_SERVER_{}_USERNAME", index);
54 let username = env.get(&username_key);
55
56 let password_key = format!("NNTP_SERVER_{}_PASSWORD", index);
57 let password = env.get(&password_key);
58
59 let max_conn_key = format!("NNTP_SERVER_{}_MAX_CONNECTIONS", index);
60 let max_connections = env
61 .get(&max_conn_key)
62 .and_then(|m| m.parse::<usize>().ok())
63 .and_then(crate::types::MaxConnections::new)
64 .unwrap_or_else(defaults::max_connections);
65
66 let use_tls_key = format!("NNTP_SERVER_{}_USE_TLS", index);
68 let use_tls = env
69 .get(&use_tls_key)
70 .and_then(|v| v.parse::<bool>().ok())
71 .unwrap_or(false);
72
73 let tls_verify_key = format!("NNTP_SERVER_{}_TLS_VERIFY_CERT", index);
74 let tls_verify_cert = env
75 .get(&tls_verify_key)
76 .and_then(|v| v.parse::<bool>().ok())
77 .unwrap_or_else(defaults::tls_verify_cert);
78
79 let tls_cert_path_key = format!("NNTP_SERVER_{}_TLS_CERT_PATH", index);
80 let tls_cert_path = env.get(&tls_cert_path_key);
81
82 let keepalive_key = format!("NNTP_SERVER_{}_CONNECTION_KEEPALIVE", index);
84 let connection_keepalive = env
85 .get(&keepalive_key)
86 .and_then(|k| k.parse::<u64>().ok())
87 .map(std::time::Duration::from_secs);
88
89 let health_max_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_MAX_PER_CYCLE", index);
91 let health_check_max_per_cycle = env
92 .get(&health_max_key)
93 .and_then(|h| h.parse::<usize>().ok())
94 .unwrap_or_else(defaults::health_check_max_per_cycle);
95
96 let health_timeout_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_POOL_TIMEOUT", index);
97 let health_check_pool_timeout = env
98 .get(&health_timeout_key)
99 .and_then(|h| h.parse::<u64>().ok())
100 .map(std::time::Duration::from_secs)
101 .unwrap_or_else(defaults::health_check_pool_timeout);
102
103 Some(Server {
104 host: crate::types::HostName::new(host.clone())
105 .unwrap_or_else(|_| panic!("Invalid hostname in {}: '{}'", host_key, host)),
106 port: crate::types::Port::new(port)
107 .unwrap_or_else(|| panic!("Invalid port in {}: {}", port_key, port)),
108 name: crate::types::ServerName::new(name.clone())
109 .unwrap_or_else(|_| panic!("Invalid server name in {}: '{}'", name_key, name)),
110 username,
111 password,
112 max_connections,
113 use_tls,
114 tls_verify_cert,
115 tls_cert_path,
116 connection_keepalive,
117 health_check_max_per_cycle,
118 health_check_pool_timeout,
119 })
120}
121
122fn load_servers_from_env() -> Option<Vec<Server>> {
133 load_servers_from_env_provider(&StdEnvProvider)
134}
135
136pub fn load_servers_from_env_provider<E: EnvProvider>(env: &E) -> Option<Vec<Server>> {
138 let servers: Vec<Server> = (0..)
139 .map(|i| parse_server_from_env(i, env))
140 .take_while(|s| s.is_some())
141 .flatten()
142 .collect();
143
144 if servers.is_empty() {
145 None
146 } else {
147 Some(servers)
148 }
149}
150
151pub fn has_server_env_vars() -> bool {
155 std::env::var("NNTP_SERVER_0_HOST").is_ok()
156}
157
158pub fn load_config_from_env() -> Result<Config> {
166 use anyhow::Context;
167
168 let servers = load_servers_from_env()
169 .context("No backend servers configured via environment variables. Set NNTP_SERVER_0_HOST, NNTP_SERVER_0_PORT, etc.")?;
170
171 let config = Config {
172 servers,
173 ..Default::default()
174 };
175
176 config.validate()?;
178
179 Ok(config)
180}
181
182pub fn load_config(config_path: &str) -> Result<Config> {
192 use anyhow::Context;
193
194 let config_content = std::fs::read_to_string(config_path)
195 .with_context(|| format!("Failed to read config file '{}'", config_path))?;
196
197 let mut config: Config = toml::from_str(&config_content)
198 .with_context(|| format!("Failed to parse config file '{}'", config_path))?;
199
200 if let Some(env_servers) = load_servers_from_env() {
202 tracing::info!(
203 "Using {} backend server(s) from environment variables (overriding config file)",
204 env_servers.len()
205 );
206 config.servers = env_servers;
207 }
208
209 config.validate()?;
211
212 Ok(config)
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217pub enum ConfigSource {
218 File,
220 Environment,
222 DefaultCreated,
224}
225
226impl ConfigSource {
227 #[must_use]
229 pub fn description(&self) -> &'static str {
230 match self {
231 Self::File => "configuration file",
232 Self::Environment => "environment variables",
233 Self::DefaultCreated => "default configuration (created)",
234 }
235 }
236}
237
238pub fn load_config_with_fallback(config_path: &str) -> Result<(Config, ConfigSource)> {
257 use anyhow::Context;
258
259 if std::path::Path::new(config_path).exists() {
261 match load_config(config_path) {
262 Ok(config) => {
263 tracing::info!("Loaded configuration from file: {}", config_path);
264 return Ok((config, ConfigSource::File));
265 }
266 Err(e) => {
267 tracing::error!(
268 "Failed to load existing config file '{}': {}",
269 config_path,
270 e
271 );
272 tracing::error!("Please check your config file syntax and try again");
273 return Err(e);
274 }
275 }
276 }
277
278 if has_server_env_vars() {
280 match load_config_from_env() {
281 Ok(config) => {
282 tracing::info!(
283 "Using configuration from environment variables (no config file found)"
284 );
285 return Ok((config, ConfigSource::Environment));
286 }
287 Err(e) => {
288 tracing::error!(
289 "Failed to load configuration from environment variables: {}",
290 e
291 );
292 return Err(e);
293 }
294 }
295 }
296
297 tracing::warn!(
299 "Config file '{}' not found and no NNTP_SERVER_* environment variables set",
300 config_path
301 );
302 tracing::warn!("Creating default config file - please edit it to add your backend servers");
303
304 let default_config = create_default_config();
305 let config_toml =
306 toml::to_string_pretty(&default_config).context("Failed to serialize default config")?;
307
308 std::fs::write(config_path, &config_toml)
309 .with_context(|| format!("Failed to write default config to '{}'", config_path))?;
310
311 tracing::info!("Created default config file: {}", config_path);
312 Ok((default_config, ConfigSource::DefaultCreated))
313}
314
315#[must_use]
317pub fn create_default_config() -> Config {
318 Config {
319 servers: vec![Server {
320 host: crate::types::HostName::new("news.example.com".to_string())
321 .expect("Valid hostname"),
322 port: crate::types::Port::new(119).expect("Valid port"),
323 name: crate::types::ServerName::new("Example News Server".to_string())
324 .expect("Valid server name"),
325 username: None,
326 password: None,
327 max_connections: defaults::max_connections(),
328 use_tls: false,
329 tls_verify_cert: defaults::tls_verify_cert(),
330 tls_cert_path: None,
331 connection_keepalive: None,
332 health_check_max_per_cycle: defaults::health_check_max_per_cycle(),
333 health_check_pool_timeout: defaults::health_check_pool_timeout(),
334 }],
335 ..Default::default()
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use std::collections::HashMap;
343
344 struct MockEnv {
346 vars: HashMap<String, String>,
347 }
348
349 impl MockEnv {
350 fn new() -> Self {
351 Self {
352 vars: HashMap::new(),
353 }
354 }
355
356 fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
357 self.vars.insert(key.into(), value.into());
358 self
359 }
360 }
361
362 impl EnvProvider for MockEnv {
363 fn get(&self, key: &str) -> Option<String> {
364 self.vars.get(key).cloned()
365 }
366 }
367
368 #[test]
369 fn test_parse_server_from_env_minimal() {
370 let mut env = MockEnv::new();
371 env.set("NNTP_SERVER_0_HOST", "news.example.com");
372
373 let server = parse_server_from_env(0, &env);
374 assert!(server.is_some());
375
376 let server = server.unwrap();
377 assert_eq!(server.host.as_str(), "news.example.com");
378 assert_eq!(server.port.get(), 119); assert_eq!(server.name.as_str(), "Server 0"); assert!(server.username.is_none());
381 assert!(server.password.is_none());
382 }
383
384 #[test]
385 fn test_parse_server_from_env_full() {
386 let mut env = MockEnv::new();
387 env.set("NNTP_SERVER_0_HOST", "secure.example.com")
388 .set("NNTP_SERVER_0_PORT", "563")
389 .set("NNTP_SERVER_0_NAME", "Secure News")
390 .set("NNTP_SERVER_0_USERNAME", "testuser")
391 .set("NNTP_SERVER_0_PASSWORD", "testpass")
392 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "20")
393 .set("NNTP_SERVER_0_USE_TLS", "true")
394 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
395
396 let server = parse_server_from_env(0, &env).unwrap();
397 assert_eq!(server.host.as_str(), "secure.example.com");
398 assert_eq!(server.port.get(), 563);
399 assert_eq!(server.name.as_str(), "Secure News");
400 assert_eq!(server.username, Some("testuser".to_string()));
401 assert_eq!(server.password, Some("testpass".to_string()));
402 assert_eq!(server.max_connections.get(), 20);
403 assert!(server.use_tls);
404 assert!(!server.tls_verify_cert);
405 }
406
407 #[test]
408 fn test_parse_server_from_env_no_host() {
409 let env = MockEnv::new();
410 let server = parse_server_from_env(0, &env);
411 assert!(server.is_none());
412 }
413
414 #[test]
415 fn test_parse_server_from_env_invalid_port() {
416 let mut env = MockEnv::new();
417 env.set("NNTP_SERVER_0_HOST", "news.example.com")
418 .set("NNTP_SERVER_0_PORT", "invalid");
419
420 let server = parse_server_from_env(0, &env).unwrap();
421 assert_eq!(server.port.get(), 119); }
423
424 #[test]
425 fn test_parse_server_from_env_invalid_max_connections() {
426 let mut env = MockEnv::new();
427 env.set("NNTP_SERVER_0_HOST", "news.example.com")
428 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "not_a_number");
429
430 let server = parse_server_from_env(0, &env).unwrap();
431 assert_eq!(server.max_connections.get(), 10); }
433
434 #[test]
435 fn test_parse_server_from_env_zero_max_connections() {
436 let mut env = MockEnv::new();
437 env.set("NNTP_SERVER_0_HOST", "news.example.com")
438 .set("NNTP_SERVER_0_MAX_CONNECTIONS", "0");
439
440 let server = parse_server_from_env(0, &env).unwrap();
441 assert_eq!(server.max_connections.get(), 10); }
443
444 #[test]
445 fn test_parse_server_from_env_keepalive() {
446 let mut env = MockEnv::new();
447 env.set("NNTP_SERVER_0_HOST", "news.example.com")
448 .set("NNTP_SERVER_0_CONNECTION_KEEPALIVE", "300");
449
450 let server = parse_server_from_env(0, &env).unwrap();
451 assert_eq!(
452 server.connection_keepalive,
453 Some(std::time::Duration::from_secs(300))
454 );
455 }
456
457 #[test]
458 fn test_parse_server_from_env_health_check_config() {
459 let mut env = MockEnv::new();
460 env.set("NNTP_SERVER_0_HOST", "news.example.com")
461 .set("NNTP_SERVER_0_HEALTH_CHECK_MAX_PER_CYCLE", "5")
462 .set("NNTP_SERVER_0_HEALTH_CHECK_POOL_TIMEOUT", "15");
463
464 let server = parse_server_from_env(0, &env).unwrap();
465 assert_eq!(server.health_check_max_per_cycle, 5);
466 assert_eq!(
467 server.health_check_pool_timeout,
468 std::time::Duration::from_secs(15)
469 );
470 }
471
472 #[test]
473 fn test_parse_server_from_env_tls_cert_path() {
474 let mut env = MockEnv::new();
475 env.set("NNTP_SERVER_0_HOST", "news.example.com")
476 .set("NNTP_SERVER_0_USE_TLS", "true")
477 .set("NNTP_SERVER_0_TLS_CERT_PATH", "/path/to/ca.pem");
478
479 let server = parse_server_from_env(0, &env).unwrap();
480 assert!(server.use_tls);
481 assert_eq!(server.tls_cert_path, Some("/path/to/ca.pem".to_string()));
482 }
483
484 #[test]
485 fn test_load_servers_from_env_provider_empty() {
486 let env = MockEnv::new();
487 let servers = load_servers_from_env_provider(&env);
488 assert!(servers.is_none());
489 }
490
491 #[test]
492 fn test_load_servers_from_env_provider_single() {
493 let mut env = MockEnv::new();
494 env.set("NNTP_SERVER_0_HOST", "news1.example.com");
495
496 let servers = load_servers_from_env_provider(&env);
497 assert!(servers.is_some());
498
499 let servers = servers.unwrap();
500 assert_eq!(servers.len(), 1);
501 assert_eq!(servers[0].host.as_str(), "news1.example.com");
502 }
503
504 #[test]
505 fn test_load_servers_from_env_provider_multiple() {
506 let mut env = MockEnv::new();
507 env.set("NNTP_SERVER_0_HOST", "news1.example.com")
508 .set("NNTP_SERVER_0_PORT", "119")
509 .set("NNTP_SERVER_1_HOST", "news2.example.com")
510 .set("NNTP_SERVER_1_PORT", "563")
511 .set("NNTP_SERVER_1_USE_TLS", "true")
512 .set("NNTP_SERVER_2_HOST", "news3.example.com");
513
514 let servers = load_servers_from_env_provider(&env);
515 assert!(servers.is_some());
516
517 let servers = servers.unwrap();
518 assert_eq!(servers.len(), 3);
519 assert_eq!(servers[0].host.as_str(), "news1.example.com");
520 assert_eq!(servers[1].host.as_str(), "news2.example.com");
521 assert_eq!(servers[2].host.as_str(), "news3.example.com");
522 assert!(servers[1].use_tls);
523 assert!(!servers[0].use_tls);
524 }
525
526 #[test]
527 fn test_load_servers_from_env_provider_gaps() {
528 let mut env = MockEnv::new();
529 env.set("NNTP_SERVER_0_HOST", "news1.example.com")
531 .set("NNTP_SERVER_2_HOST", "news3.example.com");
532
533 let servers = load_servers_from_env_provider(&env);
534 assert!(servers.is_some());
535
536 let servers = servers.unwrap();
537 assert_eq!(servers.len(), 1);
539 assert_eq!(servers[0].host.as_str(), "news1.example.com");
540 }
541
542 #[test]
543 fn test_parse_server_from_env_bool_variations() {
544 let mut env = MockEnv::new();
545 env.set("NNTP_SERVER_0_HOST", "news.example.com")
546 .set("NNTP_SERVER_0_USE_TLS", "True")
547 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "FALSE");
548
549 let server = parse_server_from_env(0, &env).unwrap();
550 assert!(!server.use_tls); assert!(server.tls_verify_cert); }
555
556 #[test]
557 fn test_parse_server_from_env_correct_bool() {
558 let mut env = MockEnv::new();
559 env.set("NNTP_SERVER_0_HOST", "news.example.com")
560 .set("NNTP_SERVER_0_USE_TLS", "true")
561 .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
562
563 let server = parse_server_from_env(0, &env).unwrap();
564 assert!(server.use_tls);
565 assert!(!server.tls_verify_cert);
566 }
567
568 #[test]
569 fn test_config_source_description() {
570 assert_eq!(ConfigSource::File.description(), "configuration file");
571 assert_eq!(
572 ConfigSource::Environment.description(),
573 "environment variables"
574 );
575 assert_eq!(
576 ConfigSource::DefaultCreated.description(),
577 "default configuration (created)"
578 );
579 }
580
581 #[test]
582 fn test_config_source_equality() {
583 assert_eq!(ConfigSource::File, ConfigSource::File);
584 assert_ne!(ConfigSource::File, ConfigSource::Environment);
585 assert_ne!(ConfigSource::Environment, ConfigSource::DefaultCreated);
586 }
587
588 #[test]
589 fn test_load_config_with_fallback_creates_default() {
590 use tempfile::NamedTempFile;
591
592 let temp_file = NamedTempFile::new().unwrap();
593 let path = temp_file.path().to_str().unwrap().to_string();
594
595 drop(temp_file);
597
598 let result = load_config_with_fallback(&path);
600 assert!(result.is_ok());
601
602 let (config, source) = result.unwrap();
603 assert_eq!(source, ConfigSource::DefaultCreated);
604 assert_eq!(config.servers.len(), 1);
605 assert_eq!(config.servers[0].host.as_str(), "news.example.com");
606
607 let _ = std::fs::remove_file(&path);
609 }
610
611 #[test]
612 fn test_load_config_with_fallback_reads_existing() {
613 use std::io::Write;
614 use tempfile::NamedTempFile;
615
616 let mut temp_file = NamedTempFile::new().unwrap();
617
618 let config_content = r#"
620[[servers]]
621host = "test.example.com"
622port = 119
623name = "Test Server"
624"#;
625 temp_file.write_all(config_content.as_bytes()).unwrap();
626 temp_file.flush().unwrap();
627
628 let path = temp_file.path().to_str().unwrap().to_string();
630
631 let result = load_config_with_fallback(&path);
632 assert!(result.is_ok());
633
634 let (config, source) = result.unwrap();
635 assert_eq!(source, ConfigSource::File);
636 assert_eq!(config.servers.len(), 1);
637 assert_eq!(config.servers[0].host.as_str(), "test.example.com");
638 }
639
640 #[test]
641 fn test_create_default_config() {
642 let config = create_default_config();
643 assert_eq!(config.servers.len(), 1);
644 assert_eq!(config.servers[0].host.as_str(), "news.example.com");
645 assert_eq!(config.servers[0].port.get(), 119);
646 assert!(!config.servers[0].use_tls);
647 }
648}