Skip to main content

nntp_proxy/config/
loading.rs

1//! Configuration loading from files and environment variables
2//!
3//! This module handles loading configuration from TOML files and environment variables,
4//! with environment variables taking precedence for Docker/container deployments.
5
6use anyhow::Result;
7
8use super::defaults;
9use super::types::{Config, Server};
10
11/// Environment variable getter trait for dependency injection
12pub trait EnvProvider {
13    fn get(&self, key: &str) -> Option<String>;
14}
15
16/// Standard environment provider using std::env::var
17#[derive(Default)]
18pub struct StdEnvProvider;
19
20impl EnvProvider for StdEnvProvider {
21    fn get(&self, key: &str) -> Option<String> {
22        std::env::var(key).ok()
23    }
24}
25
26/// Parse server configuration from environment variables (pure function, easily testable)
27///
28/// # Arguments
29/// * `index` - Server index (0, 1, 2, ...)
30/// * `env` - Environment variable provider
31///
32/// # Returns
33/// Some(Server) if HOST variable exists, None otherwise
34pub fn parse_server_from_env<E: EnvProvider>(index: usize, env: &E) -> Option<Server> {
35    // Check if this server index exists by looking for HOST
36    let host_key = format!("NNTP_SERVER_{}_HOST", index);
37    let host = env.get(&host_key)?;
38
39    // Parse port (required)
40    let port_key = format!("NNTP_SERVER_{}_PORT", index);
41    let port = env
42        .get(&port_key)
43        .and_then(|p| p.parse::<u16>().ok())
44        .unwrap_or(119); // Default NNTP port
45
46    // Get name (required, use host as fallback)
47    let name_key = format!("NNTP_SERVER_{}_NAME", index);
48    let name = env
49        .get(&name_key)
50        .unwrap_or_else(|| format!("Server {}", index));
51
52    // Optional fields
53    let username_key = format!("NNTP_SERVER_{}_USERNAME", index);
54    let username = env.get(&username_key);
55
56    let password_key = format!("NNTP_SERVER_{}_PASSWORD", index);
57    let password = env.get(&password_key);
58
59    let max_conn_key = format!("NNTP_SERVER_{}_MAX_CONNECTIONS", index);
60    let max_connections = env
61        .get(&max_conn_key)
62        .and_then(|m| m.parse::<usize>().ok())
63        .and_then(|m| crate::types::MaxConnections::try_new(m).ok())
64        .unwrap_or_else(defaults::max_connections);
65
66    // TLS configuration
67    let use_tls_key = format!("NNTP_SERVER_{}_USE_TLS", index);
68    let use_tls = env
69        .get(&use_tls_key)
70        .and_then(|v| v.parse::<bool>().ok())
71        .unwrap_or(false);
72
73    let tls_verify_key = format!("NNTP_SERVER_{}_TLS_VERIFY_CERT", index);
74    let tls_verify_cert = env
75        .get(&tls_verify_key)
76        .and_then(|v| v.parse::<bool>().ok())
77        .unwrap_or_else(defaults::tls_verify_cert);
78
79    let tls_cert_path_key = format!("NNTP_SERVER_{}_TLS_CERT_PATH", index);
80    let tls_cert_path = env.get(&tls_cert_path_key);
81
82    // Connection keepalive (in seconds)
83    let keepalive_key = format!("NNTP_SERVER_{}_CONNECTION_KEEPALIVE", index);
84    let connection_keepalive = env
85        .get(&keepalive_key)
86        .and_then(|k| k.parse::<u64>().ok())
87        .map(std::time::Duration::from_secs);
88
89    // Health check configuration
90    let health_max_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_MAX_PER_CYCLE", index);
91    let health_check_max_per_cycle = env
92        .get(&health_max_key)
93        .and_then(|h| h.parse::<usize>().ok())
94        .unwrap_or_else(defaults::health_check_max_per_cycle);
95
96    let health_timeout_key = format!("NNTP_SERVER_{}_HEALTH_CHECK_POOL_TIMEOUT", index);
97    let health_check_pool_timeout = env
98        .get(&health_timeout_key)
99        .and_then(|h| h.parse::<u64>().ok())
100        .map(std::time::Duration::from_secs)
101        .unwrap_or_else(defaults::health_check_pool_timeout);
102
103    let tier_key = format!("NNTP_SERVER_{}_TIER", index);
104    let tier = match env.get(&tier_key) {
105        Some(tier_str) => tier_str.parse::<u8>().unwrap_or_else(|_| {
106            panic!(
107                "Invalid tier in {}: '{}' (must be 0-255)",
108                tier_key, tier_str
109            )
110        }),
111        None => 0,
112    };
113
114    Some(Server {
115        host: crate::types::HostName::try_new(host.clone())
116            .unwrap_or_else(|_| panic!("Invalid hostname in {}: '{}'", host_key, host)),
117        port: crate::types::Port::try_new(port)
118            .unwrap_or_else(|_| panic!("Invalid port in {}: {}", port_key, port)),
119        name: crate::types::ServerName::try_new(name.clone())
120            .unwrap_or_else(|_| panic!("Invalid server name in {}: '{}'", name_key, name)),
121        username,
122        password,
123        max_connections,
124        use_tls,
125        tls_verify_cert,
126        tls_cert_path,
127        connection_keepalive,
128        health_check_max_per_cycle,
129        health_check_pool_timeout,
130        tier,
131    })
132}
133
134/// Load backend server configuration from environment variables
135///
136/// Supports indexed environment variables for Docker/container deployments:
137/// - `NNTP_SERVER_0_HOST`, `NNTP_SERVER_0_PORT`, `NNTP_SERVER_0_NAME`, etc.
138/// - `NNTP_SERVER_1_HOST`, `NNTP_SERVER_1_PORT`, `NNTP_SERVER_1_NAME`, etc.
139///
140/// Optional per-server variables:
141/// - `NNTP_SERVER_N_USERNAME` - Backend authentication username
142/// - `NNTP_SERVER_N_PASSWORD` - Backend authentication password
143/// - `NNTP_SERVER_N_MAX_CONNECTIONS` - Max connections (default: 10)
144fn load_servers_from_env() -> Option<Vec<Server>> {
145    load_servers_from_env_provider(&StdEnvProvider)
146}
147
148/// Load servers using a custom environment provider (testable version)
149pub fn load_servers_from_env_provider<E: EnvProvider>(env: &E) -> Option<Vec<Server>> {
150    let servers: Vec<Server> = (0..)
151        .map(|i| parse_server_from_env(i, env))
152        .take_while(|s| s.is_some())
153        .flatten()
154        .collect();
155
156    if servers.is_empty() {
157        None
158    } else {
159        Some(servers)
160    }
161}
162
163/// Check if any backend server environment variables are set
164///
165/// Returns true if at least NNTP_SERVER_0_HOST is set
166pub fn has_server_env_vars() -> bool {
167    std::env::var("NNTP_SERVER_0_HOST").is_ok()
168}
169
170/// Load configuration from environment variables only
171///
172/// Used when no config file is present. Requires at least NNTP_SERVER_0_HOST to be set.
173///
174/// # Errors
175///
176/// Returns an error if no backend servers are configured via environment variables.
177pub fn load_config_from_env() -> Result<Config> {
178    use anyhow::Context;
179
180    let servers = load_servers_from_env()
181        .context("No backend servers configured via environment variables. Set NNTP_SERVER_0_HOST, NNTP_SERVER_0_PORT, etc.")?;
182
183    let config = Config {
184        servers,
185        ..Default::default()
186    };
187
188    // Validate the loaded configuration
189    config.validate()?;
190
191    Ok(config)
192}
193
194/// Load configuration from a TOML file, with environment variable overrides
195///
196/// Environment variables for backend servers take precedence over config file:
197/// - `NNTP_SERVER_0_HOST`, `NNTP_SERVER_0_PORT`, `NNTP_SERVER_0_NAME`
198/// - `NNTP_SERVER_1_HOST`, `NNTP_SERVER_1_PORT`, `NNTP_SERVER_1_NAME`
199/// - etc.
200///
201/// This allows Docker/container deployments to override servers without
202/// modifying the config file.
203pub fn load_config(config_path: &str) -> Result<Config> {
204    use anyhow::Context;
205
206    let config_content = std::fs::read_to_string(config_path)
207        .with_context(|| format!("Failed to read config file '{}'", config_path))?;
208
209    let mut config: Config = toml::from_str(&config_content)
210        .with_context(|| format!("Failed to parse config file '{}'", config_path))?;
211
212    // Check for environment variable server overrides
213    if let Some(env_servers) = load_servers_from_env() {
214        tracing::info!(
215            "Using {} backend server(s) from environment variables (overriding config file)",
216            env_servers.len()
217        );
218        config.servers = env_servers;
219    }
220
221    // Validate the loaded configuration
222    config.validate()?;
223
224    Ok(config)
225}
226
227/// Configuration source
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum ConfigSource {
230    /// Loaded from TOML file
231    File,
232    /// Loaded from environment variables
233    Environment,
234    /// Default config created (file doesn't exist)
235    DefaultCreated,
236}
237
238impl ConfigSource {
239    /// Get a human-readable description
240    #[must_use]
241    pub fn description(&self) -> &'static str {
242        match self {
243            Self::File => "configuration file",
244            Self::Environment => "environment variables",
245            Self::DefaultCreated => "default configuration (created)",
246        }
247    }
248}
249
250/// Load configuration with automatic fallback logic
251///
252/// Attempts to load configuration in this order:
253/// 1. If config file exists, load from file (with env var overrides)
254/// 2. Else if environment variables exist (`NNTP_SERVER_*`), load from env
255/// 3. Else create default config file and return default config
256///
257/// # Arguments
258/// * `config_path` - Path to configuration file
259///
260/// # Returns
261/// Tuple of (Config, ConfigSource) indicating where config came from
262///
263/// # Errors
264/// Returns error if:
265/// - Config file exists but can't be read or parsed
266/// - Environment variables exist but are invalid
267/// - Default config can't be created
268pub fn load_config_with_fallback(config_path: &str) -> Result<(Config, ConfigSource)> {
269    use anyhow::Context;
270
271    // Check if config file exists
272    if std::path::Path::new(config_path).exists() {
273        match load_config(config_path) {
274            Ok(config) => {
275                tracing::info!("Loaded configuration from file: {}", config_path);
276                return Ok((config, ConfigSource::File));
277            }
278            Err(e) => {
279                tracing::error!(
280                    "Failed to load existing config file '{}': {}",
281                    config_path,
282                    e
283                );
284                tracing::error!("Please check your config file syntax and try again");
285                return Err(e);
286            }
287        }
288    }
289
290    // Config file doesn't exist - check for environment variables
291    if has_server_env_vars() {
292        match load_config_from_env() {
293            Ok(config) => {
294                tracing::info!(
295                    "Using configuration from environment variables (no config file found)"
296                );
297                return Ok((config, ConfigSource::Environment));
298            }
299            Err(e) => {
300                tracing::error!(
301                    "Failed to load configuration from environment variables: {}",
302                    e
303                );
304                return Err(e);
305            }
306        }
307    }
308
309    // No config file and no env vars - create default
310    tracing::warn!(
311        "Config file '{}' not found and no NNTP_SERVER_* environment variables set",
312        config_path
313    );
314    tracing::warn!("Creating default config file - please edit it to add your backend servers");
315
316    let default_config = create_default_config();
317    let config_toml =
318        toml::to_string_pretty(&default_config).context("Failed to serialize default config")?;
319
320    std::fs::write(config_path, &config_toml)
321        .with_context(|| format!("Failed to write default config to '{}'", config_path))?;
322
323    tracing::info!("Created default config file: {}", config_path);
324    Ok((default_config, ConfigSource::DefaultCreated))
325}
326
327/// Create a default configuration for examples/testing
328#[must_use]
329pub fn create_default_config() -> Config {
330    Config {
331        servers: vec![Server {
332            host: crate::types::HostName::try_new("news.example.com".to_string())
333                .expect("Valid hostname"),
334            port: crate::types::Port::try_new(119).expect("Valid port"),
335            name: crate::types::ServerName::try_new("Example News Server".to_string())
336                .expect("Valid server name"),
337            username: None,
338            password: None,
339            max_connections: defaults::max_connections(),
340            use_tls: false,
341            tls_verify_cert: defaults::tls_verify_cert(),
342            tls_cert_path: None,
343            connection_keepalive: None,
344            health_check_max_per_cycle: defaults::health_check_max_per_cycle(),
345            health_check_pool_timeout: defaults::health_check_pool_timeout(),
346            tier: 0,
347        }],
348        ..Default::default()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use std::collections::HashMap;
356
357    // Mock environment provider for testing
358    struct MockEnv {
359        vars: HashMap<String, String>,
360    }
361
362    impl MockEnv {
363        fn new() -> Self {
364            Self {
365                vars: HashMap::new(),
366            }
367        }
368
369        fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
370            self.vars.insert(key.into(), value.into());
371            self
372        }
373    }
374
375    impl EnvProvider for MockEnv {
376        fn get(&self, key: &str) -> Option<String> {
377            self.vars.get(key).cloned()
378        }
379    }
380
381    #[test]
382    fn test_parse_server_from_env_minimal() {
383        let mut env = MockEnv::new();
384        env.set("NNTP_SERVER_0_HOST", "news.example.com");
385
386        let server = parse_server_from_env(0, &env);
387        assert!(server.is_some());
388
389        let server = server.unwrap();
390        assert_eq!(server.host.as_str(), "news.example.com");
391        assert_eq!(server.port.get(), 119); // Default port
392        assert_eq!(server.name.as_str(), "Server 0"); // Default name
393        assert!(server.username.is_none());
394        assert!(server.password.is_none());
395    }
396
397    #[test]
398    fn test_parse_server_from_env_full() {
399        let mut env = MockEnv::new();
400        env.set("NNTP_SERVER_0_HOST", "secure.example.com")
401            .set("NNTP_SERVER_0_PORT", "563")
402            .set("NNTP_SERVER_0_NAME", "Secure News")
403            .set("NNTP_SERVER_0_USERNAME", "testuser")
404            .set("NNTP_SERVER_0_PASSWORD", "testpass")
405            .set("NNTP_SERVER_0_MAX_CONNECTIONS", "20")
406            .set("NNTP_SERVER_0_USE_TLS", "true")
407            .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
408
409        let server = parse_server_from_env(0, &env).unwrap();
410        assert_eq!(server.host.as_str(), "secure.example.com");
411        assert_eq!(server.port.get(), 563);
412        assert_eq!(server.name.as_str(), "Secure News");
413        assert_eq!(server.username, Some("testuser".to_string()));
414        assert_eq!(server.password, Some("testpass".to_string()));
415        assert_eq!(server.max_connections.get(), 20);
416        assert!(server.use_tls);
417        assert!(!server.tls_verify_cert);
418    }
419
420    #[test]
421    fn test_parse_server_from_env_no_host() {
422        let env = MockEnv::new();
423        let server = parse_server_from_env(0, &env);
424        assert!(server.is_none());
425    }
426
427    #[test]
428    fn test_parse_server_from_env_invalid_port() {
429        let mut env = MockEnv::new();
430        env.set("NNTP_SERVER_0_HOST", "news.example.com")
431            .set("NNTP_SERVER_0_PORT", "invalid");
432
433        let server = parse_server_from_env(0, &env).unwrap();
434        assert_eq!(server.port.get(), 119); // Falls back to default
435    }
436
437    #[test]
438    fn test_parse_server_from_env_invalid_max_connections() {
439        let mut env = MockEnv::new();
440        env.set("NNTP_SERVER_0_HOST", "news.example.com")
441            .set("NNTP_SERVER_0_MAX_CONNECTIONS", "not_a_number");
442
443        let server = parse_server_from_env(0, &env).unwrap();
444        assert_eq!(server.max_connections.get(), 10); // Default
445    }
446
447    #[test]
448    fn test_parse_server_from_env_zero_max_connections() {
449        let mut env = MockEnv::new();
450        env.set("NNTP_SERVER_0_HOST", "news.example.com")
451            .set("NNTP_SERVER_0_MAX_CONNECTIONS", "0");
452
453        let server = parse_server_from_env(0, &env).unwrap();
454        assert_eq!(server.max_connections.get(), 10); // Falls back to default (NonZero rejects 0)
455    }
456
457    #[test]
458    fn test_parse_server_from_env_keepalive() {
459        let mut env = MockEnv::new();
460        env.set("NNTP_SERVER_0_HOST", "news.example.com")
461            .set("NNTP_SERVER_0_CONNECTION_KEEPALIVE", "300");
462
463        let server = parse_server_from_env(0, &env).unwrap();
464        assert_eq!(
465            server.connection_keepalive,
466            Some(std::time::Duration::from_secs(300))
467        );
468    }
469
470    #[test]
471    fn test_parse_server_from_env_health_check_config() {
472        let mut env = MockEnv::new();
473        env.set("NNTP_SERVER_0_HOST", "news.example.com")
474            .set("NNTP_SERVER_0_HEALTH_CHECK_MAX_PER_CYCLE", "5")
475            .set("NNTP_SERVER_0_HEALTH_CHECK_POOL_TIMEOUT", "15");
476
477        let server = parse_server_from_env(0, &env).unwrap();
478        assert_eq!(server.health_check_max_per_cycle, 5);
479        assert_eq!(
480            server.health_check_pool_timeout,
481            std::time::Duration::from_secs(15)
482        );
483    }
484
485    #[test]
486    fn test_parse_server_from_env_tls_cert_path() {
487        let mut env = MockEnv::new();
488        env.set("NNTP_SERVER_0_HOST", "news.example.com")
489            .set("NNTP_SERVER_0_USE_TLS", "true")
490            .set("NNTP_SERVER_0_TLS_CERT_PATH", "/path/to/ca.pem");
491
492        let server = parse_server_from_env(0, &env).unwrap();
493        assert!(server.use_tls);
494        assert_eq!(server.tls_cert_path, Some("/path/to/ca.pem".to_string()));
495    }
496
497    #[test]
498    fn test_load_servers_from_env_provider_empty() {
499        let env = MockEnv::new();
500        let servers = load_servers_from_env_provider(&env);
501        assert!(servers.is_none());
502    }
503
504    #[test]
505    fn test_load_servers_from_env_provider_single() {
506        let mut env = MockEnv::new();
507        env.set("NNTP_SERVER_0_HOST", "news1.example.com");
508
509        let servers = load_servers_from_env_provider(&env);
510        assert!(servers.is_some());
511
512        let servers = servers.unwrap();
513        assert_eq!(servers.len(), 1);
514        assert_eq!(servers[0].host.as_str(), "news1.example.com");
515    }
516
517    #[test]
518    fn test_load_servers_from_env_provider_multiple() {
519        let mut env = MockEnv::new();
520        env.set("NNTP_SERVER_0_HOST", "news1.example.com")
521            .set("NNTP_SERVER_0_PORT", "119")
522            .set("NNTP_SERVER_1_HOST", "news2.example.com")
523            .set("NNTP_SERVER_1_PORT", "563")
524            .set("NNTP_SERVER_1_USE_TLS", "true")
525            .set("NNTP_SERVER_2_HOST", "news3.example.com");
526
527        let servers = load_servers_from_env_provider(&env);
528        assert!(servers.is_some());
529
530        let servers = servers.unwrap();
531        assert_eq!(servers.len(), 3);
532        assert_eq!(servers[0].host.as_str(), "news1.example.com");
533        assert_eq!(servers[1].host.as_str(), "news2.example.com");
534        assert_eq!(servers[2].host.as_str(), "news3.example.com");
535        assert!(servers[1].use_tls);
536        assert!(!servers[0].use_tls);
537    }
538
539    #[test]
540    fn test_load_servers_from_env_provider_gaps() {
541        let mut env = MockEnv::new();
542        // Server 0 and 2 defined, but not 1 - should stop at 1
543        env.set("NNTP_SERVER_0_HOST", "news1.example.com")
544            .set("NNTP_SERVER_2_HOST", "news3.example.com");
545
546        let servers = load_servers_from_env_provider(&env);
547        assert!(servers.is_some());
548
549        let servers = servers.unwrap();
550        // Should only get server 0, stops at first gap
551        assert_eq!(servers.len(), 1);
552        assert_eq!(servers[0].host.as_str(), "news1.example.com");
553    }
554
555    #[test]
556    fn test_parse_server_from_env_bool_variations() {
557        let mut env = MockEnv::new();
558        env.set("NNTP_SERVER_0_HOST", "news.example.com")
559            .set("NNTP_SERVER_0_USE_TLS", "True")
560            .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "FALSE");
561
562        let server = parse_server_from_env(0, &env).unwrap();
563        // Rust's parse::<bool>() requires exact "true"/"false" lowercase
564        // So these should fail to parse and use defaults
565        assert!(!server.use_tls); // Defaults to false
566        assert!(server.tls_verify_cert); // Defaults to true
567    }
568
569    #[test]
570    fn test_parse_server_from_env_correct_bool() {
571        let mut env = MockEnv::new();
572        env.set("NNTP_SERVER_0_HOST", "news.example.com")
573            .set("NNTP_SERVER_0_USE_TLS", "true")
574            .set("NNTP_SERVER_0_TLS_VERIFY_CERT", "false");
575
576        let server = parse_server_from_env(0, &env).unwrap();
577        assert!(server.use_tls);
578        assert!(!server.tls_verify_cert);
579    }
580
581    #[test]
582    fn test_config_source_description() {
583        assert_eq!(ConfigSource::File.description(), "configuration file");
584        assert_eq!(
585            ConfigSource::Environment.description(),
586            "environment variables"
587        );
588        assert_eq!(
589            ConfigSource::DefaultCreated.description(),
590            "default configuration (created)"
591        );
592    }
593
594    #[test]
595    fn test_config_source_equality() {
596        assert_eq!(ConfigSource::File, ConfigSource::File);
597        assert_ne!(ConfigSource::File, ConfigSource::Environment);
598        assert_ne!(ConfigSource::Environment, ConfigSource::DefaultCreated);
599    }
600
601    #[test]
602    fn test_load_config_with_fallback_creates_default() {
603        use tempfile::NamedTempFile;
604
605        let temp_file = NamedTempFile::new().unwrap();
606        let path = temp_file.path().to_str().unwrap().to_string();
607
608        // Remove the temp file so it doesn't exist
609        drop(temp_file);
610
611        // Should create default config
612        let result = load_config_with_fallback(&path);
613        assert!(result.is_ok());
614
615        let (config, source) = result.unwrap();
616        assert_eq!(source, ConfigSource::DefaultCreated);
617        assert_eq!(config.servers.len(), 1);
618        assert_eq!(config.servers[0].host.as_str(), "news.example.com");
619
620        // Cleanup
621        let _ = std::fs::remove_file(&path);
622    }
623
624    #[test]
625    fn test_load_config_with_fallback_reads_existing() {
626        use std::io::Write;
627        use tempfile::NamedTempFile;
628
629        let mut temp_file = NamedTempFile::new().unwrap();
630
631        // Write a valid config
632        let config_content = r#"
633[[servers]]
634host = "test.example.com"
635port = 119
636name = "Test Server"
637"#;
638        temp_file.write_all(config_content.as_bytes()).unwrap();
639        temp_file.flush().unwrap();
640
641        // Get path as owned string before borrowing for read
642        let path = temp_file.path().to_str().unwrap().to_string();
643
644        let result = load_config_with_fallback(&path);
645        assert!(result.is_ok());
646
647        let (config, source) = result.unwrap();
648        assert_eq!(source, ConfigSource::File);
649        assert_eq!(config.servers.len(), 1);
650        assert_eq!(config.servers[0].host.as_str(), "test.example.com");
651    }
652
653    #[test]
654    fn test_create_default_config() {
655        let config = create_default_config();
656        assert_eq!(config.servers.len(), 1);
657        assert_eq!(config.servers[0].host.as_str(), "news.example.com");
658        assert_eq!(config.servers[0].port.get(), 119);
659        assert!(!config.servers[0].use_tls);
660    }
661}