nntp_proxy/
proxy.rs

1//! NNTP Proxy implementation
2//!
3//! This module contains the main `NntpProxy` struct which orchestrates
4//! connection handling, routing, and resource management.
5
6use anyhow::{Context, Result};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::config::{Config, RoutingMode, ServerConfig};
14use crate::constants::buffer::{POOL, POOL_COUNT};
15use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
16use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
17use crate::protocol::BACKEND_UNAVAILABLE;
18use crate::router;
19use crate::session::ClientSession;
20use crate::types::{self, BufferSize};
21
22/// Builder for constructing an `NntpProxy` with optional configuration overrides
23///
24/// # Examples
25///
26/// Basic usage with defaults:
27/// ```no_run
28/// # use nntp_proxy::{NntpProxyBuilder, Config, RoutingMode};
29/// # use nntp_proxy::config::load_config;
30/// # fn main() -> anyhow::Result<()> {
31/// let config = load_config("config.toml")?;
32/// let proxy = NntpProxyBuilder::new(config)
33///     .with_routing_mode(RoutingMode::Hybrid)
34///     .build()?;
35/// # Ok(())
36/// # }
37/// ```
38///
39/// With custom buffer pool size:
40/// ```no_run
41/// # use nntp_proxy::{NntpProxyBuilder, Config, RoutingMode};
42/// # use nntp_proxy::config::load_config;
43/// # fn main() -> anyhow::Result<()> {
44/// let config = load_config("config.toml")?;
45/// let proxy = NntpProxyBuilder::new(config)
46///     .with_routing_mode(RoutingMode::PerCommand)
47///     .with_buffer_pool_size(512 * 1024)  // 512KB buffers
48///     .with_buffer_pool_count(64)         // 64 buffers
49///     .build()?;
50/// # Ok(())
51/// # }
52/// ```
53#[derive(Debug)]
54pub struct NntpProxyBuilder {
55    config: Config,
56    routing_mode: RoutingMode,
57    buffer_size: Option<usize>,
58    buffer_count: Option<usize>,
59}
60
61impl NntpProxyBuilder {
62    /// Create a new builder with the given configuration
63    ///
64    /// The routing mode defaults to `Standard` (1:1) mode.
65    #[must_use]
66    pub fn new(config: Config) -> Self {
67        Self {
68            config,
69            routing_mode: RoutingMode::Standard,
70            buffer_size: None,
71            buffer_count: None,
72        }
73    }
74
75    /// Set the routing mode
76    ///
77    /// Available modes:
78    /// - `Standard`: 1:1 client-to-backend mapping (default)
79    /// - `PerCommand`: Each command routes to a different backend
80    /// - `Hybrid`: Starts in per-command mode, switches to stateful when needed
81    #[must_use]
82    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
83        self.routing_mode = mode;
84        self
85    }
86
87    /// Override the default buffer pool size (256KB)
88    ///
89    /// This affects the size of each buffer in the pool. Larger buffers
90    /// can improve throughput for large article transfers but use more memory.
91    #[must_use]
92    pub fn with_buffer_pool_size(mut self, size: usize) -> Self {
93        self.buffer_size = Some(size);
94        self
95    }
96
97    /// Override the default buffer pool count (32)
98    ///
99    /// This affects how many buffers are pre-allocated. Should roughly match
100    /// the expected number of concurrent connections.
101    #[must_use]
102    pub fn with_buffer_pool_count(mut self, count: usize) -> Self {
103        self.buffer_count = Some(count);
104        self
105    }
106
107    /// Build the `NntpProxy` instance
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if:
112    /// - No servers are configured
113    /// - Connection providers cannot be created
114    /// - Buffer size is zero
115    pub fn build(self) -> Result<NntpProxy> {
116        if self.config.servers.is_empty() {
117            anyhow::bail!("No servers configured in configuration");
118        }
119
120        // Use provided values or defaults
121        let buffer_size = self.buffer_size.unwrap_or(POOL);
122        let buffer_count = self.buffer_count.unwrap_or(POOL_COUNT);
123
124        // Create deadpool connection providers for each server
125        let connection_providers: Result<Vec<DeadpoolConnectionProvider>> = self
126            .config
127            .servers
128            .iter()
129            .map(|server| {
130                info!(
131                    "Configuring deadpool connection provider for '{}'",
132                    server.name
133                );
134                DeadpoolConnectionProvider::from_server_config(server)
135            })
136            .collect();
137
138        let connection_providers = connection_providers?;
139
140        let buffer_pool = BufferPool::new(
141            BufferSize::new(buffer_size)
142                .ok_or_else(|| anyhow::anyhow!("Buffer size must be non-zero"))?,
143            buffer_count,
144        );
145
146        let servers = Arc::new(self.config.servers);
147
148        // Create backend selector and add all backends
149        let router = Arc::new({
150            use types::BackendId;
151            connection_providers.iter().enumerate().fold(
152                router::BackendSelector::new(),
153                |mut r, (idx, provider)| {
154                    let backend_id = BackendId::from_index(idx);
155                    r.add_backend(
156                        backend_id,
157                        servers[idx].name.as_str().to_string(),
158                        provider.clone(),
159                    );
160                    r
161                },
162            )
163        });
164
165        Ok(NntpProxy {
166            servers,
167            router,
168            connection_providers,
169            buffer_pool,
170            routing_mode: self.routing_mode,
171        })
172    }
173}
174
175#[derive(Debug, Clone)]
176pub struct NntpProxy {
177    servers: Arc<Vec<ServerConfig>>,
178    /// Backend selector for round-robin load balancing
179    router: Arc<router::BackendSelector>,
180    /// Connection providers per server - easily swappable implementation
181    connection_providers: Vec<DeadpoolConnectionProvider>,
182    /// Buffer pool for I/O operations
183    buffer_pool: BufferPool,
184    /// Routing mode (Standard, PerCommand, or Hybrid)
185    routing_mode: RoutingMode,
186}
187
188impl NntpProxy {
189    /// Create a new `NntpProxy` with the given configuration and routing mode
190    ///
191    /// This is a convenience method that uses the builder internally.
192    /// For more control over configuration, use [`NntpProxy::builder`].
193    ///
194    /// # Examples
195    ///
196    /// ```no_run
197    /// # use nntp_proxy::{NntpProxy, Config, RoutingMode};
198    /// # use nntp_proxy::config::load_config;
199    /// # fn main() -> anyhow::Result<()> {
200    /// let config = load_config("config.toml")?;
201    /// let proxy = NntpProxy::new(config, RoutingMode::Hybrid)?;
202    /// # Ok(())
203    /// # }
204    /// ```
205    pub fn new(config: Config, routing_mode: RoutingMode) -> Result<Self> {
206        NntpProxyBuilder::new(config)
207            .with_routing_mode(routing_mode)
208            .build()
209    }
210
211    /// Create a builder for more fine-grained control over proxy configuration
212    ///
213    /// # Examples
214    ///
215    /// ```no_run
216    /// # use nntp_proxy::{NntpProxy, Config, RoutingMode};
217    /// # use nntp_proxy::config::load_config;
218    /// # fn main() -> anyhow::Result<()> {
219    /// let config = load_config("config.toml")?;
220    /// let proxy = NntpProxy::builder(config)
221    ///     .with_routing_mode(RoutingMode::Hybrid)
222    ///     .with_buffer_pool_size(512 * 1024)
223    ///     .build()?;
224    /// # Ok(())
225    /// # }
226    /// ```
227    #[must_use]
228    pub fn builder(config: Config) -> NntpProxyBuilder {
229        NntpProxyBuilder::new(config)
230    }
231
232    /// Prewarm all connection pools before accepting clients
233    /// Creates all connections concurrently and returns when ready
234    pub async fn prewarm_connections(&self) -> Result<()> {
235        prewarm_pools(&self.connection_providers, &self.servers).await
236    }
237
238    /// Gracefully shutdown all connection pools
239    pub async fn graceful_shutdown(&self) {
240        info!("Initiating graceful shutdown of all connection pools...");
241
242        for provider in &self.connection_providers {
243            provider.graceful_shutdown().await;
244        }
245
246        info!("All connection pools have been shut down gracefully");
247    }
248
249    /// Get the list of servers
250    #[inline]
251    pub fn servers(&self) -> &[ServerConfig] {
252        &self.servers
253    }
254
255    /// Get the router
256    #[inline]
257    pub fn router(&self) -> &Arc<router::BackendSelector> {
258        &self.router
259    }
260
261    /// Get the connection providers
262    #[inline]
263    pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
264        &self.connection_providers
265    }
266
267    /// Get the buffer pool
268    #[inline]
269    pub fn buffer_pool(&self) -> &BufferPool {
270        &self.buffer_pool
271    }
272
273    /// Common setup for client connections (greeting only, prewarming done at startup)
274    async fn setup_client_connection(
275        &self,
276        client_stream: &mut TcpStream,
277        client_addr: SocketAddr,
278    ) -> Result<()> {
279        // Send proxy greeting
280        crate::protocol::send_proxy_greeting(client_stream, client_addr).await
281    }
282
283    pub async fn handle_client(
284        &self,
285        mut client_stream: TcpStream,
286        client_addr: SocketAddr,
287    ) -> Result<()> {
288        debug!("New client connection from {}", client_addr);
289
290        // Use a dummy ClientId and command for routing (synchronous 1:1 mapping)
291        use types::ClientId;
292        let client_id = ClientId::new();
293
294        // Select backend using router's round-robin
295        let backend_id = self.router.route_command_sync(client_id, "")?;
296        let server_idx = backend_id.as_index();
297        let server = &self.servers[server_idx];
298
299        info!(
300            "Routing client {} to backend {:?} ({}:{})",
301            client_addr, backend_id, server.host, server.port
302        );
303
304        // Setup connection (prewarm and greeting)
305        self.setup_client_connection(&mut client_stream, client_addr)
306            .await?;
307
308        // Get pooled backend connection
309        let pool_status = self.connection_providers[server_idx].status();
310        debug!(
311            "Pool status for {}: {}/{} available, {} created",
312            server.name, pool_status.available, pool_status.max_size, pool_status.created
313        );
314
315        let mut backend_conn = match self.connection_providers[server_idx]
316            .get_pooled_connection()
317            .await
318        {
319            Ok(conn) => {
320                debug!("Got pooled connection for {}", server.name);
321                conn
322            }
323            Err(e) => {
324                error!(
325                    "Failed to get pooled connection for {} (client {}): {}",
326                    server.name, client_addr, e
327                );
328                let _ = client_stream.write_all(BACKEND_UNAVAILABLE).await;
329                return Err(anyhow::anyhow!(
330                    "Failed to get pooled connection for backend '{}' (client {}): {}",
331                    server.name,
332                    client_addr,
333                    e
334                ));
335            }
336        };
337
338        // Apply socket optimizations for high-throughput
339        let client_optimizer = TcpOptimizer::new(&client_stream);
340        if let Err(e) = client_optimizer.optimize() {
341            debug!("Failed to optimize client socket: {}", e);
342        }
343
344        let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
345        if let Err(e) = backend_optimizer.optimize() {
346            debug!("Failed to optimize backend socket: {}", e);
347        }
348
349        // Create session and handle connection
350        let session = ClientSession::new(client_addr, self.buffer_pool.clone());
351        debug!("Starting session for client {}", client_addr);
352
353        let copy_result = session
354            .handle_with_pooled_backend(client_stream, &mut *backend_conn)
355            .await;
356
357        debug!("Session completed for client {}", client_addr);
358
359        // Complete the routing (decrement pending count)
360        self.router.complete_command_sync(backend_id);
361
362        // Log session results and handle backend connection errors
363        match copy_result {
364            Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
365                info!(
366                    "Connection closed for client {}: {} bytes sent, {} bytes received",
367                    client_addr, client_to_backend_bytes, backend_to_client_bytes
368                );
369            }
370            Err(e) => {
371                // Check if this is a backend I/O error - if so, remove connection from pool
372                if crate::pool::is_connection_error(&e) {
373                    warn!(
374                        "Backend connection error for client {}: {} - removing connection from pool",
375                        client_addr, e
376                    );
377                    crate::pool::remove_from_pool(backend_conn);
378                    return Err(e);
379                }
380                warn!("Session error for client {}: {}", client_addr, e);
381            }
382        }
383
384        debug!("Connection returned to pool for {}", server.name);
385        Ok(())
386    }
387
388    /// Handle client connection using per-command routing mode
389    ///
390    /// This creates a session with the router, allowing commands from this client
391    /// to be routed to different backends based on load balancing.
392    pub async fn handle_client_per_command_routing(
393        &self,
394        client_stream: TcpStream,
395        client_addr: SocketAddr,
396    ) -> Result<()> {
397        debug!(
398            "New per-command routing client connection from {}",
399            client_addr
400        );
401
402        // Enable TCP_NODELAY for low latency
403        if let Err(e) = client_stream.set_nodelay(true) {
404            debug!("Failed to set TCP_NODELAY for {}: {}", client_addr, e);
405        }
406
407        // NOTE: Don't call setup_client_connection here because handle_per_command_routing
408        // sends its own greeting ("200 NNTP Proxy Ready (Per-Command Routing)")
409        // Calling setup_client_connection would send a duplicate greeting
410
411        // Create session with router for per-command routing
412        let session = ClientSession::new_with_router(
413            client_addr,
414            self.buffer_pool.clone(),
415            self.router.clone(),
416            self.routing_mode,
417        );
418
419        let session_id = crate::formatting::short_id(session.client_id().as_uuid());
420
421        info!(
422            "Client {} [{}] connected in per-command routing mode",
423            client_addr, session_id
424        );
425
426        // Handle the session with per-command routing
427        let result = session
428            .handle_per_command_routing(client_stream)
429            .await
430            .with_context(|| {
431                format!(
432                    "Per-command routing session failed for {} [{}]",
433                    client_addr, session_id
434                )
435            });
436
437        // Log session results
438        match result {
439            Ok((client_to_backend, backend_to_client)) => {
440                info!(
441                    "Session closed {} [{}] ↑{} ↓{}",
442                    client_addr,
443                    session_id,
444                    crate::formatting::format_bytes(client_to_backend),
445                    crate::formatting::format_bytes(backend_to_client)
446                );
447            }
448            Err(e) => {
449                // Check if this is a broken pipe error (normal for quick disconnections like SABnzbd tests)
450                let is_broken_pipe = e.downcast_ref::<std::io::Error>().is_some_and(|io_err| {
451                    matches!(
452                        io_err.kind(),
453                        std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset
454                    )
455                });
456
457                if is_broken_pipe {
458                    debug!(
459                        "Client {} [{}] disconnected: {} (normal for test connections)",
460                        client_addr, session_id, e
461                    );
462                } else {
463                    warn!("Session error {} [{}]: {}", client_addr, session_id, e);
464                }
465            }
466        }
467
468        debug!(
469            "Per-command routing connection closed for {} (ID: {})",
470            client_addr,
471            session.client_id()
472        );
473        Ok(())
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use std::sync::Arc;
481
482    fn create_test_config() -> Config {
483        use crate::config::{health_check_max_per_cycle, health_check_pool_timeout};
484        use crate::types::{HostName, MaxConnections, Port, ServerName};
485        Config {
486            servers: vec![
487                ServerConfig {
488                    host: HostName::new("server1.example.com".to_string()).unwrap(),
489                    port: Port::new(119).unwrap(),
490                    name: ServerName::new("Test Server 1".to_string()).unwrap(),
491                    username: None,
492                    password: None,
493                    max_connections: MaxConnections::new(5).unwrap(),
494                    use_tls: false,
495                    tls_verify_cert: true,
496                    tls_cert_path: None,
497                    connection_keepalive: None,
498                    health_check_max_per_cycle: health_check_max_per_cycle(),
499                    health_check_pool_timeout: health_check_pool_timeout(),
500                },
501                ServerConfig {
502                    host: HostName::new("server2.example.com".to_string()).unwrap(),
503                    port: Port::new(119).unwrap(),
504                    name: ServerName::new("Test Server 2".to_string()).unwrap(),
505                    username: None,
506                    password: None,
507                    max_connections: MaxConnections::new(8).unwrap(),
508                    use_tls: false,
509                    tls_verify_cert: true,
510                    tls_cert_path: None,
511                    connection_keepalive: None,
512                    health_check_max_per_cycle: health_check_max_per_cycle(),
513                    health_check_pool_timeout: health_check_pool_timeout(),
514                },
515                ServerConfig {
516                    host: HostName::new("server3.example.com".to_string()).unwrap(),
517                    port: Port::new(119).unwrap(),
518                    name: ServerName::new("Test Server 3".to_string()).unwrap(),
519                    username: None,
520                    password: None,
521                    max_connections: MaxConnections::new(12).unwrap(),
522                    use_tls: false,
523                    tls_verify_cert: true,
524                    tls_cert_path: None,
525                    connection_keepalive: None,
526                    health_check_max_per_cycle: health_check_max_per_cycle(),
527                    health_check_pool_timeout: health_check_pool_timeout(),
528                },
529            ],
530            ..Default::default()
531        }
532    }
533
534    #[test]
535    fn test_proxy_creation_with_servers() {
536        let config = create_test_config();
537        let proxy = Arc::new(
538            NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
539        );
540
541        assert_eq!(proxy.servers().len(), 3);
542        assert_eq!(proxy.servers()[0].name.as_str(), "Test Server 1");
543    }
544
545    #[test]
546    fn test_proxy_creation_with_empty_servers() {
547        let config = Config {
548            servers: vec![],
549            ..Default::default()
550        };
551        let result = NntpProxy::new(config, RoutingMode::Standard);
552
553        assert!(result.is_err());
554        assert!(
555            result
556                .unwrap_err()
557                .to_string()
558                .contains("No servers configured")
559        );
560    }
561
562    #[test]
563    fn test_proxy_has_router() {
564        let config = create_test_config();
565        let proxy = Arc::new(
566            NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
567        );
568
569        // Proxy should have a router with backends
570        assert_eq!(proxy.router.backend_count(), 3);
571    }
572
573    #[test]
574    fn test_builder_basic_usage() {
575        let config = create_test_config();
576        let proxy = NntpProxy::builder(config)
577            .build()
578            .expect("Failed to build proxy");
579
580        assert_eq!(proxy.servers().len(), 3);
581        assert_eq!(proxy.router.backend_count(), 3);
582    }
583
584    #[test]
585    fn test_builder_with_routing_mode() {
586        let config = create_test_config();
587        let proxy = NntpProxy::builder(config)
588            .with_routing_mode(RoutingMode::PerCommand)
589            .build()
590            .expect("Failed to build proxy");
591
592        assert_eq!(proxy.servers().len(), 3);
593    }
594
595    #[test]
596    fn test_builder_with_custom_buffer_pool() {
597        let config = create_test_config();
598        let proxy = NntpProxy::builder(config)
599            .with_buffer_pool_size(512 * 1024)
600            .with_buffer_pool_count(64)
601            .build()
602            .expect("Failed to build proxy");
603
604        assert_eq!(proxy.servers().len(), 3);
605        // Pool size and count are used internally but not exposed for verification
606    }
607
608    #[test]
609    fn test_builder_with_all_options() {
610        let config = create_test_config();
611        let proxy = NntpProxy::builder(config)
612            .with_routing_mode(RoutingMode::Hybrid)
613            .with_buffer_pool_size(1024 * 1024)
614            .with_buffer_pool_count(16)
615            .build()
616            .expect("Failed to build proxy");
617
618        assert_eq!(proxy.servers().len(), 3);
619        assert_eq!(proxy.router.backend_count(), 3);
620    }
621
622    #[test]
623    fn test_builder_empty_servers_error() {
624        let config = Config {
625            servers: vec![],
626            ..Default::default()
627        };
628        let result = NntpProxy::builder(config).build();
629
630        assert!(result.is_err());
631        assert!(
632            result
633                .unwrap_err()
634                .to_string()
635                .contains("No servers configured")
636        );
637    }
638
639    #[test]
640    fn test_backward_compatibility_new() {
641        // Ensure NntpProxy::new() still works (it uses builder internally)
642        let config = create_test_config();
643        let proxy = NntpProxy::new(config, RoutingMode::Standard)
644            .expect("Failed to create proxy with new()");
645
646        assert_eq!(proxy.servers().len(), 3);
647        assert_eq!(proxy.router.backend_count(), 3);
648    }
649}