nntp_proxy/
proxy.rs

1//! NNTP Proxy implementation
2//!
3//! This module contains the main `NntpProxy` struct which orchestrates
4//! connection handling, routing, and resource management.
5
6use anyhow::Result;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::config::{Config, ServerConfig};
14use crate::constants::buffer::{BUFFER_POOL_SIZE, BUFFER_SIZE};
15use crate::constants::stateless_proxy::*;
16use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
17use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
18use crate::router;
19use crate::session::ClientSession;
20use crate::types;
21
22#[derive(Debug, Clone)]
23pub struct NntpProxy {
24    servers: Arc<Vec<ServerConfig>>,
25    /// Backend selector for round-robin load balancing
26    router: Arc<router::BackendSelector>,
27    /// Connection providers per server - easily swappable implementation
28    connection_providers: Vec<DeadpoolConnectionProvider>,
29    /// Buffer pool for I/O operations
30    buffer_pool: BufferPool,
31}
32
33impl NntpProxy {
34    pub fn new(config: Config) -> Result<Self> {
35        if config.servers.is_empty() {
36            anyhow::bail!("No servers configured in configuration");
37        }
38
39        // Create deadpool connection providers for each server
40        let connection_providers: Vec<DeadpoolConnectionProvider> = config
41            .servers
42            .iter()
43            .map(|server| {
44                info!(
45                    "Configuring deadpool connection provider for '{}'",
46                    server.name
47                );
48                DeadpoolConnectionProvider::from_server_config(server)
49            })
50            .collect();
51
52        let buffer_pool = BufferPool::new(BUFFER_SIZE, BUFFER_POOL_SIZE);
53
54        let servers = Arc::new(config.servers);
55
56        // Create backend selector and add all backends
57        let router = Arc::new({
58            use types::BackendId;
59            connection_providers.iter().enumerate().fold(
60                router::BackendSelector::new(),
61                |mut r, (idx, provider)| {
62                    let backend_id = BackendId::from_index(idx);
63                    r.add_backend(backend_id, servers[idx].name.clone(), provider.clone());
64                    r
65                },
66            )
67        });
68
69        Ok(Self {
70            servers,
71            router,
72            connection_providers,
73            buffer_pool,
74        })
75    }
76
77    /// Prewarm all connection pools before accepting clients
78    /// Creates all connections concurrently and returns when ready
79    pub async fn prewarm_connections(&self) -> Result<()> {
80        prewarm_pools(&self.connection_providers, &self.servers).await
81    }
82
83    /// Gracefully shutdown all connection pools
84    pub async fn graceful_shutdown(&self) {
85        info!("Initiating graceful shutdown of all connection pools...");
86
87        for provider in &self.connection_providers {
88            provider.graceful_shutdown().await;
89        }
90
91        info!("All connection pools have been shut down gracefully");
92    }
93
94    /// Get the list of servers
95    #[inline]
96    pub fn servers(&self) -> &[ServerConfig] {
97        &self.servers
98    }
99
100    /// Get the router
101    #[inline]
102    pub fn router(&self) -> &Arc<router::BackendSelector> {
103        &self.router
104    }
105
106    /// Get the connection providers
107    #[inline]
108    pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
109        &self.connection_providers
110    }
111
112    /// Get the buffer pool
113    #[inline]
114    pub fn buffer_pool(&self) -> &BufferPool {
115        &self.buffer_pool
116    }
117
118    /// Common setup for client connections (greeting only, prewarming done at startup)
119    async fn setup_client_connection(
120        &self,
121        client_stream: &mut TcpStream,
122        client_addr: SocketAddr,
123    ) -> Result<()> {
124        // Send proxy greeting
125        crate::protocol::send_proxy_greeting(client_stream, client_addr).await
126    }
127
128    pub async fn handle_client(
129        &self,
130        mut client_stream: TcpStream,
131        client_addr: SocketAddr,
132    ) -> Result<()> {
133        debug!("New client connection from {}", client_addr);
134
135        // Use a dummy ClientId and command for routing (synchronous 1:1 mapping)
136        use types::ClientId;
137        let client_id = ClientId::new();
138
139        // Select backend using router's round-robin
140        let backend_id = self.router.route_command_sync(client_id, "")?;
141        let server_idx = backend_id.as_index();
142        let server = &self.servers[server_idx];
143
144        info!(
145            "Routing client {} to backend {:?} ({}:{})",
146            client_addr, backend_id, server.host, server.port
147        );
148
149        // Setup connection (prewarm and greeting)
150        self.setup_client_connection(&mut client_stream, client_addr)
151            .await?;
152
153        // Get pooled backend connection
154        let pool_status = self.connection_providers[server_idx].status();
155        debug!(
156            "Pool status for {}: {}/{} available, {} created",
157            server.name, pool_status.available, pool_status.max_size, pool_status.created
158        );
159
160        let mut backend_conn = match self.connection_providers[server_idx]
161            .get_pooled_connection()
162            .await
163        {
164            Ok(conn) => {
165                debug!("Got pooled connection for {}", server.name);
166                conn
167            }
168            Err(e) => {
169                error!(
170                    "Failed to get pooled connection for {} (client {}): {}",
171                    server.name, client_addr, e
172                );
173                let _ = client_stream.write_all(NNTP_BACKEND_UNAVAILABLE).await;
174                return Err(anyhow::anyhow!(
175                    "Failed to get pooled connection for backend '{}' (client {}): {}",
176                    server.name,
177                    client_addr,
178                    e
179                ));
180            }
181        };
182
183        // Apply socket optimizations for high-throughput
184        let client_optimizer = TcpOptimizer::new(&client_stream);
185        if let Err(e) = client_optimizer.optimize() {
186            debug!("Failed to optimize client socket: {}", e);
187        }
188        
189        let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
190        if let Err(e) = backend_optimizer.optimize() {
191            debug!("Failed to optimize backend socket: {}", e);
192        }
193
194        // Create session and handle connection
195        let session = ClientSession::new(client_addr, self.buffer_pool.clone());
196        debug!("Starting session for client {}", client_addr);
197
198        let copy_result = session
199            .handle_with_pooled_backend(client_stream, &mut *backend_conn)
200            .await;
201
202        debug!("Session completed for client {}", client_addr);
203
204        // Complete the routing (decrement pending count)
205        self.router.complete_command_sync(backend_id);
206
207        // Log session results
208        match copy_result {
209            Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
210                info!(
211                    "Connection closed for client {}: {} bytes sent, {} bytes received",
212                    client_addr, client_to_backend_bytes, backend_to_client_bytes
213                );
214            }
215            Err(e) => {
216                warn!("Session error for client {}: {}", client_addr, e);
217            }
218        }
219
220        debug!("Connection returned to pool for {}", server.name);
221        Ok(())
222    }
223
224    /// Handle client connection using per-command routing mode
225    ///
226    /// This creates a session with the router, allowing commands from this client
227    /// to be routed to different backends based on load balancing.
228    pub async fn handle_client_per_command_routing(
229        &self,
230        mut client_stream: TcpStream,
231        client_addr: SocketAddr,
232    ) -> Result<()> {
233        debug!(
234            "New per-command routing client connection from {}",
235            client_addr
236        );
237
238        // Enable TCP_NODELAY for low latency
239        let _ = client_stream.set_nodelay(true);
240
241        // Setup connection (prewarm and greeting)
242        self.setup_client_connection(&mut client_stream, client_addr)
243            .await?;
244
245        // Create session with router for per-command routing
246        let session = ClientSession::new_with_router(
247            client_addr,
248            self.buffer_pool.clone(),
249            self.router.clone(),
250        );
251
252        info!(
253            "Client {} (ID: {}) connected in per-command routing mode",
254            client_addr,
255            session.client_id()
256        );
257
258        // Handle the session with per-command routing
259        let result = session.handle_per_command_routing(client_stream).await;
260
261        // Log session results
262        match result {
263            Ok((client_to_backend, backend_to_client)) => {
264                info!(
265                    "Per-command routing session closed for {} (ID: {}): {} bytes sent, {} bytes received",
266                    client_addr,
267                    session.client_id(),
268                    client_to_backend,
269                    backend_to_client
270                );
271            }
272            Err(e) => {
273                // Check if this is a broken pipe error (normal for quick disconnections like SABnzbd tests)
274                let is_broken_pipe = if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
275                    matches!(io_err.kind(), std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset)
276                } else {
277                    false
278                };
279                
280                if is_broken_pipe {
281                    debug!(
282                        "Client {} (ID: {}) disconnected during session: {} - This is normal for test connections",
283                        client_addr,
284                        session.client_id(),
285                        e
286                    );
287                } else {
288                    warn!(
289                        "Per-command routing session error for {} (ID: {}): {}",
290                        client_addr,
291                        session.client_id(),
292                        e
293                    );
294                }
295                
296                // For debugging SABnzbd test connections and other short sessions,
297                // log additional context when transfers are small (likely test scenarios)
298                debug!(
299                    "Session error details for {} (ID: {}): Error occurred during per-command routing. \
300                     This may be a client test connection or early disconnection. \
301                     Check session debug logs above for command/response details.",
302                    client_addr,
303                    session.client_id()
304                );
305            }
306        }
307
308        debug!(
309            "Per-command routing connection closed for {} (ID: {})",
310            client_addr,
311            session.client_id()
312        );
313        Ok(())
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use std::sync::Arc;
321
322    fn create_test_config() -> Config {
323        Config {
324            servers: vec![
325                ServerConfig {
326                    host: "server1.example.com".to_string(),
327                    port: 119,
328                    name: "Test Server 1".to_string(),
329                    username: None,
330                    password: None,
331                    max_connections: 5,
332                    use_tls: false,
333                    tls_verify_cert: true,
334                    tls_cert_path: None,
335                },
336                ServerConfig {
337                    host: "server2.example.com".to_string(),
338                    port: 119,
339                    name: "Test Server 2".to_string(),
340                    username: None,
341                    password: None,
342                    max_connections: 8,
343                    use_tls: false,
344                    tls_verify_cert: true,
345                    tls_cert_path: None,
346                },
347                ServerConfig {
348                    host: "server3.example.com".to_string(),
349                    port: 119,
350                    name: "Test Server 3".to_string(),
351                    username: None,
352                    password: None,
353                    max_connections: 12,
354                    use_tls: false,
355                    tls_verify_cert: true,
356                    tls_cert_path: None,
357                },
358            ],
359            ..Default::default()
360        }
361    }
362
363    #[test]
364    fn test_proxy_creation_with_servers() {
365        let config = create_test_config();
366        let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
367
368        assert_eq!(proxy.servers().len(), 3);
369        assert_eq!(proxy.servers()[0].name, "Test Server 1");
370    }
371
372    #[test]
373    fn test_proxy_creation_with_empty_servers() {
374        let config = Config {
375            servers: vec![],
376            ..Default::default()
377        };
378        let result = NntpProxy::new(config);
379
380        assert!(result.is_err());
381        assert!(
382            result
383                .unwrap_err()
384                .to_string()
385                .contains("No servers configured")
386        );
387    }
388
389    #[test]
390    fn test_proxy_has_router() {
391        let config = create_test_config();
392        let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
393
394        // Proxy should have a router with backends
395        assert_eq!(proxy.router.backend_count(), 3);
396    }
397}