nntp_proxy/
proxy.rs

1//! NNTP Proxy implementation
2//!
3//! This module contains the main `NntpProxy` struct which orchestrates
4//! connection handling, routing, and resource management.
5
6use anyhow::{Context, Result};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::auth::AuthHandler;
14use crate::config::{Config, RoutingMode, ServerConfig};
15use crate::constants::buffer::{POOL, POOL_COUNT};
16use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
17use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
18use crate::protocol::BACKEND_UNAVAILABLE;
19use crate::router;
20use crate::session::ClientSession;
21use crate::types::{self, BufferSize};
22
23/// Builder for constructing an `NntpProxy` with optional configuration overrides
24///
25/// # Examples
26///
27/// Basic usage with defaults:
28/// ```no_run
29/// # use nntp_proxy::{NntpProxyBuilder, Config, RoutingMode};
30/// # use nntp_proxy::config::load_config;
31/// # fn main() -> anyhow::Result<()> {
32/// let config = load_config("config.toml")?;
33/// let proxy = NntpProxyBuilder::new(config)
34///     .with_routing_mode(RoutingMode::Hybrid)
35///     .build()?;
36/// # Ok(())
37/// # }
38/// ```
39///
40/// With custom buffer pool size:
41/// ```no_run
42/// # use nntp_proxy::{NntpProxyBuilder, Config, RoutingMode};
43/// # use nntp_proxy::config::load_config;
44/// # fn main() -> anyhow::Result<()> {
45/// let config = load_config("config.toml")?;
46/// let proxy = NntpProxyBuilder::new(config)
47///     .with_routing_mode(RoutingMode::PerCommand)
48///     .with_buffer_pool_size(512 * 1024)  // 512KB buffers
49///     .with_buffer_pool_count(64)         // 64 buffers
50///     .build()?;
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug)]
55pub struct NntpProxyBuilder {
56    config: Config,
57    routing_mode: RoutingMode,
58    buffer_size: Option<usize>,
59    buffer_count: Option<usize>,
60}
61
62impl NntpProxyBuilder {
63    /// Create a new builder with the given configuration
64    ///
65    /// The routing mode defaults to `Standard` (1:1) mode.
66    #[must_use]
67    pub fn new(config: Config) -> Self {
68        Self {
69            config,
70            routing_mode: RoutingMode::Standard,
71            buffer_size: None,
72            buffer_count: None,
73        }
74    }
75
76    /// Set the routing mode
77    ///
78    /// Available modes:
79    /// - `Standard`: 1:1 client-to-backend mapping (default)
80    /// - `PerCommand`: Each command routes to a different backend
81    /// - `Hybrid`: Starts in per-command mode, switches to stateful when needed
82    #[must_use]
83    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
84        self.routing_mode = mode;
85        self
86    }
87
88    /// Override the default buffer pool size (256KB)
89    ///
90    /// This affects the size of each buffer in the pool. Larger buffers
91    /// can improve throughput for large article transfers but use more memory.
92    #[must_use]
93    pub fn with_buffer_pool_size(mut self, size: usize) -> Self {
94        self.buffer_size = Some(size);
95        self
96    }
97
98    /// Override the default buffer pool count (32)
99    ///
100    /// This affects how many buffers are pre-allocated. Should roughly match
101    /// the expected number of concurrent connections.
102    #[must_use]
103    pub fn with_buffer_pool_count(mut self, count: usize) -> Self {
104        self.buffer_count = Some(count);
105        self
106    }
107
108    /// Build the `NntpProxy` instance
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if:
113    /// - No servers are configured
114    /// - Connection providers cannot be created
115    /// - Buffer size is zero
116    pub fn build(self) -> Result<NntpProxy> {
117        if self.config.servers.is_empty() {
118            anyhow::bail!("No servers configured in configuration");
119        }
120
121        // Use provided values or defaults
122        let buffer_size = self.buffer_size.unwrap_or(POOL);
123        let buffer_count = self.buffer_count.unwrap_or(POOL_COUNT);
124
125        // Create deadpool connection providers for each server
126        let connection_providers: Result<Vec<DeadpoolConnectionProvider>> = self
127            .config
128            .servers
129            .iter()
130            .map(|server| {
131                info!(
132                    "Configuring deadpool connection provider for '{}'",
133                    server.name
134                );
135                DeadpoolConnectionProvider::from_server_config(server)
136            })
137            .collect();
138
139        let connection_providers = connection_providers?;
140
141        let buffer_pool = BufferPool::new(
142            BufferSize::new(buffer_size)
143                .ok_or_else(|| anyhow::anyhow!("Buffer size must be non-zero"))?,
144            buffer_count,
145        );
146
147        let servers = Arc::new(self.config.servers);
148
149        // Create backend selector and add all backends
150        let router = Arc::new({
151            use types::BackendId;
152            connection_providers.iter().enumerate().fold(
153                router::BackendSelector::new(),
154                |mut r, (idx, provider)| {
155                    let backend_id = BackendId::from_index(idx);
156                    r.add_backend(
157                        backend_id,
158                        servers[idx].name.as_str().to_string(),
159                        provider.clone(),
160                    );
161                    r
162                },
163            )
164        });
165
166        // Create auth handler from config
167        let auth_handler = Arc::new(AuthHandler::new(
168            self.config.client_auth.username.clone(),
169            self.config.client_auth.password.clone(),
170        ));
171
172        Ok(NntpProxy {
173            servers,
174            router,
175            connection_providers,
176            buffer_pool,
177            routing_mode: self.routing_mode,
178            auth_handler,
179        })
180    }
181}
182
183#[derive(Debug, Clone)]
184pub struct NntpProxy {
185    servers: Arc<Vec<ServerConfig>>,
186    /// Backend selector for round-robin load balancing
187    router: Arc<router::BackendSelector>,
188    /// Connection providers per server - easily swappable implementation
189    connection_providers: Vec<DeadpoolConnectionProvider>,
190    /// Buffer pool for I/O operations
191    buffer_pool: BufferPool,
192    /// Routing mode (Standard, PerCommand, or Hybrid)
193    routing_mode: RoutingMode,
194    /// Authentication handler for client auth interception
195    auth_handler: Arc<AuthHandler>,
196}
197
198impl NntpProxy {
199    /// Create a new `NntpProxy` with the given configuration and routing mode
200    ///
201    /// This is a convenience method that uses the builder internally.
202    /// For more control over configuration, use [`NntpProxy::builder`].
203    ///
204    /// # Examples
205    ///
206    /// ```no_run
207    /// # use nntp_proxy::{NntpProxy, Config, RoutingMode};
208    /// # use nntp_proxy::config::load_config;
209    /// # fn main() -> anyhow::Result<()> {
210    /// let config = load_config("config.toml")?;
211    /// let proxy = NntpProxy::new(config, RoutingMode::Hybrid)?;
212    /// # Ok(())
213    /// # }
214    /// ```
215    pub fn new(config: Config, routing_mode: RoutingMode) -> Result<Self> {
216        NntpProxyBuilder::new(config)
217            .with_routing_mode(routing_mode)
218            .build()
219    }
220
221    /// Create a builder for more fine-grained control over proxy configuration
222    ///
223    /// # Examples
224    ///
225    /// ```no_run
226    /// # use nntp_proxy::{NntpProxy, Config, RoutingMode};
227    /// # use nntp_proxy::config::load_config;
228    /// # fn main() -> anyhow::Result<()> {
229    /// let config = load_config("config.toml")?;
230    /// let proxy = NntpProxy::builder(config)
231    ///     .with_routing_mode(RoutingMode::Hybrid)
232    ///     .with_buffer_pool_size(512 * 1024)
233    ///     .build()?;
234    /// # Ok(())
235    /// # }
236    /// ```
237    #[must_use]
238    pub fn builder(config: Config) -> NntpProxyBuilder {
239        NntpProxyBuilder::new(config)
240    }
241
242    /// Prewarm all connection pools before accepting clients
243    /// Creates all connections concurrently and returns when ready
244    pub async fn prewarm_connections(&self) -> Result<()> {
245        prewarm_pools(&self.connection_providers, &self.servers).await
246    }
247
248    /// Gracefully shutdown all connection pools
249    pub async fn graceful_shutdown(&self) {
250        info!("Initiating graceful shutdown of all connection pools...");
251
252        for provider in &self.connection_providers {
253            provider.graceful_shutdown().await;
254        }
255
256        info!("All connection pools have been shut down gracefully");
257    }
258
259    /// Get the list of servers
260    #[inline]
261    pub fn servers(&self) -> &[ServerConfig] {
262        &self.servers
263    }
264
265    /// Get the router
266    #[inline]
267    pub fn router(&self) -> &Arc<router::BackendSelector> {
268        &self.router
269    }
270
271    /// Get the connection providers
272    #[inline]
273    pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
274        &self.connection_providers
275    }
276
277    /// Get the buffer pool
278    #[inline]
279    pub fn buffer_pool(&self) -> &BufferPool {
280        &self.buffer_pool
281    }
282
283    /// Common setup for client connections (greeting only, prewarming done at startup)
284    async fn setup_client_connection(
285        &self,
286        client_stream: &mut TcpStream,
287        client_addr: SocketAddr,
288    ) -> Result<()> {
289        // Send proxy greeting
290        crate::protocol::send_proxy_greeting(client_stream, client_addr).await
291    }
292
293    pub async fn handle_client(
294        &self,
295        mut client_stream: TcpStream,
296        client_addr: SocketAddr,
297    ) -> Result<()> {
298        debug!("New client connection from {}", client_addr);
299
300        // Use a dummy ClientId and command for routing (synchronous 1:1 mapping)
301        use types::ClientId;
302        let client_id = ClientId::new();
303
304        // Select backend using router's round-robin
305        let backend_id = self.router.route_command_sync(client_id, "")?;
306        let server_idx = backend_id.as_index();
307        let server = &self.servers[server_idx];
308
309        info!(
310            "Routing client {} to backend {:?} ({}:{})",
311            client_addr, backend_id, server.host, server.port
312        );
313
314        // Setup connection (prewarm and greeting)
315        self.setup_client_connection(&mut client_stream, client_addr)
316            .await?;
317
318        // Get pooled backend connection
319        let pool_status = self.connection_providers[server_idx].status();
320        debug!(
321            "Pool status for {}: {}/{} available, {} created",
322            server.name, pool_status.available, pool_status.max_size, pool_status.created
323        );
324
325        let mut backend_conn = match self.connection_providers[server_idx]
326            .get_pooled_connection()
327            .await
328        {
329            Ok(conn) => {
330                debug!("Got pooled connection for {}", server.name);
331                conn
332            }
333            Err(e) => {
334                error!(
335                    "Failed to get pooled connection for {} (client {}): {}",
336                    server.name, client_addr, e
337                );
338                let _ = client_stream.write_all(BACKEND_UNAVAILABLE).await;
339                return Err(anyhow::anyhow!(
340                    "Failed to get pooled connection for backend '{}' (client {}): {}",
341                    server.name,
342                    client_addr,
343                    e
344                ));
345            }
346        };
347
348        // Apply socket optimizations for high-throughput
349        let client_optimizer = TcpOptimizer::new(&client_stream);
350        if let Err(e) = client_optimizer.optimize() {
351            debug!("Failed to optimize client socket: {}", e);
352        }
353
354        let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
355        if let Err(e) = backend_optimizer.optimize() {
356            debug!("Failed to optimize backend socket: {}", e);
357        }
358
359        // Create session and handle connection
360        let session = ClientSession::new(
361            client_addr,
362            self.buffer_pool.clone(),
363            self.auth_handler.clone(),
364        );
365        debug!("Starting session for client {}", client_addr);
366
367        let copy_result = session
368            .handle_with_pooled_backend(client_stream, &mut *backend_conn)
369            .await;
370
371        debug!("Session completed for client {}", client_addr);
372
373        // Complete the routing (decrement pending count)
374        self.router.complete_command_sync(backend_id);
375
376        // Log session results and handle backend connection errors
377        match copy_result {
378            Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
379                info!(
380                    "Connection closed for client {}: {} bytes sent, {} bytes received",
381                    client_addr, client_to_backend_bytes, backend_to_client_bytes
382                );
383            }
384            Err(e) => {
385                // Check if this is a backend I/O error - if so, remove connection from pool
386                if crate::pool::is_connection_error(&e) {
387                    warn!(
388                        "Backend connection error for client {}: {} - removing connection from pool",
389                        client_addr, e
390                    );
391                    crate::pool::remove_from_pool(backend_conn);
392                    return Err(e);
393                }
394                warn!("Session error for client {}: {}", client_addr, e);
395            }
396        }
397
398        debug!("Connection returned to pool for {}", server.name);
399        Ok(())
400    }
401
402    /// Handle client connection using per-command routing mode
403    ///
404    /// This creates a session with the router, allowing commands from this client
405    /// to be routed to different backends based on load balancing.
406    pub async fn handle_client_per_command_routing(
407        &self,
408        client_stream: TcpStream,
409        client_addr: SocketAddr,
410    ) -> Result<()> {
411        debug!(
412            "New per-command routing client connection from {}",
413            client_addr
414        );
415
416        // Enable TCP_NODELAY for low latency
417        if let Err(e) = client_stream.set_nodelay(true) {
418            debug!("Failed to set TCP_NODELAY for {}: {}", client_addr, e);
419        }
420
421        // NOTE: Don't call setup_client_connection here because handle_per_command_routing
422        // sends its own greeting ("200 NNTP Proxy Ready (Per-Command Routing)")
423        // Calling setup_client_connection would send a duplicate greeting
424
425        // Create session with router for per-command routing
426        let session = ClientSession::new_with_router(
427            client_addr,
428            self.buffer_pool.clone(),
429            self.router.clone(),
430            self.routing_mode,
431            self.auth_handler.clone(),
432        );
433
434        let session_id = crate::formatting::short_id(session.client_id().as_uuid());
435
436        info!(
437            "Client {} [{}] connected in per-command routing mode",
438            client_addr, session_id
439        );
440
441        // Handle the session with per-command routing
442        let result = session
443            .handle_per_command_routing(client_stream)
444            .await
445            .with_context(|| {
446                format!(
447                    "Per-command routing session failed for {} [{}]",
448                    client_addr, session_id
449                )
450            });
451
452        // Log session results
453        match result {
454            Ok((client_to_backend, backend_to_client)) => {
455                info!(
456                    "Session closed {} [{}] ↑{} ↓{}",
457                    client_addr,
458                    session_id,
459                    crate::formatting::format_bytes(client_to_backend),
460                    crate::formatting::format_bytes(backend_to_client)
461                );
462            }
463            Err(e) => {
464                // Check if this is a broken pipe error (normal for quick disconnections like SABnzbd tests)
465                let is_broken_pipe = e.downcast_ref::<std::io::Error>().is_some_and(|io_err| {
466                    matches!(
467                        io_err.kind(),
468                        std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset
469                    )
470                });
471
472                if is_broken_pipe {
473                    debug!(
474                        "Client {} [{}] disconnected: {} (normal for test connections)",
475                        client_addr, session_id, e
476                    );
477                } else {
478                    warn!("Session error {} [{}]: {}", client_addr, session_id, e);
479                }
480            }
481        }
482
483        debug!(
484            "Per-command routing connection closed for {} (ID: {})",
485            client_addr,
486            session.client_id()
487        );
488        Ok(())
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use std::sync::Arc;
496
497    fn create_test_config() -> Config {
498        use crate::config::{health_check_max_per_cycle, health_check_pool_timeout};
499        use crate::types::{HostName, MaxConnections, Port, ServerName};
500        Config {
501            servers: vec![
502                ServerConfig {
503                    host: HostName::new("server1.example.com".to_string()).unwrap(),
504                    port: Port::new(119).unwrap(),
505                    name: ServerName::new("Test Server 1".to_string()).unwrap(),
506                    username: None,
507                    password: None,
508                    max_connections: MaxConnections::new(5).unwrap(),
509                    use_tls: false,
510                    tls_verify_cert: true,
511                    tls_cert_path: None,
512                    connection_keepalive: None,
513                    health_check_max_per_cycle: health_check_max_per_cycle(),
514                    health_check_pool_timeout: health_check_pool_timeout(),
515                },
516                ServerConfig {
517                    host: HostName::new("server2.example.com".to_string()).unwrap(),
518                    port: Port::new(119).unwrap(),
519                    name: ServerName::new("Test Server 2".to_string()).unwrap(),
520                    username: None,
521                    password: None,
522                    max_connections: MaxConnections::new(8).unwrap(),
523                    use_tls: false,
524                    tls_verify_cert: true,
525                    tls_cert_path: None,
526                    connection_keepalive: None,
527                    health_check_max_per_cycle: health_check_max_per_cycle(),
528                    health_check_pool_timeout: health_check_pool_timeout(),
529                },
530                ServerConfig {
531                    host: HostName::new("server3.example.com".to_string()).unwrap(),
532                    port: Port::new(119).unwrap(),
533                    name: ServerName::new("Test Server 3".to_string()).unwrap(),
534                    username: None,
535                    password: None,
536                    max_connections: MaxConnections::new(12).unwrap(),
537                    use_tls: false,
538                    tls_verify_cert: true,
539                    tls_cert_path: None,
540                    connection_keepalive: None,
541                    health_check_max_per_cycle: health_check_max_per_cycle(),
542                    health_check_pool_timeout: health_check_pool_timeout(),
543                },
544            ],
545            ..Default::default()
546        }
547    }
548
549    #[test]
550    fn test_proxy_creation_with_servers() {
551        let config = create_test_config();
552        let proxy = Arc::new(
553            NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
554        );
555
556        assert_eq!(proxy.servers().len(), 3);
557        assert_eq!(proxy.servers()[0].name.as_str(), "Test Server 1");
558    }
559
560    #[test]
561    fn test_proxy_creation_with_empty_servers() {
562        let config = Config {
563            servers: vec![],
564            ..Default::default()
565        };
566        let result = NntpProxy::new(config, RoutingMode::Standard);
567
568        assert!(result.is_err());
569        assert!(
570            result
571                .unwrap_err()
572                .to_string()
573                .contains("No servers configured")
574        );
575    }
576
577    #[test]
578    fn test_proxy_has_router() {
579        let config = create_test_config();
580        let proxy = Arc::new(
581            NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
582        );
583
584        // Proxy should have a router with backends
585        assert_eq!(proxy.router.backend_count(), 3);
586    }
587
588    #[test]
589    fn test_builder_basic_usage() {
590        let config = create_test_config();
591        let proxy = NntpProxy::builder(config)
592            .build()
593            .expect("Failed to build proxy");
594
595        assert_eq!(proxy.servers().len(), 3);
596        assert_eq!(proxy.router.backend_count(), 3);
597    }
598
599    #[test]
600    fn test_builder_with_routing_mode() {
601        let config = create_test_config();
602        let proxy = NntpProxy::builder(config)
603            .with_routing_mode(RoutingMode::PerCommand)
604            .build()
605            .expect("Failed to build proxy");
606
607        assert_eq!(proxy.servers().len(), 3);
608    }
609
610    #[test]
611    fn test_builder_with_custom_buffer_pool() {
612        let config = create_test_config();
613        let proxy = NntpProxy::builder(config)
614            .with_buffer_pool_size(512 * 1024)
615            .with_buffer_pool_count(64)
616            .build()
617            .expect("Failed to build proxy");
618
619        assert_eq!(proxy.servers().len(), 3);
620        // Pool size and count are used internally but not exposed for verification
621    }
622
623    #[test]
624    fn test_builder_with_all_options() {
625        let config = create_test_config();
626        let proxy = NntpProxy::builder(config)
627            .with_routing_mode(RoutingMode::Hybrid)
628            .with_buffer_pool_size(1024 * 1024)
629            .with_buffer_pool_count(16)
630            .build()
631            .expect("Failed to build proxy");
632
633        assert_eq!(proxy.servers().len(), 3);
634        assert_eq!(proxy.router.backend_count(), 3);
635    }
636
637    #[test]
638    fn test_builder_empty_servers_error() {
639        let config = Config {
640            servers: vec![],
641            ..Default::default()
642        };
643        let result = NntpProxy::builder(config).build();
644
645        assert!(result.is_err());
646        assert!(
647            result
648                .unwrap_err()
649                .to_string()
650                .contains("No servers configured")
651        );
652    }
653
654    #[test]
655    fn test_backward_compatibility_new() {
656        // Ensure NntpProxy::new() still works (it uses builder internally)
657        let config = create_test_config();
658        let proxy = NntpProxy::new(config, RoutingMode::Standard)
659            .expect("Failed to create proxy with new()");
660
661        assert_eq!(proxy.servers().len(), 3);
662        assert_eq!(proxy.router.backend_count(), 3);
663    }
664}