nntp_proxy/config/
types.rs

1//! Configuration type definitions
2//!
3//! This module contains all the core configuration structures used by the proxy.
4
5use super::defaults;
6use crate::types::{
7    CacheCapacity, HostName, MaxConnections, MaxErrors, Port, ServerName, ThreadCount,
8    duration_serde, option_duration_serde,
9};
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13/// Routing mode for the proxy
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
15#[serde(rename_all = "lowercase")]
16pub enum RoutingMode {
17    /// Stateful 1:1 mode - each client gets a dedicated backend connection
18    Stateful,
19    /// Per-command routing - each command can use a different backend (stateless only)
20    PerCommand,
21    /// Hybrid mode - starts in per-command routing, auto-switches to stateful on first stateful command
22    Hybrid,
23}
24
25impl Default for RoutingMode {
26    /// Default routing mode is Hybrid, which provides optimal performance and full protocol support.
27    /// This mode automatically starts in per-command routing for efficiency and seamlessly switches
28    /// to stateful mode when commands requiring group context are detected.
29    fn default() -> Self {
30        Self::Hybrid
31    }
32}
33
34impl RoutingMode {
35    /// Check if this mode supports per-command routing
36    #[must_use]
37    pub const fn supports_per_command_routing(&self) -> bool {
38        matches!(self, Self::PerCommand | Self::Hybrid)
39    }
40
41    /// Check if this mode can handle stateful commands
42    #[must_use]
43    pub const fn supports_stateful_commands(&self) -> bool {
44        matches!(self, Self::Stateful | Self::Hybrid)
45    }
46
47    /// Get a human-readable description of this routing mode
48    #[must_use]
49    pub const fn as_str(&self) -> &'static str {
50        match self {
51            Self::Stateful => "stateful 1:1 mode",
52            Self::PerCommand => "per-command routing mode (stateless)",
53            Self::Hybrid => "hybrid routing mode",
54        }
55    }
56}
57
58impl std::fmt::Display for RoutingMode {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.write_str(self.as_str())
61    }
62}
63
64/// Main proxy configuration
65#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
66pub struct Config {
67    /// List of backend NNTP servers
68    #[serde(default)]
69    pub servers: Vec<Server>,
70    /// Proxy server settings
71    #[serde(default)]
72    pub proxy: Proxy,
73    /// Health check configuration
74    #[serde(default)]
75    pub health_check: HealthCheck,
76    /// Cache configuration (optional, for caching proxy)
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub cache: Option<Cache>,
79    /// Client authentication configuration
80    #[serde(default)]
81    pub client_auth: ClientAuth,
82}
83
84/// Proxy server settings
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86#[serde(default)]
87pub struct Proxy {
88    /// Host/IP to bind to (default: 0.0.0.0)
89    pub host: String,
90    /// Port to listen on (default: 8119)
91    pub port: Port,
92    /// Number of worker threads (default: 1, use 0 for CPU cores)
93    pub threads: ThreadCount,
94}
95
96impl Proxy {
97    /// Default listen host (all interfaces)
98    pub const DEFAULT_HOST: &'static str = "0.0.0.0";
99}
100
101impl Default for Proxy {
102    fn default() -> Self {
103        Self {
104            host: Self::DEFAULT_HOST.to_string(),
105            port: Port::default(),
106            threads: ThreadCount::default(),
107        }
108    }
109}
110
111/// Cache configuration for article caching
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113pub struct Cache {
114    /// Maximum number of articles to cache
115    #[serde(default = "super::defaults::cache_max_capacity")]
116    pub max_capacity: CacheCapacity,
117    /// Time-to-live for cached articles
118    #[serde(with = "duration_serde", default = "super::defaults::cache_ttl")]
119    pub ttl: Duration,
120}
121
122impl Default for Cache {
123    fn default() -> Self {
124        Self {
125            max_capacity: defaults::cache_max_capacity(),
126            ttl: defaults::cache_ttl(),
127        }
128    }
129}
130
131/// Health check configuration
132#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
133pub struct HealthCheck {
134    /// Interval between health checks
135    #[serde(
136        with = "duration_serde",
137        default = "super::defaults::health_check_interval"
138    )]
139    pub interval: Duration,
140    /// Timeout for each health check
141    #[serde(
142        with = "duration_serde",
143        default = "super::defaults::health_check_timeout"
144    )]
145    pub timeout: Duration,
146    /// Number of consecutive failures before marking unhealthy
147    #[serde(default = "super::defaults::unhealthy_threshold")]
148    pub unhealthy_threshold: MaxErrors,
149}
150
151impl Default for HealthCheck {
152    fn default() -> Self {
153        Self {
154            interval: super::defaults::health_check_interval(),
155            timeout: super::defaults::health_check_timeout(),
156            unhealthy_threshold: super::defaults::unhealthy_threshold(),
157        }
158    }
159}
160
161/// Client authentication configuration
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
163pub struct ClientAuth {
164    /// Required username for client authentication (if set, auth is enabled)
165    /// DEPRECATED: Use `users` instead for multi-user support
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub username: Option<String>,
168    /// Required password for client authentication
169    /// DEPRECATED: Use `users` instead for multi-user support
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub password: Option<String>,
172    /// Optional custom greeting message
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub greeting: Option<String>,
175    /// List of authorized users (replaces username/password for multi-user support)
176    #[serde(default, skip_serializing_if = "Vec::is_empty")]
177    pub users: Vec<UserCredentials>,
178}
179
180/// Individual user credentials
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
182pub struct UserCredentials {
183    pub username: String,
184    pub password: String,
185}
186
187impl ClientAuth {
188    /// Check if authentication is enabled
189    pub fn is_enabled(&self) -> bool {
190        // Auth is enabled if either the legacy single-user config or multi-user list is populated
191        (!self.users.is_empty()) || (self.username.is_some() && self.password.is_some())
192    }
193
194    /// Get all users (combines legacy + new format)
195    pub fn all_users(&self) -> Vec<(&str, &str)> {
196        let mut users = Vec::new();
197
198        // Add legacy single user if present
199        if let (Some(u), Some(p)) = (&self.username, &self.password) {
200            users.push((u.as_str(), p.as_str()));
201        }
202
203        // Add multi-user list
204        for user in &self.users {
205            users.push((user.username.as_str(), user.password.as_str()));
206        }
207
208        users
209    }
210}
211
212/// Configuration for a single backend server
213#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
214pub struct Server {
215    pub host: HostName,
216    pub port: Port,
217    pub name: ServerName,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub username: Option<String>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub password: Option<String>,
222    /// Maximum number of concurrent connections to this server
223    #[serde(default = "super::defaults::max_connections")]
224    pub max_connections: MaxConnections,
225
226    /// Enable TLS/SSL for this backend connection
227    #[serde(default)]
228    pub use_tls: bool,
229    /// Verify TLS certificates (recommended for production)
230    #[serde(default = "super::defaults::tls_verify_cert")]
231    pub tls_verify_cert: bool,
232    /// Optional path to custom CA certificate
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub tls_cert_path: Option<String>,
235    /// Interval to send keep-alive commands (DATE) on idle connections
236    /// None disables keep-alive (default)
237    #[serde(
238        with = "option_duration_serde",
239        default,
240        skip_serializing_if = "Option::is_none"
241    )]
242    pub connection_keepalive: Option<Duration>,
243    /// Maximum number of connections to check per health check cycle
244    /// Lower values reduce pool contention but may take longer to detect all stale connections
245    #[serde(default = "super::defaults::health_check_max_per_cycle")]
246    pub health_check_max_per_cycle: usize,
247    /// Timeout when acquiring a connection for health checking
248    /// Short timeout prevents blocking if pool is busy
249    #[serde(
250        with = "duration_serde",
251        default = "super::defaults::health_check_pool_timeout"
252    )]
253    pub health_check_pool_timeout: Duration,
254}
255
256/// Builder for constructing `Server` instances
257///
258/// Provides a fluent API for creating server configurations, especially useful in tests
259/// where creating Server with all 11+ fields is verbose.
260///
261/// # Examples
262///
263/// ```
264/// use nntp_proxy::config::Server;
265///
266/// // Minimal configuration
267/// let config = Server::builder("news.example.com", 119)
268///     .build()
269///     .unwrap();
270///
271/// // With authentication and TLS
272/// let config = Server::builder("secure.example.com", 563)
273///     .name("Secure Server")
274///     .username("user")
275///     .password("pass")
276///     .max_connections(20)
277///     .use_tls(true)
278///     .build()
279///     .unwrap();
280/// ```
281pub struct ServerBuilder {
282    host: String,
283    port: u16,
284    name: Option<String>,
285    username: Option<String>,
286    password: Option<String>,
287    max_connections: Option<usize>,
288    use_tls: bool,
289    tls_verify_cert: bool,
290    tls_cert_path: Option<String>,
291    connection_keepalive: Option<Duration>,
292    health_check_max_per_cycle: Option<usize>,
293    health_check_pool_timeout: Option<Duration>,
294}
295
296impl ServerBuilder {
297    /// Create a new builder with required parameters
298    ///
299    /// # Arguments
300    /// * `host` - Backend server hostname or IP address
301    /// * `port` - Backend server port number
302    #[must_use]
303    pub fn new(host: impl Into<String>, port: u16) -> Self {
304        Self {
305            host: host.into(),
306            port,
307            name: None,
308            username: None,
309            password: None,
310            max_connections: None,
311            use_tls: false,
312            tls_verify_cert: true, // Secure by default
313            tls_cert_path: None,
314            connection_keepalive: None,
315            health_check_max_per_cycle: None,
316            health_check_pool_timeout: None,
317        }
318    }
319
320    /// Set a friendly name for logging (defaults to "host:port")
321    #[must_use]
322    pub fn name(mut self, name: impl Into<String>) -> Self {
323        self.name = Some(name.into());
324        self
325    }
326
327    /// Set authentication username
328    #[must_use]
329    pub fn username(mut self, username: impl Into<String>) -> Self {
330        self.username = Some(username.into());
331        self
332    }
333
334    /// Set authentication password
335    #[must_use]
336    pub fn password(mut self, password: impl Into<String>) -> Self {
337        self.password = Some(password.into());
338        self
339    }
340
341    /// Set maximum number of concurrent connections
342    #[must_use]
343    pub fn max_connections(mut self, max: usize) -> Self {
344        self.max_connections = Some(max);
345        self
346    }
347
348    /// Enable TLS/SSL for this backend connection
349    #[must_use]
350    pub fn use_tls(mut self, enabled: bool) -> Self {
351        self.use_tls = enabled;
352        self
353    }
354
355    /// Set whether to verify TLS certificates
356    #[must_use]
357    pub fn tls_verify_cert(mut self, verify: bool) -> Self {
358        self.tls_verify_cert = verify;
359        self
360    }
361
362    /// Set path to custom CA certificate
363    #[must_use]
364    pub fn tls_cert_path(mut self, path: impl Into<String>) -> Self {
365        self.tls_cert_path = Some(path.into());
366        self
367    }
368
369    /// Set keep-alive interval for idle connections
370    #[must_use]
371    pub fn connection_keepalive(mut self, interval: Duration) -> Self {
372        self.connection_keepalive = Some(interval);
373        self
374    }
375
376    /// Set maximum connections to check per health check cycle
377    #[must_use]
378    pub fn health_check_max_per_cycle(mut self, max: usize) -> Self {
379        self.health_check_max_per_cycle = Some(max);
380        self
381    }
382
383    /// Set timeout for acquiring connections during health checks
384    #[must_use]
385    pub fn health_check_pool_timeout(mut self, timeout: Duration) -> Self {
386        self.health_check_pool_timeout = Some(timeout);
387        self
388    }
389
390    /// Build the Server
391    ///
392    /// # Errors
393    ///
394    /// Returns an error if:
395    /// - Host is empty or invalid
396    /// - Port is 0
397    /// - Name is empty (when explicitly set)
398    /// - Max connections is 0 (when explicitly set)
399    pub fn build(self) -> Result<Server, anyhow::Error> {
400        use crate::types::{HostName, MaxConnections, Port, ServerName};
401
402        let host = HostName::new(self.host.clone())?;
403
404        let port = Port::new(self.port)
405            .ok_or_else(|| anyhow::anyhow!("Invalid port: {} (must be 1-65535)", self.port))?;
406
407        let name_str = self
408            .name
409            .unwrap_or_else(|| format!("{}:{}", self.host, self.port));
410        let name = ServerName::new(name_str)?;
411
412        let max_connections = if let Some(max) = self.max_connections {
413            MaxConnections::new(max)
414                .ok_or_else(|| anyhow::anyhow!("Invalid max_connections: {} (must be > 0)", max))?
415        } else {
416            super::defaults::max_connections()
417        };
418
419        let health_check_max_per_cycle = self
420            .health_check_max_per_cycle
421            .unwrap_or_else(super::defaults::health_check_max_per_cycle);
422
423        let health_check_pool_timeout = self
424            .health_check_pool_timeout
425            .unwrap_or_else(super::defaults::health_check_pool_timeout);
426
427        Ok(Server {
428            host,
429            port,
430            name,
431            username: self.username,
432            password: self.password,
433            max_connections,
434            use_tls: self.use_tls,
435            tls_verify_cert: self.tls_verify_cert,
436            tls_cert_path: self.tls_cert_path,
437            connection_keepalive: self.connection_keepalive,
438            health_check_max_per_cycle,
439            health_check_pool_timeout,
440        })
441    }
442}
443
444impl Server {
445    /// Create a builder for constructing a Server
446    ///
447    /// # Examples
448    ///
449    /// ```
450    /// use nntp_proxy::config::Server;
451    ///
452    /// let config = Server::builder("news.example.com", 119)
453    ///     .name("Example Server")
454    ///     .max_connections(15)
455    ///     .build()
456    ///     .unwrap();
457    /// ```
458    #[must_use]
459    pub fn builder(host: impl Into<String>, port: u16) -> ServerBuilder {
460        ServerBuilder::new(host, port)
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    // RoutingMode tests
469    #[test]
470    fn test_routing_mode_default() {
471        assert_eq!(RoutingMode::default(), RoutingMode::Hybrid);
472    }
473
474    #[test]
475    fn test_routing_mode_supports_per_command() {
476        assert!(RoutingMode::PerCommand.supports_per_command_routing());
477        assert!(RoutingMode::Hybrid.supports_per_command_routing());
478        assert!(!RoutingMode::Stateful.supports_per_command_routing());
479    }
480
481    #[test]
482    fn test_routing_mode_supports_stateful() {
483        assert!(RoutingMode::Stateful.supports_stateful_commands());
484        assert!(RoutingMode::Hybrid.supports_stateful_commands());
485        assert!(!RoutingMode::PerCommand.supports_stateful_commands());
486    }
487
488    #[test]
489    fn test_routing_mode_as_str() {
490        assert_eq!(RoutingMode::Stateful.as_str(), "stateful 1:1 mode");
491        assert_eq!(
492            RoutingMode::PerCommand.as_str(),
493            "per-command routing mode (stateless)"
494        );
495        assert_eq!(RoutingMode::Hybrid.as_str(), "hybrid routing mode");
496    }
497
498    #[test]
499    fn test_routing_mode_display() {
500        assert_eq!(RoutingMode::Stateful.to_string(), "stateful 1:1 mode");
501        assert_eq!(RoutingMode::Hybrid.to_string(), "hybrid routing mode");
502    }
503
504    // Proxy tests
505    #[test]
506    fn test_proxy_default() {
507        let proxy = Proxy::default();
508        assert_eq!(proxy.host, "0.0.0.0");
509        assert_eq!(proxy.port.get(), 8119);
510    }
511
512    #[test]
513    fn test_proxy_default_host_constant() {
514        assert_eq!(Proxy::DEFAULT_HOST, "0.0.0.0");
515    }
516
517    // Cache tests
518    #[test]
519    fn test_cache_default() {
520        let cache = Cache::default();
521        assert_eq!(cache.max_capacity.get(), 10000);
522        assert_eq!(cache.ttl, Duration::from_secs(3600));
523    }
524
525    // HealthCheck tests
526    #[test]
527    fn test_health_check_default() {
528        let hc = HealthCheck::default();
529        assert_eq!(hc.interval, Duration::from_secs(30));
530        assert_eq!(hc.timeout, Duration::from_secs(5));
531        assert_eq!(hc.unhealthy_threshold.get(), 3);
532    }
533
534    // ClientAuth tests
535    #[test]
536    fn test_client_auth_is_enabled_legacy() {
537        let mut auth = ClientAuth::default();
538        assert!(!auth.is_enabled());
539
540        auth.username = Some("user".to_string());
541        auth.password = Some("pass".to_string());
542        assert!(auth.is_enabled());
543    }
544
545    #[test]
546    fn test_client_auth_is_enabled_multi_user() {
547        let mut auth = ClientAuth::default();
548        auth.users.push(UserCredentials {
549            username: "alice".to_string(),
550            password: "secret".to_string(),
551        });
552        assert!(auth.is_enabled());
553    }
554
555    #[test]
556    fn test_client_auth_all_users_legacy() {
557        let mut auth = ClientAuth::default();
558        auth.username = Some("user".to_string());
559        auth.password = Some("pass".to_string());
560
561        let users = auth.all_users();
562        assert_eq!(users.len(), 1);
563        assert_eq!(users[0], ("user", "pass"));
564    }
565
566    #[test]
567    fn test_client_auth_all_users_multi() {
568        let mut auth = ClientAuth::default();
569        auth.users.push(UserCredentials {
570            username: "alice".to_string(),
571            password: "alice_pw".to_string(),
572        });
573        auth.users.push(UserCredentials {
574            username: "bob".to_string(),
575            password: "bob_pw".to_string(),
576        });
577
578        let users = auth.all_users();
579        assert_eq!(users.len(), 2);
580        assert_eq!(users[0], ("alice", "alice_pw"));
581        assert_eq!(users[1], ("bob", "bob_pw"));
582    }
583
584    #[test]
585    fn test_client_auth_all_users_combined() {
586        let mut auth = ClientAuth::default();
587        auth.username = Some("legacy".to_string());
588        auth.password = Some("legacy_pw".to_string());
589        auth.users.push(UserCredentials {
590            username: "alice".to_string(),
591            password: "alice_pw".to_string(),
592        });
593
594        let users = auth.all_users();
595        assert_eq!(users.len(), 2);
596        assert_eq!(users[0], ("legacy", "legacy_pw"));
597        assert_eq!(users[1], ("alice", "alice_pw"));
598    }
599
600    // ServerBuilder tests
601    #[test]
602    fn test_server_builder_minimal() {
603        let server = Server::builder("news.example.com", 119).build().unwrap();
604
605        assert_eq!(server.host.as_str(), "news.example.com");
606        assert_eq!(server.port.get(), 119);
607        assert_eq!(server.name.as_str(), "news.example.com:119");
608        assert_eq!(server.max_connections.get(), 10);
609        assert!(!server.use_tls);
610        assert!(server.tls_verify_cert); // Secure by default
611    }
612
613    #[test]
614    fn test_server_builder_with_name() {
615        let server = Server::builder("localhost", 119)
616            .name("Test Server")
617            .build()
618            .unwrap();
619
620        assert_eq!(server.name.as_str(), "Test Server");
621    }
622
623    #[test]
624    fn test_server_builder_with_auth() {
625        let server = Server::builder("news.example.com", 119)
626            .username("testuser")
627            .password("testpass")
628            .build()
629            .unwrap();
630
631        assert_eq!(server.username.as_ref().unwrap(), "testuser");
632        assert_eq!(server.password.as_ref().unwrap(), "testpass");
633    }
634
635    #[test]
636    fn test_server_builder_with_max_connections() {
637        let server = Server::builder("localhost", 119)
638            .max_connections(20)
639            .build()
640            .unwrap();
641
642        assert_eq!(server.max_connections.get(), 20);
643    }
644
645    #[test]
646    fn test_server_builder_with_tls() {
647        let server = Server::builder("secure.example.com", 563)
648            .use_tls(true)
649            .tls_verify_cert(false)
650            .tls_cert_path("/path/to/cert.pem")
651            .build()
652            .unwrap();
653
654        assert!(server.use_tls);
655        assert!(!server.tls_verify_cert);
656        assert_eq!(server.tls_cert_path.as_ref().unwrap(), "/path/to/cert.pem");
657    }
658
659    #[test]
660    fn test_server_builder_with_keepalive() {
661        let keepalive = Duration::from_secs(300);
662        let server = Server::builder("localhost", 119)
663            .connection_keepalive(keepalive)
664            .build()
665            .unwrap();
666
667        assert_eq!(server.connection_keepalive, Some(keepalive));
668    }
669
670    #[test]
671    fn test_server_builder_with_health_check_settings() {
672        let timeout = Duration::from_millis(500);
673        let server = Server::builder("localhost", 119)
674            .health_check_max_per_cycle(5)
675            .health_check_pool_timeout(timeout)
676            .build()
677            .unwrap();
678
679        assert_eq!(server.health_check_max_per_cycle, 5);
680        assert_eq!(server.health_check_pool_timeout, timeout);
681    }
682
683    #[test]
684    fn test_server_builder_chaining() {
685        let server = Server::builder("news.example.com", 563)
686            .name("Production Server")
687            .username("admin")
688            .password("secret")
689            .max_connections(25)
690            .use_tls(true)
691            .tls_verify_cert(true)
692            .build()
693            .unwrap();
694
695        assert_eq!(server.name.as_str(), "Production Server");
696        assert_eq!(server.max_connections.get(), 25);
697        assert!(server.use_tls);
698    }
699
700    #[test]
701    fn test_server_builder_invalid_host() {
702        let result = Server::builder("", 119).build();
703        assert!(result.is_err());
704    }
705
706    #[test]
707    fn test_server_builder_invalid_port() {
708        let result = Server::builder("localhost", 0).build();
709        assert!(result.is_err());
710    }
711
712    #[test]
713    fn test_server_builder_invalid_max_connections() {
714        let result = Server::builder("localhost", 119).max_connections(0).build();
715        assert!(result.is_err());
716    }
717
718    #[test]
719    fn test_server_builder_from_server_method() {
720        let builder = Server::builder("localhost", 119);
721        let server = builder.build().unwrap();
722        assert_eq!(server.host.as_str(), "localhost");
723    }
724
725    // Config tests
726    #[test]
727    fn test_config_default() {
728        let config = Config::default();
729        assert!(config.servers.is_empty());
730        assert_eq!(config.proxy.host, "0.0.0.0");
731        assert!(config.cache.is_none());
732        assert!(!config.client_auth.is_enabled());
733    }
734}