nntp_proxy/
lib.rs

1use anyhow::Result;
2use crossbeam::queue::SegQueue;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9use tokio::net::{TcpSocket, TcpStream};
10use tokio::sync::Semaphore;
11use tracing::{debug, error, info, warn};
12
13/// Default maximum connections per server
14fn default_max_connections() -> u32 {
15    10
16}
17
18/// Lock-free buffer pool for reusing large I/O buffers
19/// Uses crossbeam's SegQueue for lock-free operations
20#[derive(Debug, Clone)]
21pub struct BufferPool {
22    pool: Arc<SegQueue<Vec<u8>>>,
23    buffer_size: usize,
24    max_pool_size: usize,
25    pool_size: Arc<AtomicUsize>,
26}
27
28impl BufferPool {
29    /// Create a page-aligned buffer for optimal DMA performance
30    fn create_aligned_buffer(size: usize) -> Vec<u8> {
31        // Align to page boundaries (4KB) for better memory performance
32        let page_size = 4096;
33        let aligned_size = size.div_ceil(page_size) * page_size;
34
35        // Use aligned allocation for better cache performance
36        let mut buffer = Vec::with_capacity(aligned_size);
37        buffer.resize(size, 0);
38        buffer
39    }
40
41    pub fn new(buffer_size: usize, max_pool_size: usize) -> Self {
42        let pool = Arc::new(SegQueue::new());
43        let pool_size = Arc::new(AtomicUsize::new(0));
44
45        // Pre-allocate all buffers at startup to eliminate allocation overhead
46        info!(
47            "Pre-allocating {} buffers of {}KB each ({}MB total)",
48            max_pool_size,
49            buffer_size / 1024,
50            (max_pool_size * buffer_size) / (1024 * 1024)
51        );
52
53        for _ in 0..max_pool_size {
54            let buffer = Self::create_aligned_buffer(buffer_size);
55            pool.push(buffer);
56            pool_size.fetch_add(1, Ordering::Relaxed);
57        }
58
59        info!("Buffer pool pre-allocation complete");
60
61        Self {
62            pool,
63            buffer_size,
64            max_pool_size,
65            pool_size,
66        }
67    }
68
69    /// Get a buffer from the pool or create a new one (lock-free)
70    pub async fn get_buffer(&self) -> Vec<u8> {
71        if let Some(mut buffer) = self.pool.pop() {
72            self.pool_size.fetch_sub(1, Ordering::Relaxed);
73            // Reuse existing buffer, clear it first
74            buffer.clear();
75            buffer.resize(self.buffer_size, 0);
76            buffer
77        } else {
78            // Create new page-aligned buffer for better DMA performance
79            Self::create_aligned_buffer(self.buffer_size)
80        }
81    }
82
83    /// Return a buffer to the pool (lock-free)
84    pub async fn return_buffer(&self, buffer: Vec<u8>) {
85        if buffer.len() == self.buffer_size {
86            let current_size = self.pool_size.load(Ordering::Relaxed);
87            if current_size < self.max_pool_size {
88                self.pool.push(buffer);
89                self.pool_size.fetch_add(1, Ordering::Relaxed);
90            }
91            // If pool is full, just drop the buffer
92        }
93    }
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
97pub struct Config {
98    /// List of backend NNTP servers
99    pub servers: Vec<ServerConfig>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
103pub struct ServerConfig {
104    pub host: String,
105    pub port: u16,
106    pub name: String,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub username: Option<String>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub password: Option<String>,
111    /// Maximum number of concurrent connections to this server
112    #[serde(default = "default_max_connections")]
113    pub max_connections: u32,
114}
115
116/// Pooled connection wrapper
117#[derive(Debug)]
118pub struct PooledConnection {
119    pub stream: TcpStream,
120    pub server_name: String,
121    pub authenticated: bool,
122}
123
124impl PooledConnection {
125    pub fn new(
126        stream: TcpStream,
127        server_name: String,
128        authenticated: bool,
129    ) -> Self {
130        Self {
131            stream,
132            server_name,
133            authenticated,
134        }
135    }
136
137    pub fn into_stream(self) -> TcpStream {
138        self.stream
139    }
140
141    pub fn is_authenticated(&self) -> bool {
142        self.authenticated
143    }
144
145    pub fn server_name(&self) -> &str {
146        &self.server_name
147    }
148}
149
150/// Connection pool for backend servers
151#[derive(Debug, Clone)]
152pub struct ConnectionPool {
153    pool: Arc<SegQueue<TcpStream>>,
154    max_connections: usize,
155    active_connections: Arc<AtomicUsize>,
156    initialized: Arc<AtomicBool>,
157}
158
159impl ConnectionPool {
160    pub fn new(max_connections: usize) -> Self {
161        Self {
162            pool: Arc::new(SegQueue::new()),
163            max_connections,
164            active_connections: Arc::new(AtomicUsize::new(0)),
165            initialized: Arc::new(AtomicBool::new(false)),
166        }
167    }
168
169    /// Pre-establish all connections on first request for maximum performance
170    async fn initialize_connections(&self, server: &ServerConfig) -> Result<()> {
171        // Use compare_exchange to ensure only one thread initializes
172        if self
173            .initialized
174            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
175            .is_ok()
176        {
177            info!(
178                "Pre-establishing {} connections to {}",
179                self.max_connections, server.name
180            );
181
182            // Pre-establish all connections in parallel for faster startup
183            let mut tasks = Vec::new();
184            for i in 0..self.max_connections {
185                let server_addr = format!("{}:{}", server.host, server.port);
186                let server_name = server.name.clone();
187                let pool = Arc::clone(&self.pool);
188                let active_connections = Arc::clone(&self.active_connections);
189
190                let task = tokio::spawn(async move {
191                    match Self::create_optimized_tcp_stream(&server_addr).await {
192                        Ok(stream) => {
193                            pool.push(stream);
194                            active_connections.fetch_add(1, Ordering::Relaxed);
195                            debug!("Pre-established connection {} to {}", i + 1, server_name);
196                            Ok(())
197                        }
198                        Err(e) => {
199                            warn!(
200                                "Failed to pre-establish connection {} to {}: {}",
201                                i + 1,
202                                server_name,
203                                e
204                            );
205                            Err(e)
206                        }
207                    }
208                });
209                tasks.push(task);
210            }
211
212            // Wait for all connections to be established
213            for task in tasks {
214                let _ = task.await;
215            }
216
217            let established = self.active_connections.load(Ordering::Relaxed);
218            info!(
219                "Successfully pre-established {}/{} connections to {} in parallel",
220                established, self.max_connections, server.name
221            );
222        }
223        Ok(())
224    }
225
226    /// Get a connection from the pool or create a new one
227    pub async fn get_connection(
228        &self,
229        server: &ServerConfig,
230        _proxy: &NntpProxy,
231    ) -> Result<PooledConnection> {
232        // Pre-establish all connections on first request
233        if !self.initialized.load(Ordering::Acquire) {
234            self.initialize_connections(server).await?;
235        }
236
237        // Try to get a connection from the pool
238        if let Some(stream) = self.pool.pop() {
239            // Test if the connection is still alive by trying a non-blocking read
240            let mut test_buf = [0u8; 1];
241            match stream.try_read(&mut test_buf) {
242                Ok(0) => {
243                    // Connection was closed by server, decrease count and create new one
244                    self.active_connections.fetch_sub(1, Ordering::Relaxed);
245                    info!(
246                        "Pooled connection to {} was closed, creating new one",
247                        server.name
248                    );
249                }
250                Ok(_) => {
251                    // Got unexpected data, connection might be in use, decrease count and create new one
252                    self.active_connections.fetch_sub(1, Ordering::Relaxed);
253                    info!(
254                        "Pooled connection to {} has unexpected data, creating new one",
255                        server.name
256                    );
257                }
258                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
259                    // Connection is alive and ready
260                    info!("Reusing pooled connection to {}", server.name);
261                    return Ok(PooledConnection {
262                        stream,
263                        server_name: server.name.clone(),
264                        authenticated: false, // Reset authentication state for safety
265                    });
266                }
267                Err(_) => {
268                    // Connection error, decrease count and create new one
269                    self.active_connections.fetch_sub(1, Ordering::Relaxed);
270                    info!(
271                        "Pooled connection to {} has error, creating new one",
272                        server.name
273                    );
274                }
275            }
276        }
277
278        // Create new connection - don't authenticate here, let the caller handle it
279        info!("Creating new connection to {} for pooling", server.name);
280        let backend_addr = format!("{}:{}", server.host, server.port);
281        let stream = Self::create_optimized_tcp_stream(&backend_addr).await?;
282
283        // Return unauthenticated connection - authentication will be handled by caller
284        let pooled_conn = PooledConnection::new(
285            stream,
286            server.name.clone(),
287            false,
288        );
289        Ok(pooled_conn)
290    }
291
292    /// Create an optimized TCP stream with performance tuning
293    async fn create_optimized_tcp_stream(addr: &str) -> Result<TcpStream, std::io::Error> {
294        use std::net::{SocketAddr, ToSocketAddrs};
295
296        // Parse the address
297        let socket_addr: SocketAddr = addr.to_socket_addrs()?.next().ok_or_else(|| {
298            std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid address")
299        })?;
300
301        // Create socket with optimizations
302        let socket = if socket_addr.is_ipv4() {
303            TcpSocket::new_v4()?
304        } else {
305            TcpSocket::new_v6()?
306        };
307
308        // Apply TCP optimizations
309        socket.set_nodelay(true)?; // Disable Nagle's algorithm for low latency
310
311        // Set socket buffer sizes for high throughput
312        #[cfg(target_os = "linux")]
313        {
314            use std::os::unix::io::AsRawFd;
315            let fd = socket.as_raw_fd();
316
317            // Set much larger socket buffers for high-throughput large file transfers (2MB each)
318            let buffer_size = 2 * 1024 * 1024i32; // 2MB instead of 512KB
319            unsafe {
320                libc::setsockopt(
321                    fd,
322                    libc::SOL_SOCKET,
323                    libc::SO_RCVBUF,
324                    &buffer_size as *const i32 as *const libc::c_void,
325                    std::mem::size_of::<i32>() as u32,
326                );
327                libc::setsockopt(
328                    fd,
329                    libc::SOL_SOCKET,
330                    libc::SO_SNDBUF,
331                    &buffer_size as *const i32 as *const libc::c_void,
332                    std::mem::size_of::<i32>() as u32,
333                );
334
335                // Enable TCP keep-alive for connection reuse
336                let keepalive = 1i32;
337                libc::setsockopt(
338                    fd,
339                    libc::SOL_SOCKET,
340                    libc::SO_KEEPALIVE,
341                    &keepalive as *const i32 as *const libc::c_void,
342                    std::mem::size_of::<i32>() as u32,
343                );
344
345                // Set aggressive keep-alive timing for high-performance scenarios
346                let keepalive_time = 60i32; // Start probes after 60 seconds
347                let keepalive_interval = 10i32; // Probe every 10 seconds
348                let keepalive_probes = 3i32; // 3 failed probes before considering dead
349
350                libc::setsockopt(
351                    fd,
352                    libc::IPPROTO_TCP,
353                    libc::TCP_KEEPIDLE,
354                    &keepalive_time as *const i32 as *const libc::c_void,
355                    std::mem::size_of::<i32>() as u32,
356                );
357                libc::setsockopt(
358                    fd,
359                    libc::IPPROTO_TCP,
360                    libc::TCP_KEEPINTVL,
361                    &keepalive_interval as *const i32 as *const libc::c_void,
362                    std::mem::size_of::<i32>() as u32,
363                );
364                libc::setsockopt(
365                    fd,
366                    libc::IPPROTO_TCP,
367                    libc::TCP_KEEPCNT,
368                    &keepalive_probes as *const i32 as *const libc::c_void,
369                    std::mem::size_of::<i32>() as u32,
370                );
371
372                // Enable TCP_CORK for better packet batching at high speeds
373                let cork_flag = 1i32;
374                libc::setsockopt(
375                    fd,
376                    libc::IPPROTO_TCP,
377                    libc::TCP_CORK,
378                    &cork_flag as *const i32 as *const libc::c_void,
379                    std::mem::size_of::<i32>() as u32,
380                );
381
382                // Advanced TCP congestion control optimizations for large file transfers
383                // Set TCP congestion control algorithm to BBR if available (best for high BDP)
384                let bbr_name = b"bbr\0";
385                let bbr_result = libc::setsockopt(
386                    fd,
387                    libc::IPPROTO_TCP,
388                    libc::TCP_CONGESTION,
389                    bbr_name.as_ptr() as *const libc::c_void,
390                    bbr_name.len() as u32 - 1,
391                );
392
393                // If BBR fails, try cubic which is also excellent for large transfers
394                if bbr_result != 0 {
395                    let cubic_name = b"cubic\0";
396                    libc::setsockopt(
397                        fd,
398                        libc::IPPROTO_TCP,
399                        libc::TCP_CONGESTION,
400                        cubic_name.as_ptr() as *const libc::c_void,
401                        cubic_name.len() as u32 - 1,
402                    );
403                }
404
405                // Enable TCP Fast Open for faster connection establishment
406                let tcp_fastopen = 1i32; // Enable client-side Fast Open
407                libc::setsockopt(
408                    fd,
409                    libc::IPPROTO_TCP,
410                    libc::TCP_FASTOPEN,
411                    &tcp_fastopen as *const i32 as *const libc::c_void,
412                    std::mem::size_of::<i32>() as u32,
413                );
414
415                // Enable socket reuse for better connection distribution
416                let reuse_addr = 1i32;
417                libc::setsockopt(
418                    fd,
419                    libc::SOL_SOCKET,
420                    libc::SO_REUSEADDR,
421                    &reuse_addr as *const i32 as *const libc::c_void,
422                    std::mem::size_of::<i32>() as u32,
423                );
424
425                // Enable port reuse for better performance
426                let reuse_port = 1i32;
427                libc::setsockopt(
428                    fd,
429                    libc::SOL_SOCKET,
430                    libc::SO_REUSEPORT,
431                    &reuse_port as *const i32 as *const libc::c_void,
432                    std::mem::size_of::<i32>() as u32,
433                );
434            }
435        }
436
437        // Connect with the optimized socket
438        socket.connect(socket_addr).await
439    }
440
441    /// Return a connection to the pool
442    pub async fn return_connection(&self, conn: PooledConnection) {
443        if self.active_connections.load(Ordering::Relaxed) >= self.max_connections {
444            info!("Pool is full, closing connection to {}", conn.server_name);
445            return; // Pool is full, just drop the connection
446        }
447
448        // Test if the connection is still alive before returning to pool
449        let mut test_buf = [0u8; 1];
450        match conn.stream.try_read(&mut test_buf) {
451            Ok(0) => {
452                // Connection was closed by server
453                info!(
454                    "Connection to {} was closed by server, not returning to pool",
455                    conn.server_name
456                );
457                self.active_connections.fetch_sub(1, Ordering::Relaxed);
458            }
459            Ok(_) => {
460                // Got unexpected data, connection might be in use
461                info!(
462                    "Connection to {} has unexpected data, not returning to pool",
463                    conn.server_name
464                );
465                self.active_connections.fetch_sub(1, Ordering::Relaxed);
466            }
467            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
468                // Connection is alive and ready
469                info!("Returning connection to {} to pool", conn.server_name);
470                self.pool.push(conn.stream);
471                // Don't decrement active_connections here since we're keeping the connection
472            }
473            Err(_) => {
474                // Connection error
475                info!(
476                    "Connection to {} has error, not returning to pool",
477                    conn.server_name
478                );
479                self.active_connections.fetch_sub(1, Ordering::Relaxed);
480            }
481        }
482    }
483}
484
485#[derive(Clone, Debug)]
486pub struct NntpProxy {
487    servers: Vec<ServerConfig>,
488    current_index: Arc<AtomicUsize>,
489    /// Connection semaphores per server (server_name -> semaphore)
490    connection_semaphores: Arc<HashMap<String, Arc<Semaphore>>>,
491    /// Connection pool for backend connections
492    connection_pool: ConnectionPool,
493    /// Buffer pool for I/O operations
494    buffer_pool: BufferPool,
495}
496
497impl NntpProxy {
498    pub fn new(config: Config) -> Result<Self> {
499        if config.servers.is_empty() {
500            anyhow::bail!("No servers configured");
501        }
502
503        // Create connection semaphores for each server
504        let mut connection_semaphores = HashMap::new();
505        for server in &config.servers {
506            let semaphore = Arc::new(Semaphore::new(server.max_connections as usize));
507            connection_semaphores.insert(server.name.clone(), semaphore);
508            info!(
509                "Server '{}' configured with max {} connections",
510                server.name, server.max_connections
511            );
512        }
513
514        Ok(Self {
515            servers: config.servers,
516            current_index: Arc::new(AtomicUsize::new(0)),
517            connection_semaphores: Arc::new(connection_semaphores),
518            connection_pool: ConnectionPool::new(32), // Increased to 32 for better small file concurrency
519            buffer_pool: BufferPool::new(256 * 1024, 32), // 256KB buffers better for 100MB files with more available
520        })
521    }
522
523    /// Pre-warm connections to all servers for optimal small file performance
524    pub async fn prewarm_connections(&self) -> Result<()> {
525        info!("Pre-warming connections to all backend servers...");
526        for server in &self.servers {
527            // Pre-establish connections without authentication to avoid complexity
528            for i in 0..4 {
529                // Back to 4 connections per server for stability
530                match self.connection_pool.get_connection(server, self).await {
531                    Ok(conn) => {
532                        info!("Pre-warmed connection {}/4 to {}", i + 1, server.name);
533                        self.connection_pool.return_connection(conn).await;
534                    }
535                    Err(e) => {
536                        warn!("Failed to pre-warm connection to {}: {}", server.name, e);
537                    }
538                }
539            }
540        }
541        info!("Connection pre-warming complete");
542        Ok(())
543    }
544
545    /// Get the next server using round-robin
546    pub fn next_server(&self) -> &ServerConfig {
547        let index = self.current_index.fetch_add(1, Ordering::Relaxed);
548        &self.servers[index % self.servers.len()]
549    }
550
551    /// Get the current server index (for testing)
552    #[cfg(test)]
553    pub fn current_index(&self) -> usize {
554        self.current_index.load(Ordering::Relaxed) % self.servers.len()
555    }
556
557    /// Reset the server index (for testing)
558    #[cfg(test)]
559    pub fn reset_index(&self) {
560        self.current_index.store(0, Ordering::Relaxed);
561    }
562
563    /// Get the list of servers
564    pub fn servers(&self) -> &[ServerConfig] {
565        &self.servers
566    }
567
568    pub async fn handle_client(
569        &self,
570        mut client_stream: TcpStream,
571        client_addr: SocketAddr,
572    ) -> Result<()> {
573        info!("New client connection from {}", client_addr);
574
575        // Get the next backend server
576        let server = self.next_server();
577        info!(
578            "Routing client {} to server {}:{}",
579            client_addr, server.host, server.port
580        );
581
582        // Acquire connection permit for this server
583        let semaphore = self.connection_semaphores.get(&server.name).unwrap();
584        let _permit = match semaphore.try_acquire() {
585            Ok(permit) => {
586                info!(
587                    "Acquired connection permit for server '{}' ({} remaining)",
588                    server.name,
589                    semaphore.available_permits()
590                );
591                permit
592            }
593            Err(_) => {
594                warn!(
595                    "Server '{}' has reached max connections ({}), rejecting client",
596                    server.name, server.max_connections
597                );
598                let _ = client_stream
599                    .write_all(b"400 Server temporarily unavailable - too many connections\r\n")
600                    .await;
601                return Err(anyhow::anyhow!("Server {} at max connections", server.name));
602            }
603        };
604
605        // Try to get a pooled connection or create a new one
606        let backend_addr = format!("{}:{}", server.host, server.port);
607
608        // Try to get a pooled connection first
609        let (mut backend_stream, is_pooled, server_name, pooled_authenticated) =
610            match self.connection_pool.get_connection(server, self).await {
611                Ok(pooled) => {
612                    info!(
613                        "Using pooled connection to {} (authenticated: {})",
614                        pooled.server_name, pooled.authenticated
615                    );
616                    (
617                        pooled.stream,
618                        true,
619                        pooled.server_name.clone(),
620                        pooled.authenticated,
621                    )
622                }
623                Err(_) => {
624                    // If no pooled connection available, create a new one
625                    info!("Creating new connection to {}", backend_addr);
626                    match ConnectionPool::create_optimized_tcp_stream(&backend_addr).await {
627                        Ok(stream) => (stream, false, server.name.clone(), false),
628                        Err(e) => {
629                            error!("Failed to connect to backend {}: {}", backend_addr, e);
630                            let _ = client_stream
631                                .write_all(b"400 Backend server unavailable\r\n")
632                                .await;
633                            return Err(e.into());
634                        }
635                    }
636                }
637            };
638
639        info!("Connected to backend server {}", backend_addr);
640
641        // Simplified authentication - focus on speed over complexity
642        if !is_pooled || !pooled_authenticated {
643            // Only authenticate if not already authenticated
644            if let (Some(username), Some(password)) = (&server.username, &server.password) {
645                if let Err(e) = self
646                    .authenticate_backend(&mut backend_stream, username, password)
647                    .await
648                {
649                    error!("Authentication failed for {}: {}", server.name, e);
650                    let _ = client_stream
651                        .write_all(b"502 Authentication failed\r\n")
652                        .await;
653                    return Err(e);
654                }
655            }
656        }
657
658        // Simple greeting for protocol compliance
659        if let Err(e) = client_stream.write_all(b"200 NNTP Service Ready\r\n").await {
660            error!("Failed to send greeting to client: {}", e);
661            return Err(e.into());
662        }
663
664        // Try zero-copy first (Linux only), then fall back to high-performance buffered copying
665        let copy_result = {
666            #[cfg(target_os = "linux")]
667            {
668                match self
669                    .copy_bidirectional_zero_copy(&mut client_stream, &mut backend_stream)
670                    .await
671                {
672                    Ok(result) => {
673                        debug!("Zero-copy successful");
674                        Ok(result)
675                    }
676                    Err(_) => {
677                        debug!("Zero-copy failed, falling back to buffered copy");
678                        self.copy_bidirectional_buffered(&mut client_stream, &mut backend_stream)
679                            .await
680                    }
681                }
682            }
683            #[cfg(not(target_os = "linux"))]
684            {
685                self.copy_bidirectional_buffered(&mut client_stream, &mut backend_stream)
686                    .await
687            }
688        };
689
690        // Always try to return the connection to the pool (whether it was originally pooled or newly created)
691        // Determine if authentication was performed in this session
692        let was_authenticated = if is_pooled {
693            // If it was pooled, it might already be authenticated OR we just authenticated it
694            pooled_authenticated || (server.username.is_some() && server.password.is_some())
695        } else {
696            // If it was a new connection, it's authenticated if we have credentials
697            server.username.is_some() && server.password.is_some()
698        };
699
700        let pooled_conn = PooledConnection::new(
701            backend_stream,
702            server_name,
703            was_authenticated,
704        );
705        self.connection_pool.return_connection(pooled_conn).await;
706        info!(
707            "Returned connection to pool for {} (authenticated: {})",
708            server.name, was_authenticated
709        );
710
711        match copy_result {
712            Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
713                info!(
714                    "Connection closed for client {}: {} bytes client->backend, {} bytes backend->client",
715                    client_addr, client_to_backend_bytes, backend_to_client_bytes
716                );
717            }
718            Err(e) => {
719                warn!("Bidirectional copy error for client {}: {}", client_addr, e);
720            }
721        }
722
723        // The permit will be automatically dropped here when _permit goes out of scope
724        info!("Connection closed for client {}", client_addr);
725        Ok(())
726    }
727
728    /// Perform NNTP authentication using AUTHINFO USER/PASS commands
729    async fn authenticate_backend(
730        &self,
731        stream: &mut TcpStream,
732        username: &str,
733        password: &str,
734    ) -> Result<()> {
735        // Use a buffer from our optimized pool instead of small allocations
736        let mut buffer = self.buffer_pool.get_buffer().await;
737
738        // Read the server greeting first
739        let n = stream.read(&mut buffer).await?;
740        let greeting = &buffer[..n];
741        info!(
742            "Server greeting: {}",
743            String::from_utf8_lossy(greeting).trim()
744        );
745
746        // Check if greeting indicates successful connection (200)
747        let greeting_str = String::from_utf8_lossy(greeting);
748        if !greeting_str.starts_with("200") && !greeting_str.starts_with("201") {
749            return Err(anyhow::anyhow!(
750                "Server returned non-success greeting: {}",
751                greeting_str.trim()
752            ));
753        }
754
755        // Send AUTHINFO USER command
756        let user_command = format!("AUTHINFO USER {}\r\n", username);
757        stream.write_all(user_command.as_bytes()).await?;
758
759        // Read response
760        let n = stream.read(&mut buffer).await?;
761        let response = String::from_utf8_lossy(&buffer[..n]);
762        info!("AUTHINFO USER response: {}", response.trim());
763
764        // Should get 381 (password required) or 281 (authenticated)
765        if response.starts_with("281") {
766            // Already authenticated with just username
767            return Ok(());
768        } else if !response.starts_with("381") {
769            return Err(anyhow::anyhow!(
770                "Unexpected response to AUTHINFO USER: {}",
771                response.trim()
772            ));
773        }
774
775        // Send AUTHINFO PASS command
776        let pass_command = format!("AUTHINFO PASS {}\r\n", password);
777        stream.write_all(pass_command.as_bytes()).await?;
778
779        // Read final response
780        let n = stream.read(&mut buffer).await?;
781        let response = String::from_utf8_lossy(&buffer[..n]);
782        info!("AUTHINFO PASS response: {}", response.trim());
783
784        // Should get 281 (authenticated)
785        let result = if response.starts_with("281") {
786            Ok(())
787        } else {
788            Err(anyhow::anyhow!(
789                "Authentication failed: {}",
790                response.trim()
791            ))
792        };
793
794        // Return buffer to pool
795        self.buffer_pool.return_buffer(buffer).await;
796        result
797    }
798
799    /// Zero-copy bidirectional copy specifically for TcpStream pairs (Linux only)
800    #[cfg(target_os = "linux")]
801    async fn copy_bidirectional_zero_copy(
802        &self,
803        client_stream: &mut TcpStream,
804        backend_stream: &mut TcpStream,
805    ) -> Result<(u64, u64), std::io::Error> {
806        debug!("Starting optimized zero-copy bidirectional transfer");
807
808        // Apply aggressive socket optimizations for 1GB+ transfers
809        if let Err(e) = Self::set_high_throughput_optimizations(client_stream) {
810            debug!("Failed to set client socket optimizations: {}", e);
811        }
812        if let Err(e) = Self::set_high_throughput_optimizations(backend_stream) {
813            debug!("Failed to set backend socket optimizations: {}", e);
814        }
815
816        match tokio_splice2::copy_bidirectional(client_stream, backend_stream).await {
817            Ok(traffic_result) => {
818                debug!(
819                    "Zero-copy transfer successful: {} bytes (client->server: {}, server->client: {})",
820                    traffic_result.tx + traffic_result.rx,
821                    traffic_result.tx,
822                    traffic_result.rx
823                );
824                Ok((traffic_result.tx as u64, traffic_result.rx as u64))
825            }
826            Err(e) => {
827                debug!("Zero-copy failed: {}", e);
828                Err(e)
829            }
830        }
831    }
832
833    /// Set socket optimizations for high-throughput transfers
834    fn set_high_throughput_optimizations(stream: &TcpStream) -> Result<(), std::io::Error> {
835        use std::os::unix::io::AsRawFd;
836        let fd = stream.as_raw_fd();
837
838        unsafe {
839            // Keep Nagle's algorithm enabled for large transfers to reduce packet overhead
840            // (opposite of small transfer optimization)
841
842            // Enable TCP_QUICKACK for immediate ACKs
843            let quickack: libc::c_int = 1;
844            libc::setsockopt(
845                fd,
846                libc::IPPROTO_TCP,
847                libc::TCP_QUICKACK,
848                &quickack as *const _ as *const libc::c_void,
849                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
850            );
851
852            // Set larger TCP window scaling for high bandwidth-delay product
853            let window_clamp: libc::c_int = 16 * 1024 * 1024; // 16MB window
854            libc::setsockopt(
855                fd,
856                libc::IPPROTO_TCP,
857                libc::TCP_WINDOW_CLAMP,
858                &window_clamp as *const _ as *const libc::c_void,
859                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
860            );
861
862            // Optimize for high throughput with TCP_CORK equivalent (defer small packets)
863            let cork: libc::c_int = 1;
864            libc::setsockopt(
865                fd,
866                libc::IPPROTO_TCP,
867                libc::TCP_CORK,
868                &cork as *const _ as *const libc::c_void,
869                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
870            );
871
872            // Immediately uncork to flush
873            let uncork: libc::c_int = 0;
874            libc::setsockopt(
875                fd,
876                libc::IPPROTO_TCP,
877                libc::TCP_CORK,
878                &uncork as *const _ as *const libc::c_void,
879                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
880            );
881        }
882
883        Ok(())
884    }
885
886    /// High-performance bidirectional copy with zero-copy optimization
887    /// Attempts zero-copy transfer on Linux, falls back to pooled buffers
888    async fn copy_bidirectional_buffered<R, W>(
889        &self,
890        mut reader: R,
891        mut writer: W,
892    ) -> Result<(u64, u64), std::io::Error>
893    where
894        R: AsyncRead + AsyncWrite + Unpin,
895        W: AsyncRead + AsyncWrite + Unpin,
896    {
897        // Use high-throughput buffered copy with pooled buffers for generic streams
898        // Zero-copy is handled by the specialized copy_bidirectional_zero_copy method
899        // Optimized for sustained high-throughput transfers
900        use std::io::ErrorKind;
901        use tokio::io::{AsyncReadExt, AsyncWriteExt};
902
903        // Get larger buffers from the pool (256KB for high throughput)
904        let mut buf1 = self.buffer_pool.get_buffer().await;
905        let mut buf2 = self.buffer_pool.get_buffer().await;
906
907        let mut transferred_a_to_b = 0u64;
908        let mut transferred_b_to_a = 0u64;
909
910        // High-throughput copy with 256KB buffers reduces syscall overhead
911        let copy_result = async {
912            loop {
913                tokio::select! {
914                    // Copy from reader to writer with larger buffers
915                    result = reader.read(&mut buf1) => {
916                        match result {
917                            Ok(0) => break, // EOF
918                            Ok(n) => {
919                                writer.write_all(&buf1[..n]).await?;
920                                transferred_a_to_b += n as u64;
921                            }
922                            Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
923                            Err(e) => return Err(e),
924                        }
925                    }
926                    // Copy from writer to reader with larger buffers
927                    result = writer.read(&mut buf2) => {
928                        match result {
929                            Ok(0) => break, // EOF
930                            Ok(n) => {
931                                reader.write_all(&buf2[..n]).await?;
932                                transferred_b_to_a += n as u64;
933                            }
934                            Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
935                            Err(e) => return Err(e),
936                        }
937                    }
938                }
939            }
940            Ok((transferred_a_to_b, transferred_b_to_a))
941        }
942        .await;
943
944        // Return buffers to the pool
945        self.buffer_pool.return_buffer(buf1).await;
946        self.buffer_pool.return_buffer(buf2).await;
947
948        copy_result
949    }
950}
951
952pub fn load_config(config_path: &str) -> Result<Config> {
953    let config_content = std::fs::read_to_string(config_path)
954        .map_err(|e| anyhow::anyhow!("Failed to read config file '{}': {}", config_path, e))?;
955
956    let config: Config = toml::from_str(&config_content)
957        .map_err(|e| anyhow::anyhow!("Failed to parse config file '{}': {}", config_path, e))?;
958
959    Ok(config)
960}
961
962pub fn create_default_config() -> Config {
963    Config {
964        servers: vec![ServerConfig {
965            host: "news.example.com".to_string(),
966            port: 119,
967            name: "Example News Server".to_string(),
968            username: None,
969            password: None,
970            max_connections: default_max_connections(),
971        }],
972    }
973}
974
975#[cfg(test)]
976mod tests {
977    use super::*;
978    use std::io::Write;
979    use std::sync::Arc;
980    use tempfile::NamedTempFile;
981
982    fn create_test_config() -> Config {
983        Config {
984            servers: vec![
985                ServerConfig {
986                    host: "server1.example.com".to_string(),
987                    port: 119,
988                    name: "Test Server 1".to_string(),
989                    username: None,
990                    password: None,
991                    max_connections: 5,
992                },
993                ServerConfig {
994                    host: "server2.example.com".to_string(),
995                    port: 119,
996                    name: "Test Server 2".to_string(),
997                    username: None,
998                    password: None,
999                    max_connections: 8,
1000                },
1001                ServerConfig {
1002                    host: "server3.example.com".to_string(),
1003                    port: 119,
1004                    name: "Test Server 3".to_string(),
1005                    username: None,
1006                    password: None,
1007                    max_connections: 12,
1008                },
1009            ],
1010        }
1011    }
1012
1013    #[test]
1014    fn test_server_config_creation() {
1015        let config = ServerConfig {
1016            host: "news.example.com".to_string(),
1017            port: 119,
1018            name: "Example Server".to_string(),
1019            username: None,
1020            password: None,
1021            max_connections: 15,
1022        };
1023
1024        assert_eq!(config.host, "news.example.com");
1025        assert_eq!(config.port, 119);
1026        assert_eq!(config.name, "Example Server");
1027        assert_eq!(config.max_connections, 15);
1028    }
1029
1030    #[test]
1031    fn test_config_creation() {
1032        let config = create_test_config();
1033        assert_eq!(config.servers.len(), 3);
1034        assert_eq!(config.servers[0].name, "Test Server 1");
1035        assert_eq!(config.servers[1].name, "Test Server 2");
1036        assert_eq!(config.servers[2].name, "Test Server 3");
1037    }
1038
1039    #[test]
1040    fn test_proxy_creation_with_servers() {
1041        let config = create_test_config();
1042        let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1043
1044        assert_eq!(proxy.servers().len(), 3);
1045        assert_eq!(proxy.servers()[0].name, "Test Server 1");
1046    }
1047
1048    #[test]
1049    fn test_proxy_creation_with_empty_servers() {
1050        let config = Config { servers: vec![] };
1051        let result = NntpProxy::new(config);
1052
1053        assert!(result.is_err());
1054        assert!(
1055            result
1056                .unwrap_err()
1057                .to_string()
1058                .contains("No servers configured")
1059        );
1060    }
1061
1062    #[test]
1063    fn test_round_robin_server_selection() {
1064        let config = create_test_config();
1065        let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1066
1067        proxy.reset_index();
1068
1069        // Test first round
1070        assert_eq!(proxy.next_server().name, "Test Server 1");
1071        assert_eq!(proxy.next_server().name, "Test Server 2");
1072        assert_eq!(proxy.next_server().name, "Test Server 3");
1073
1074        // Test wraparound
1075        assert_eq!(proxy.next_server().name, "Test Server 1");
1076        assert_eq!(proxy.next_server().name, "Test Server 2");
1077    }
1078
1079    #[test]
1080    fn test_round_robin_with_single_server() {
1081        let config = Config {
1082            servers: vec![ServerConfig {
1083                host: "single.example.com".to_string(),
1084                port: 119,
1085                name: "Single Server".to_string(),
1086                username: None,
1087                password: None,
1088                max_connections: 3,
1089            }],
1090        };
1091
1092        let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1093        proxy.reset_index();
1094
1095        // All requests should go to the same server
1096        assert_eq!(proxy.next_server().name, "Single Server");
1097        assert_eq!(proxy.next_server().name, "Single Server");
1098        assert_eq!(proxy.next_server().name, "Single Server");
1099    }
1100
1101    #[test]
1102    fn test_concurrent_round_robin() {
1103        let config = create_test_config();
1104        let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
1105        proxy.reset_index();
1106
1107        let mut handles = vec![];
1108        let servers_selected = Arc::new(std::sync::Mutex::new(Vec::new()));
1109
1110        // Spawn multiple tasks to test concurrent access
1111        for _ in 0..9 {
1112            let proxy_clone = Arc::clone(&proxy);
1113            let servers_clone = Arc::clone(&servers_selected);
1114
1115            let handle = std::thread::spawn(move || {
1116                let server = proxy_clone.next_server();
1117                servers_clone.lock().unwrap().push(server.name.clone());
1118            });
1119            handles.push(handle);
1120        }
1121
1122        // Wait for all tasks to complete
1123        for handle in handles {
1124            handle.join().unwrap();
1125        }
1126
1127        let servers = servers_selected.lock().unwrap();
1128        assert_eq!(servers.len(), 9);
1129
1130        // Count occurrences of each server (should be balanced)
1131        let server1_count = servers.iter().filter(|&s| s == "Test Server 1").count();
1132        let server2_count = servers.iter().filter(|&s| s == "Test Server 2").count();
1133        let server3_count = servers.iter().filter(|&s| s == "Test Server 3").count();
1134
1135        // Each server should be selected 3 times
1136        assert_eq!(server1_count, 3);
1137        assert_eq!(server2_count, 3);
1138        assert_eq!(server3_count, 3);
1139    }
1140
1141    #[test]
1142    fn test_load_config_from_file() -> Result<()> {
1143        let config = create_test_config();
1144        let config_toml = toml::to_string_pretty(&config)?;
1145
1146        // Create a temporary file
1147        let mut temp_file = NamedTempFile::new()?;
1148        write!(temp_file, "{}", config_toml)?;
1149
1150        // Load config from file
1151        let loaded_config = load_config(temp_file.path().to_str().unwrap())?;
1152
1153        assert_eq!(loaded_config.servers.len(), 3);
1154        assert_eq!(loaded_config.servers[0].name, "Test Server 1");
1155        assert_eq!(loaded_config.servers[0].host, "server1.example.com");
1156        assert_eq!(loaded_config.servers[0].port, 119);
1157
1158        Ok(())
1159    }
1160
1161    #[test]
1162    fn test_load_config_nonexistent_file() {
1163        let result = load_config("/nonexistent/path/config.toml");
1164        assert!(result.is_err());
1165        assert!(
1166            result
1167                .unwrap_err()
1168                .to_string()
1169                .contains("Failed to read config file")
1170        );
1171    }
1172
1173    #[test]
1174    fn test_load_config_invalid_toml() -> Result<()> {
1175        let invalid_toml = "invalid toml content [[[";
1176
1177        // Create a temporary file with invalid TOML
1178        let mut temp_file = NamedTempFile::new()?;
1179        write!(temp_file, "{}", invalid_toml)?;
1180
1181        let result = load_config(temp_file.path().to_str().unwrap());
1182        assert!(result.is_err());
1183        assert!(
1184            result
1185                .unwrap_err()
1186                .to_string()
1187                .contains("Failed to parse config file")
1188        );
1189
1190        Ok(())
1191    }
1192
1193    #[test]
1194    fn test_create_default_config() {
1195        let config = create_default_config();
1196
1197        assert_eq!(config.servers.len(), 1);
1198        assert_eq!(config.servers[0].host, "news.example.com");
1199        assert_eq!(config.servers[0].port, 119);
1200        assert_eq!(config.servers[0].name, "Example News Server");
1201    }
1202
1203    #[test]
1204    fn test_config_serialization() -> Result<()> {
1205        let config = create_test_config();
1206
1207        // Serialize to TOML
1208        let toml_string = toml::to_string_pretty(&config)?;
1209        assert!(toml_string.contains("server1.example.com"));
1210        assert!(toml_string.contains("Test Server 1"));
1211
1212        // Deserialize back
1213        let deserialized: Config = toml::from_str(&toml_string)?;
1214        assert_eq!(deserialized, config);
1215
1216        Ok(())
1217    }
1218}