nntp_proxy/
session.rs

1//! Client session management
2//!
3//! This module handles the lifecycle of a client connection, including
4//! command processing, authentication interception, and data transfer.
5
6use anyhow::Result;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpStream;
11use tracing::{debug, error, warn};
12
13use crate::auth::AuthHandler;
14use crate::command::{AuthAction, CommandAction, CommandHandler};
15use crate::constants::buffer::{COMMAND_SIZE, STREAMING_CHUNK_SIZE};
16use crate::constants::protocol::{
17    BACKEND_ERROR, CONNECTION_CLOSING, PROXY_GREETING_PCR, TERMINATOR_TAIL_SIZE,
18};
19use crate::constants::stateless_proxy::NNTP_COMMAND_NOT_SUPPORTED;
20use crate::pool::BufferPool;
21use crate::router::BackendSelector;
22use crate::streaming::StreamHandler;
23use crate::types::ClientId;
24
25/// Represents an active client session
26pub struct ClientSession {
27    client_addr: SocketAddr,
28    buffer_pool: BufferPool,
29    /// Unique identifier for this client
30    client_id: ClientId,
31    /// Optional router for per-command routing mode
32    router: Option<Arc<BackendSelector>>,
33}
34
35impl ClientSession {
36    /// Create a new client session for 1:1 mode
37    #[must_use]
38    pub fn new(client_addr: SocketAddr, buffer_pool: BufferPool) -> Self {
39        Self {
40            client_addr,
41            buffer_pool,
42            client_id: ClientId::new(),
43            router: None,
44        }
45    }
46
47    /// Create a new client session for per-command routing mode
48    #[must_use]
49    pub fn new_with_router(
50        client_addr: SocketAddr,
51        buffer_pool: BufferPool,
52        router: Arc<BackendSelector>,
53    ) -> Self {
54        Self {
55            client_addr,
56            buffer_pool,
57            client_id: ClientId::new(),
58            router: Some(router),
59        }
60    }
61
62    /// Get the unique client ID
63    #[must_use]
64    #[inline]
65    pub fn client_id(&self) -> ClientId {
66        self.client_id
67    }
68
69    /// Check if this session is using per-command routing
70    #[must_use]
71    #[inline]
72    pub fn is_per_command_routing(&self) -> bool {
73        self.router.is_some()
74    }
75
76    /// Handle client connection with a pooled backend connection
77    /// This keeps the pooled connection object alive and returns it to the pool when done
78    /// Intercepts authentication commands since backend connection is already authenticated
79    pub async fn handle_with_pooled_backend<T>(
80        &self,
81        mut client_stream: TcpStream,
82        backend_conn: T,
83    ) -> Result<(u64, u64)>
84    where
85        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
86    {
87        use tokio::io::BufReader;
88
89        // Split streams for independent read/write
90        let (client_read, mut client_write) = client_stream.split();
91        let (mut backend_read, mut backend_write) = tokio::io::split(backend_conn);
92        let mut client_reader = BufReader::new(client_read);
93
94        let mut client_to_backend_bytes = 0u64;
95        let mut backend_to_client_bytes = 0u64;
96
97        // Reuse line buffer to avoid per-iteration allocations
98        let mut line = String::with_capacity(COMMAND_SIZE);
99
100        debug!("Client {} session loop starting", self.client_addr);
101
102        // Handle the initial command/response phase where we intercept auth
103        loop {
104            line.clear();
105            let mut buffer = self.buffer_pool.get_buffer().await;
106
107            tokio::select! {
108                // Read command from client
109                result = client_reader.read_line(&mut line) => {
110                    match result {
111                        Ok(0) => {
112                            debug!("Client {} disconnected (0 bytes read)", self.client_addr);
113                            self.buffer_pool.return_buffer(buffer).await;
114                            break; // Client disconnected
115                        }
116                        Ok(n) => {
117                            debug!("Client {} sent {} bytes: {:?}", self.client_addr, n, line.trim());
118                            let trimmed = line.trim();
119                            debug!("Client {} command: {}", self.client_addr, trimmed);
120
121                            // Handle command using CommandHandler
122                            match CommandHandler::handle_command(&line) {
123                                CommandAction::InterceptAuth(auth_action) => {
124                                    let response = match auth_action {
125                                        AuthAction::RequestPassword => AuthHandler::user_response(),
126                                        AuthAction::AcceptAuth => AuthHandler::pass_response(),
127                                    };
128                                    client_write.write_all(response).await?;
129                                    backend_to_client_bytes += response.len() as u64;
130                                    debug!("Intercepted auth command for client {}", self.client_addr);
131                                }
132                                CommandAction::Reject(_reason) => {
133                                    warn!("Rejecting command from client {}: {}", self.client_addr, trimmed);
134                                    client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
135                                    backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
136                                }
137                                CommandAction::ForwardHighThroughput => {
138                                    // Forward article retrieval by message-ID to backend
139                                    backend_write.write_all(line.as_bytes()).await?;
140                                    client_to_backend_bytes += line.len() as u64;
141                                    debug!("Client {} switching to high-throughput mode", self.client_addr);
142
143                                    // Return the buffer before transitioning
144                                    self.buffer_pool.return_buffer(buffer).await;
145
146                                    // For high-throughput data transfer, use optimized handler
147                                    return StreamHandler::high_throughput_transfer(
148                                        client_reader,
149                                        client_write,
150                                        backend_read,
151                                        backend_write,
152                                        client_to_backend_bytes,
153                                        backend_to_client_bytes,
154                                    ).await;
155                                }
156                                CommandAction::ForwardStateless => {
157                                    // Forward stateless commands to backend
158                                    backend_write.write_all(line.as_bytes()).await?;
159                                    client_to_backend_bytes += line.len() as u64;
160                                }
161                            }
162                        }
163                        Err(e) => {
164                            warn!("Error reading from client {}: {}", self.client_addr, e);
165                            self.buffer_pool.return_buffer(buffer).await;
166                            break;
167                        }
168                    }
169                }
170
171                // Read response from backend and forward to client (for non-auth commands)
172                result = backend_read.read(&mut buffer) => {
173                    match result {
174                        Ok(0) => {
175                            self.buffer_pool.return_buffer(buffer).await;
176                            break; // Backend disconnected
177                        }
178                        Ok(n) => {
179                            client_write.write_all(&buffer[..n]).await?;
180                            backend_to_client_bytes += n as u64;
181                        }
182                        Err(e) => {
183                            warn!("Error reading from backend for client {}: {}", self.client_addr, e);
184                            self.buffer_pool.return_buffer(buffer).await;
185                            break;
186                        }
187                    }
188                }
189            }
190
191            self.buffer_pool.return_buffer(buffer).await;
192        }
193
194        Ok((client_to_backend_bytes, backend_to_client_bytes))
195    }
196
197    /// Handle a client connection with per-command routing
198    /// Each command is routed independently to potentially different backends
199    pub async fn handle_per_command_routing(
200        &self,
201        mut client_stream: TcpStream,
202    ) -> Result<(u64, u64)> {
203        use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
204        
205        debug!(
206            "Client {} starting per-command routing session",
207            self.client_addr
208        );
209
210        let router = self
211            .router
212            .as_ref()
213            .ok_or_else(|| anyhow::anyhow!("Per-command routing mode requires a router"))?;
214
215        let (client_read, mut client_write) = client_stream.split();
216        let mut client_reader = BufReader::new(client_read);
217
218        let mut client_to_backend_bytes = 0u64;
219        let mut backend_to_client_bytes = 0u64;
220
221        // Send initial greeting to client
222        debug!(
223            "Client {} sending greeting: {} | hex: {:02x?}",
224            self.client_addr,
225            String::from_utf8_lossy(PROXY_GREETING_PCR),
226            PROXY_GREETING_PCR
227        );
228        
229        if let Err(e) = client_write.write_all(PROXY_GREETING_PCR).await {
230            debug!(
231                "Client {} failed to send greeting: {} (kind: {:?}). \
232                 This suggests the client disconnected immediately after connecting.",
233                self.client_addr, e, e.kind()
234            );
235            return Err(e.into());
236        }
237        backend_to_client_bytes += PROXY_GREETING_PCR.len() as u64;
238
239        debug!(
240            "Client {} sent greeting successfully, entering command loop",
241            self.client_addr
242        );
243
244        // Reuse command buffer to avoid allocations per command
245        let mut command = String::with_capacity(COMMAND_SIZE);
246
247        // Process commands one at a time
248        loop {
249            command.clear();
250
251            match client_reader.read_line(&mut command).await {
252                Ok(0) => {
253                    debug!("Client {} disconnected", self.client_addr);
254                    break; // Client disconnected
255                }
256                Ok(n) => {
257                    client_to_backend_bytes += n as u64;
258                    let trimmed = command.trim();
259
260                    debug!(
261                        "Client {} received command ({} bytes): {} | hex: {:02x?}",
262                        self.client_addr, n, trimmed, command.as_bytes()
263                    );
264
265                    // Handle QUIT locally
266                    if trimmed.eq_ignore_ascii_case("QUIT") {
267                        // Send closing message - ignore errors if client already disconnected, but log for debugging
268                        if let Err(e) = client_write.write_all(CONNECTION_CLOSING).await {
269                            debug!(
270                                "Failed to write CONNECTION_CLOSING to client {}: {}",
271                                self.client_addr, e
272                            );
273                        }
274                        backend_to_client_bytes += CONNECTION_CLOSING.len() as u64;
275                        debug!("Client {} sent QUIT, closing connection", self.client_addr);
276                        break;
277                    }
278
279                    // Check if command should be rejected (stateful commands)
280                    match CommandHandler::handle_command(&command) {
281                        CommandAction::InterceptAuth(auth_action) => {
282                            // Handle authentication locally
283                            let response = match auth_action {
284                                AuthAction::RequestPassword => AuthHandler::user_response(),
285                                AuthAction::AcceptAuth => AuthHandler::pass_response(),
286                            };
287                            client_write.write_all(response).await?;
288                            backend_to_client_bytes += response.len() as u64;
289                            continue;
290                        }
291                        CommandAction::Reject(reason) => {
292                            warn!(
293                                "Rejecting command from client {}: {} ({})",
294                                self.client_addr, trimmed, reason
295                            );
296                            client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
297                            backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
298                            continue;
299                        }
300                        CommandAction::ForwardStateless | CommandAction::ForwardHighThroughput => {
301                            // Route this command to a backend
302                            match self
303                                .route_and_execute_command(
304                                    router,
305                                    &command,
306                                    &mut client_write,
307                                    &mut client_to_backend_bytes,
308                                    &mut backend_to_client_bytes,
309                                )
310                                .await
311                            {
312                                Ok(()) => {}
313                                Err(e) => {
314                                    // Provide detailed context for broken pipe errors
315                                    if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
316                                        match io_err.kind() {
317                                            std::io::ErrorKind::BrokenPipe => {
318                                                warn!(
319                                                    "Client {} disconnected unexpectedly while routing command '{}' (broken pipe). \
320                                                     Session stats: {} bytes sent to backend, {} bytes received from backend. \
321                                                     This usually indicates the client closed the connection before receiving the response.",
322                                                    self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
323                                                );
324                                            }
325                                            std::io::ErrorKind::ConnectionReset => {
326                                                warn!(
327                                                    "Client {} connection reset while routing command '{}'. \
328                                                     Session stats: {} bytes sent to backend, {} bytes received from backend. \
329                                                     This usually indicates a network issue or client crash.",
330                                                    self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
331                                                );
332                                            }
333                                            std::io::ErrorKind::ConnectionAborted => {
334                                                warn!(
335                                                    "Client {} connection aborted while routing command '{}'. \
336                                                     Session stats: {} bytes sent to backend, {} bytes received from backend. \
337                                                     This usually indicates the connection was terminated by the local system.",
338                                                    self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
339                                                );
340                                            }
341                                            _ => {
342                                                error!(
343                                                    "I/O error routing command '{}' for client {}: {} (kind: {:?}). \
344                                                     Session stats: {} bytes sent to backend, {} bytes received from backend.",
345                                                    trimmed, self.client_addr, e, io_err.kind(), client_to_backend_bytes, backend_to_client_bytes
346                                                );
347                                            }
348                                        }
349                                    } else {
350                                        error!(
351                                            "Error routing command '{}' for client {}: {}. \
352                                             Session stats: {} bytes sent to backend, {} bytes received from backend.",
353                                            trimmed, self.client_addr, e, client_to_backend_bytes, backend_to_client_bytes
354                                        );
355                                    }
356                                    
357                                    // Try to send error response, but don't log failure if client is gone
358                                    let _ = client_write.write_all(BACKEND_ERROR).await;
359                                    backend_to_client_bytes += BACKEND_ERROR.len() as u64;
360                                    
361                                    // For debugging test connections and small transfers, log detailed info
362                                    if client_to_backend_bytes + backend_to_client_bytes < 500 {
363                                        debug!(
364                                            "ERROR SUMMARY for small transfer - Client {}: \
365                                             Command '{}' failed with {}. \
366                                             Total session: {} bytes to backend, {} bytes from backend. \
367                                             This appears to be a short session (test connection?). \
368                                             Check debug logs above for full command/response hex dumps.",
369                                            self.client_addr, trimmed, e, client_to_backend_bytes, backend_to_client_bytes
370                                        );
371                                    }
372                                }
373                            }
374                        }
375                    }
376                }
377                Err(e) => {
378                    // Provide detailed context for client read errors
379                    match e.kind() {
380                        std::io::ErrorKind::UnexpectedEof => {
381                            debug!(
382                                "Client {} closed connection (EOF). Session stats: {} bytes sent to backend, {} bytes received from backend.",
383                                self.client_addr, client_to_backend_bytes, backend_to_client_bytes
384                            );
385                        }
386                        std::io::ErrorKind::BrokenPipe => {
387                            debug!(
388                                "Client {} connection broken pipe while reading. Session stats: {} bytes sent to backend, {} bytes received from backend.",
389                                self.client_addr, client_to_backend_bytes, backend_to_client_bytes
390                            );
391                        }
392                        std::io::ErrorKind::ConnectionReset => {
393                            warn!(
394                                "Client {} connection reset while reading. Session stats: {} bytes sent to backend, {} bytes received from backend.",
395                                self.client_addr, client_to_backend_bytes, backend_to_client_bytes
396                            );
397                        }
398                        _ => {
399                            warn!(
400                                "Error reading from client {}: {} (kind: {:?}). Session stats: {} bytes sent to backend, {} bytes received from backend.",
401                                self.client_addr, e, e.kind(), client_to_backend_bytes, backend_to_client_bytes
402                            );
403                        }
404                    }
405                    break;
406                }
407            }
408        }
409
410        // Log session summary for debugging, especially useful for test connections
411        if client_to_backend_bytes + backend_to_client_bytes < 500 {
412            debug!(
413                "SESSION SUMMARY for Client {}: Small transfer completed successfully. \
414                 {} bytes sent to backend, {} bytes received from backend. \
415                 This appears to be a short session (likely test connection). \
416                 Check debug logs above for individual command/response details.",
417                self.client_addr, client_to_backend_bytes, backend_to_client_bytes
418            );
419        }
420
421        Ok((client_to_backend_bytes, backend_to_client_bytes))
422    }
423
424    /// Route a single command to a backend and execute it
425    async fn route_and_execute_command(
426        &self,
427        router: &BackendSelector,
428        command: &str,
429        client_write: &mut tokio::net::tcp::WriteHalf<'_>,
430        client_to_backend_bytes: &mut u64,
431        backend_to_client_bytes: &mut u64,
432    ) -> Result<()> {
433        use tokio::io::AsyncWriteExt;
434
435        // Route the command to get a backend (lock-free!)
436        let backend_id = router.route_command_sync(self.client_id, command)?;
437
438        debug!(
439            "Client {} routed command to backend {:?}: {}",
440            self.client_addr,
441            backend_id,
442            command.trim()
443        );
444
445        // Get a connection from the router's backend pool
446        let provider = router
447            .get_backend_provider(backend_id)
448            .ok_or_else(|| anyhow::anyhow!("Backend {:?} not found", backend_id))?;
449
450        debug!(
451            "Client {} getting pooled connection for backend {:?}",
452            self.client_addr, backend_id
453        );
454        // Use get_pooled_connection() to get a connection that auto-returns to pool
455        // The pool's recycle() method will health-check connections before reuse
456        // so we don't get stale connections that timed out on the backend
457        let mut pooled_conn = provider.get_pooled_connection().await?;
458        debug!(
459            "Client {} got pooled connection for backend {:?}",
460            self.client_addr, backend_id
461        );
462
463        // Connection from pool is already authenticated - no need to consume greeting or auth again
464
465        // Forward the command to the backend
466        debug!(
467            "Client {} forwarding command to backend {:?} ({} bytes): {} | hex: {:02x?}",
468            self.client_addr,
469            backend_id,
470            command.len(),
471            command.trim(),
472            command.as_bytes()
473        );
474        pooled_conn.write_all(command.as_bytes()).await?;
475        *client_to_backend_bytes += command.len() as u64;
476        debug!(
477            "Client {} command sent to backend {:?}",
478            self.client_addr, backend_id
479        );
480
481        // Read the response from the backend
482        debug!(
483            "Client {} reading response from backend {:?}",
484            self.client_addr, backend_id
485        );
486
487        // Use direct reading from backend - no split() to avoid mutex overhead
488        use tokio::io::AsyncReadExt;
489
490        let mut chunk = vec![0u8; STREAMING_CHUNK_SIZE];
491        let mut total_bytes = 0;
492
493        // Read first chunk to determine response type
494        let n = pooled_conn.read(&mut chunk).await?;
495        if n == 0 {
496            return Err(anyhow::anyhow!("Backend connection closed unexpectedly"));
497        }
498        
499        debug!(
500            "Client {} received backend response chunk ({} bytes): {} | hex: {:02x?}",
501            self.client_addr, n,
502            String::from_utf8_lossy(&chunk[..n.min(100)]), // Show first 100 bytes max
503            &chunk[..n.min(32)] // Show first 32 bytes in hex
504        );
505
506        // Find first newline to determine if multiline
507        let first_newline = chunk[..n].iter().position(|&b| b == b'\n').unwrap_or(n);
508
509        // Multiline responses have second digit of 1, 2, or 3 (e.g., 215, 220-225, 230-235)
510        // Single-line responses have second digit of 0, 4, or 8 (e.g., 200, 201, 205, 400, 480)
511        // See RFC 3977 Section 3.2: https://tools.ietf.org/html/rfc3977#section-3.2
512        // "Multi-line data blocks are used for responses where the length is not known
513        //  in advance... The response code for a multi-line response will always begin
514        //  with the digit 2 or 3, and the second digit will be 1, 2, or 3."
515        let is_multiline = first_newline >= 3
516            && chunk[0] == b'2'
517            && (chunk[1] == b'1' || chunk[1] == b'2' || chunk[1] == b'3');
518
519        // Log first line (best effort)
520        if let Ok(first_line_str) = std::str::from_utf8(&chunk[..first_newline.min(n)]) {
521            debug!(
522                "Client {} got first line from backend {:?}: {}",
523                self.client_addr,
524                backend_id,
525                first_line_str.trim()
526            );
527        }
528
529        // Write first chunk directly to client
530        debug!(
531            "Client {} sending first chunk ({} bytes): {} | hex: {:02x?}",
532            self.client_addr, n, 
533            String::from_utf8_lossy(&chunk[..n.min(100)]), // Show first 100 bytes max
534            &chunk[..n.min(32)] // Show first 32 bytes in hex
535        );
536        client_write.write_all(&chunk[..n]).await?;
537        total_bytes += n;
538
539        if is_multiline {
540            // Fast check if terminator is in first chunk (check end only)
541            let has_terminator = if n >= 5 {
542                chunk[n - 5..n] == *b"\r\n.\r\n" || (n >= 3 && chunk[n - 3..n] == *b"\n.\n")
543            } else {
544                n >= 3 && chunk[..n] == *b"\n.\n"
545            };
546
547            if !has_terminator {
548                // For multiline responses, use pipelined streaming
549                // Prepare double buffering for concurrent read/write
550                let mut chunk1 = chunk; // Reuse first buffer
551                let mut chunk2 = vec![0u8; STREAMING_CHUNK_SIZE]; // Second buffer for pipelining
552
553                let mut tail: [u8; TERMINATOR_TAIL_SIZE] = [0; TERMINATOR_TAIL_SIZE]; // Fixed-size tail for span detection
554                let mut tail_len: usize = 0; // How much of tail is valid
555
556                // Initialize tail with last bytes of first chunk (already written above)
557                if n >= TERMINATOR_TAIL_SIZE {
558                    tail.copy_from_slice(&chunk1[n - TERMINATOR_TAIL_SIZE..n]);
559                    tail_len = TERMINATOR_TAIL_SIZE;
560                } else if n > 0 {
561                    tail[..n].copy_from_slice(&chunk1[..n]);
562                    tail_len = n;
563                }
564
565                // Check terminator in first chunk (already written)
566                let first_has_term = if n >= 5 {
567                    chunk1[n - 5..n] == *b"\r\n.\r\n" || (n >= 3 && chunk1[n - 3..n] == *b"\n.\n")
568                } else {
569                    n >= 3 && chunk1[..n] == *b"\n.\n"
570                };
571
572                if !first_has_term {
573                    // First chunk didn't have terminator, continue reading
574                    let mut current_chunk = &mut chunk1;
575                    let mut next_chunk = &mut chunk2;
576
577                    // Read next chunk and start loop
578                    let mut current_n = pooled_conn.read(next_chunk).await?;
579                    if current_n > 0 {
580                        std::mem::swap(&mut current_chunk, &mut next_chunk);
581
582                        loop {
583                            // Write current chunk to client
584                            debug!(
585                                "Client {} sending streaming chunk ({} bytes): {} | hex: {:02x?}",
586                                self.client_addr, current_n,
587                                String::from_utf8_lossy(&current_chunk[..current_n.min(100)]), // Show first 100 bytes max
588                                &current_chunk[..current_n.min(32)] // Show first 32 bytes in hex
589                            );
590                            client_write.write_all(&current_chunk[..current_n]).await?;
591                            total_bytes += current_n;
592
593                            // Check terminator in chunk we just wrote
594                            let has_term = if current_n >= 5 {
595                                current_chunk[current_n - 5..current_n] == *b"\r\n.\r\n"
596                                    || (current_n >= 3
597                                        && current_chunk[current_n - 3..current_n] == *b"\n.\n")
598                            } else {
599                                current_n >= 3 && current_chunk[..current_n] == *b"\n.\n"
600                            };
601
602                            if has_term {
603                                break; // Done! We already wrote the final chunk
604                            }
605
606                            // Check boundary spanning terminator (ONLY if current chunk is small enough)
607                            // This is rare - only check if terminator could span from previous chunk
608                            let has_spanning_term = if tail_len >= 2 && (1..=4).contains(&current_n)
609                            {
610                                // Build combined view: tail + start of current chunk
611                                let mut check_buf = [0u8; 9]; // max: 4 tail + 5 current
612                                check_buf[..tail_len].copy_from_slice(&tail[..tail_len]);
613                                let curr_copy = current_n.min(5);
614                                check_buf[tail_len..tail_len + curr_copy]
615                                    .copy_from_slice(&current_chunk[..curr_copy]);
616                                let total = tail_len + curr_copy;
617
618                                (total >= 5 && check_buf[total - 5..total] == *b"\r\n.\r\n")
619                                    || (total >= 3 && check_buf[total - 3..total] == *b"\n.\n")
620                            } else {
621                                false
622                            };
623
624                            if has_spanning_term {
625                                break; // Done! We already wrote the final chunk
626                            }
627
628                            // Update tail for next iteration (only last 4 bytes)
629                            if current_n >= TERMINATOR_TAIL_SIZE {
630                                tail.copy_from_slice(
631                                    &current_chunk[current_n - TERMINATOR_TAIL_SIZE..current_n],
632                                );
633                                tail_len = TERMINATOR_TAIL_SIZE;
634                            } else if current_n > 0 {
635                                tail[..current_n].copy_from_slice(&current_chunk[..current_n]);
636                                tail_len = current_n;
637                            }
638
639                            // Read next chunk
640                            let next_n = pooled_conn.read(next_chunk).await?;
641                            if next_n == 0 {
642                                break; // EOF
643                            }
644
645                            // Swap buffers for next iteration
646                            std::mem::swap(&mut current_chunk, &mut next_chunk);
647                            current_n = next_n;
648                        }
649                    }
650                }
651            }
652        }
653
654        debug!(
655            "Client {} forwarded response ({} bytes) to client",
656            self.client_addr, total_bytes
657        );
658        *backend_to_client_bytes += total_bytes as u64;
659
660        // Complete the request - decrement pending count (lock-free!)
661        router.complete_command_sync(backend_id);
662
663        Ok(())
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use std::net::{IpAddr, Ipv4Addr};
671
672    #[test]
673    fn test_client_session_creation() {
674        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
675        let buffer_pool = BufferPool::new(1024, 4);
676        let session = ClientSession::new(addr, buffer_pool.clone());
677
678        assert_eq!(session.client_addr.port(), 8080);
679        assert_eq!(
680            session.client_addr.ip(),
681            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
682        );
683    }
684
685    #[test]
686    fn test_client_session_with_different_ports() {
687        let buffer_pool = BufferPool::new(1024, 4);
688
689        let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
690        let session1 = ClientSession::new(addr1, buffer_pool.clone());
691
692        let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9090);
693        let session2 = ClientSession::new(addr2, buffer_pool.clone());
694
695        assert_ne!(session1.client_addr.port(), session2.client_addr.port());
696        assert_eq!(session1.client_addr.port(), 8080);
697        assert_eq!(session2.client_addr.port(), 9090);
698    }
699
700    #[test]
701    fn test_client_session_with_ipv6() {
702        let buffer_pool = BufferPool::new(1024, 4);
703        let addr = SocketAddr::new(IpAddr::V6("::1".parse().unwrap()), 8119);
704        let session = ClientSession::new(addr, buffer_pool);
705
706        assert_eq!(session.client_addr.port(), 8119);
707        assert!(session.client_addr.is_ipv6());
708    }
709
710    #[test]
711    fn test_buffer_pool_cloning() {
712        let buffer_pool = BufferPool::new(8192, 10);
713        let buffer_pool_clone = buffer_pool.clone();
714
715        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 1234);
716        let _session1 = ClientSession::new(addr, buffer_pool);
717        let _session2 = ClientSession::new(addr, buffer_pool_clone);
718
719        // Both sessions should work with the same underlying pool
720    }
721
722    #[test]
723    fn test_session_addr_formatting() {
724        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5555);
725        let buffer_pool = BufferPool::new(1024, 4);
726        let session = ClientSession::new(addr, buffer_pool);
727
728        let addr_str = format!("{}", session.client_addr);
729        assert!(addr_str.contains("10.0.0.1"));
730        assert!(addr_str.contains("5555"));
731    }
732
733    #[test]
734    fn test_multiple_sessions_same_buffer_pool() {
735        let buffer_pool = BufferPool::new(4096, 8);
736        let sessions: Vec<_> = (0..5)
737            .map(|i| {
738                let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8000 + i);
739                ClientSession::new(addr, buffer_pool.clone())
740            })
741            .collect();
742
743        assert_eq!(sessions.len(), 5);
744        for (i, session) in sessions.iter().enumerate() {
745            assert_eq!(session.client_addr.port(), 8000 + i as u16);
746        }
747    }
748
749    #[test]
750    fn test_loopback_address() {
751        let buffer_pool = BufferPool::new(1024, 4);
752        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8119);
753        let session = ClientSession::new(addr, buffer_pool);
754
755        assert!(session.client_addr.ip().is_loopback());
756    }
757
758    #[test]
759    fn test_unspecified_address() {
760        let buffer_pool = BufferPool::new(1024, 4);
761        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
762        let session = ClientSession::new(addr, buffer_pool);
763
764        assert!(session.client_addr.ip().is_unspecified());
765        assert_eq!(session.client_addr.port(), 0);
766    }
767
768    #[test]
769    fn test_session_without_router() {
770        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
771        let buffer_pool = BufferPool::new(1024, 4);
772        let session = ClientSession::new(addr, buffer_pool);
773
774        assert!(!session.is_per_command_routing());
775        assert_eq!(session.client_addr.port(), 8080);
776    }
777
778    #[test]
779    fn test_session_with_router() {
780        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
781        let buffer_pool = BufferPool::new(1024, 4);
782        let router = Arc::new(BackendSelector::new());
783        let session = ClientSession::new_with_router(addr, buffer_pool, router);
784
785        assert!(session.is_per_command_routing());
786        assert_eq!(session.client_addr.port(), 8080);
787    }
788
789    #[test]
790    fn test_client_id_uniqueness() {
791        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
792        let buffer_pool = BufferPool::new(1024, 4);
793
794        let session1 = ClientSession::new(addr, buffer_pool.clone());
795        let session2 = ClientSession::new(addr, buffer_pool);
796
797        // Each session should have a unique client ID
798        assert_ne!(session1.client_id(), session2.client_id());
799    }
800
801    #[tokio::test]
802    async fn test_quit_command_per_command_routing() {
803        use tokio::io::{AsyncReadExt, AsyncWriteExt};
804        use tokio::net::TcpListener;
805
806        // Start a mock server for the backend
807        let backend_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
808        let backend_addr = backend_listener.local_addr().unwrap();
809
810        // Spawn mock backend server
811        tokio::spawn(async move {
812            if let Ok((mut stream, _)) = backend_listener.accept().await {
813                // Send greeting
814                let _ = stream.write_all(b"200 Mock Server Ready\r\n").await;
815
816                // Read and discard any commands, keep connection alive briefly
817                let mut buf = [0u8; 1024];
818                let _ = stream.read(&mut buf).await;
819            }
820        });
821
822        // Give server time to start
823        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
824
825        // Create a router with the mock backend
826        let mut router = BackendSelector::new();
827        let provider = crate::pool::DeadpoolConnectionProvider::new(
828            "127.0.0.1".to_string(),
829            backend_addr.port(),
830            "test-backend".to_string(),
831            2,
832            None,
833            None,
834        );
835        router.add_backend(
836            crate::types::BackendId::from_index(0),
837            "test-backend".to_string(),
838            provider,
839        );
840
841        // Create client connection
842        let client_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
843        let client_addr = client_listener.local_addr().unwrap();
844
845        // Create session
846        let buffer_pool = BufferPool::new(1024, 4);
847        let session = ClientSession::new_with_router(client_addr, buffer_pool, Arc::new(router));
848
849        // Spawn client that sends QUIT and immediately closes
850        let client_handle = tokio::spawn(async move {
851            let mut client = tokio::net::TcpStream::connect(client_addr).await.unwrap();
852
853            // Read greeting
854            let mut greeting = [0u8; 256];
855            let n = client.read(&mut greeting).await.unwrap();
856            assert!(n > 0);
857
858            // Send QUIT command
859            client.write_all(b"QUIT\r\n").await.unwrap();
860
861            // Try to read response (might fail if we close too fast, which is fine)
862            let mut response = [0u8; 256];
863            let _ = client.read(&mut response).await;
864
865            // Close connection immediately (simulating SABnzbd behavior)
866            drop(client);
867        });
868
869        // Accept client connection
870        let (client_stream, _) = client_listener.accept().await.unwrap();
871
872        // Handle the session - should not return an error despite client closing
873        let result = session.handle_per_command_routing(client_stream).await;
874
875        // Should succeed (not propagate broken pipe error)
876        assert!(
877            result.is_ok(),
878            "QUIT handling should not return error: {:?}",
879            result
880        );
881
882        if let Ok((sent, received)) = result {
883            // Should have sent QUIT command
884            assert!(sent > 0, "Should have sent bytes (QUIT command)");
885            // Should have received greeting and possibly closing message
886            assert!(received > 0, "Should have received bytes (greeting)");
887        }
888
889        // Wait for client to finish
890        let _ = client_handle.await;
891    }
892
893    #[tokio::test]
894    async fn test_quit_command_closes_connection_cleanly() {
895        use tokio::io::{AsyncReadExt, AsyncWriteExt};
896        use tokio::net::TcpListener;
897
898        // Start a mock backend
899        let backend_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900        let backend_addr = backend_listener.local_addr().unwrap();
901
902        tokio::spawn(async move {
903            if let Ok((mut stream, _)) = backend_listener.accept().await {
904                let _ = stream.write_all(b"200 Ready\r\n").await;
905                let mut buf = [0u8; 1024];
906                let _ = stream.read(&mut buf).await;
907            }
908        });
909
910        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
911
912        // Create router
913        let mut router = BackendSelector::new();
914        let provider = crate::pool::DeadpoolConnectionProvider::new(
915            "127.0.0.1".to_string(),
916            backend_addr.port(),
917            "test".to_string(),
918            1,
919            None,
920            None,
921        );
922        router.add_backend(
923            crate::types::BackendId::from_index(0),
924            "test".to_string(),
925            provider,
926        );
927
928        // Create client
929        let client_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
930        let client_addr = client_listener.local_addr().unwrap();
931
932        let buffer_pool = BufferPool::new(1024, 4);
933        let session = ClientSession::new_with_router(client_addr, buffer_pool, Arc::new(router));
934
935        // Client that sends QUIT and waits for response
936        let client_handle = tokio::spawn(async move {
937            let mut client = tokio::net::TcpStream::connect(client_addr).await.unwrap();
938
939            // Read greeting
940            let mut buf = [0u8; 256];
941            let n = client.read(&mut buf).await.unwrap();
942            assert!(n > 0, "Should receive greeting");
943
944            // Send QUIT
945            client.write_all(b"QUIT\r\n").await.unwrap();
946
947            // Read closing response
948            let n = client.read(&mut buf).await.unwrap();
949
950            // Should receive "205 Connection closing"
951            let response = String::from_utf8_lossy(&buf[..n]);
952            assert!(
953                response.contains("205"),
954                "Should receive 205 closing response"
955            );
956        });
957
958        let (client_stream, _) = client_listener.accept().await.unwrap();
959        let result = session.handle_per_command_routing(client_stream).await;
960
961        assert!(result.is_ok(), "Session should handle QUIT cleanly");
962
963        client_handle.await.unwrap();
964    }
965}