nntp_proxy/session/handlers/
per_command.rs

1//! Per-command routing mode handler and command execution
2//!
3//! This module implements independent per-command routing where each command
4//! can be routed to a different backend. It includes the core command execution
5//! logic used by all routing modes.
6
7use super::common::{SMALL_TRANSFER_THRESHOLD, extract_message_id};
8use crate::pool::PooledBuffer;
9use crate::session::{ClientSession, backend, connection, streaming};
10use anyhow::Result;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tracing::{debug, error, info, warn};
14
15use crate::command::{CommandHandler, NntpCommand};
16use crate::config::RoutingMode;
17use crate::constants::buffer::COMMAND;
18use crate::protocol::{BACKEND_ERROR, CONNECTION_CLOSING, PROXY_GREETING_PCR};
19use crate::router::BackendSelector;
20use crate::types::{BytesTransferred, TransferMetrics};
21
22impl ClientSession {
23    /// Handle a client connection with per-command routing
24    /// Each command is routed independently to potentially different backends
25    pub async fn handle_per_command_routing(
26        &self,
27        mut client_stream: TcpStream,
28    ) -> Result<(u64, u64)> {
29        use tokio::io::BufReader;
30
31        debug!(
32            "Client {} starting per-command routing session",
33            self.client_addr
34        );
35
36        let router = self
37            .router
38            .as_ref()
39            .ok_or_else(|| anyhow::anyhow!("Per-command routing mode requires a router"))?;
40
41        let (client_read, mut client_write) = client_stream.split();
42        let mut client_reader = BufReader::new(client_read);
43
44        let mut client_to_backend_bytes = BytesTransferred::zero();
45        let mut backend_to_client_bytes = BytesTransferred::zero();
46
47        // Auth state: username from AUTHINFO USER command
48        let mut auth_username: Option<String> = None;
49
50        // Send initial greeting to client
51        debug!(
52            "Client {} sending greeting: {} | hex: {:02x?}",
53            self.client_addr,
54            String::from_utf8_lossy(PROXY_GREETING_PCR),
55            PROXY_GREETING_PCR
56        );
57
58        if let Err(e) = client_write.write_all(PROXY_GREETING_PCR).await {
59            debug!(
60                "Client {} failed to send greeting: {} (kind: {:?}). \
61                 This suggests the client disconnected immediately after connecting.",
62                self.client_addr,
63                e,
64                e.kind()
65            );
66            return Err(e.into());
67        }
68        backend_to_client_bytes.add(PROXY_GREETING_PCR.len());
69
70        debug!(
71            "Client {} sent greeting successfully, entering command loop",
72            self.client_addr
73        );
74
75        // Reuse command buffer to avoid allocations per command
76        let mut command = String::with_capacity(COMMAND);
77
78        // PERFORMANCE: Cache authenticated state to avoid atomic loads after auth succeeds
79        // If auth is disabled, skip checks from the start
80        let mut skip_auth_check = !self.auth_handler.is_enabled();
81
82        // Process commands one at a time
83        loop {
84            command.clear();
85
86            let n = match client_reader.read_line(&mut command).await {
87                Ok(0) => {
88                    debug!("Client {} disconnected", self.client_addr);
89                    break;
90                }
91                Ok(n) => n,
92                Err(e) => {
93                    connection::log_client_error(
94                        self.client_addr,
95                        &e,
96                        TransferMetrics {
97                            client_to_backend: client_to_backend_bytes,
98                            backend_to_client: backend_to_client_bytes,
99                        },
100                    );
101                    break;
102                }
103            };
104
105            client_to_backend_bytes.add(n);
106            let trimmed = command.trim();
107
108            debug!(
109                "Client {} received command ({} bytes): {} | hex: {:02x?}",
110                self.client_addr,
111                n,
112                trimmed,
113                command.as_bytes()
114            );
115
116            // Handle QUIT locally
117            if trimmed.eq_ignore_ascii_case("QUIT") {
118                let _ = client_write
119                    .write_all(CONNECTION_CLOSING)
120                    .await
121                    .inspect_err(|e| {
122                        debug!(
123                            "Failed to write CONNECTION_CLOSING to client {}: {}",
124                            self.client_addr, e
125                        );
126                    });
127                backend_to_client_bytes.add(CONNECTION_CLOSING.len());
128                debug!("Client {} sent QUIT, closing connection", self.client_addr);
129                break;
130            }
131
132            let action = CommandHandler::handle_command(&command);
133
134            // ALWAYS intercept auth commands, even when auth is disabled
135            // Auth commands must NEVER be forwarded to backend
136            if matches!(action, CommandAction::InterceptAuth(_)) {
137                match action {
138                    CommandAction::InterceptAuth(auth_action) => {
139                        // Store username if this is AUTHINFO USER
140                        if let crate::command::AuthAction::RequestPassword(ref username) =
141                            auth_action
142                        {
143                            auth_username = Some(username.clone());
144                        }
145
146                        // Handle auth and validate
147                        let (bytes, auth_success) = self
148                            .auth_handler
149                            .handle_auth_command(
150                                auth_action,
151                                &mut client_write,
152                                auth_username.as_deref(),
153                            )
154                            .await?;
155                        backend_to_client_bytes.add(bytes);
156
157                        if auth_success {
158                            skip_auth_check = true;
159                            self.authenticated
160                                .store(true, std::sync::atomic::Ordering::Release);
161                        }
162                    }
163                    _ => unreachable!(),
164                }
165                continue;
166            }
167
168            // PERFORMANCE OPTIMIZATION: Fast path after authentication
169            // Once authenticated, skip classification (just route everything)
170            // Cache check to avoid atomic load on hot path
171            skip_auth_check = skip_auth_check
172                || self
173                    .authenticated
174                    .load(std::sync::atomic::Ordering::Acquire);
175            if skip_auth_check {
176                // Already authenticated - just route the command (HOT PATH after auth)
177                self.route_command_with_error_handling(
178                    router,
179                    &command,
180                    &mut client_write,
181                    &mut client_to_backend_bytes,
182                    &mut backend_to_client_bytes,
183                    trimmed,
184                )
185                .await?;
186                continue;
187            }
188
189            // Not yet authenticated - need to check for auth/stateful commands
190
191            // In hybrid mode, stateful commands trigger a switch to stateful connection
192            if self.routing_mode == RoutingMode::Hybrid
193                && matches!(action, CommandAction::Reject(_))
194                && NntpCommand::classify(&command).is_stateful()
195            {
196                info!(
197                    "Client {} switching to stateful mode (command: {})",
198                    self.client_addr, trimmed
199                );
200                return self
201                    .switch_to_stateful_mode(
202                        client_reader,
203                        client_write,
204                        &command,
205                        client_to_backend_bytes,
206                        backend_to_client_bytes,
207                    )
208                    .await;
209            }
210
211            // Handle the command - inline for performance
212            // Check ForwardStateless FIRST (70%+ of traffic)
213            use crate::command::CommandAction;
214            match action {
215                CommandAction::ForwardStateless => {
216                    // Check if auth is required but not completed
217                    if self.auth_handler.is_enabled() {
218                        // Reject all non-auth commands before authentication
219                        let response = b"480 Authentication required\r\n";
220                        client_write.write_all(response).await?;
221                        backend_to_client_bytes.add(response.len());
222                    } else {
223                        // Auth disabled - forward to backend via router (HOT PATH - 70%+ of commands)
224                        self.route_command_with_error_handling(
225                            router,
226                            &command,
227                            &mut client_write,
228                            &mut client_to_backend_bytes,
229                            &mut backend_to_client_bytes,
230                            trimmed,
231                        )
232                        .await?;
233                    }
234                }
235                CommandAction::Reject(response) => {
236                    // Send rejection response inline
237                    client_write.write_all(response.as_bytes()).await?;
238                    backend_to_client_bytes.add(response.len());
239                }
240                CommandAction::InterceptAuth(_) => {
241                    // Already handled above before fast path check
242                    unreachable!("Auth commands should be handled before reaching here");
243                }
244            }
245        }
246
247        // Log session summary for debugging, especially useful for test connections
248        if (client_to_backend_bytes + backend_to_client_bytes).as_u64() < SMALL_TRANSFER_THRESHOLD {
249            debug!(
250                "Session summary {} | ↑{} ↓{} | Short session (likely test connection)",
251                self.client_addr,
252                crate::formatting::format_bytes(client_to_backend_bytes.as_u64()),
253                crate::formatting::format_bytes(backend_to_client_bytes.as_u64())
254            );
255        }
256
257        Ok((
258            client_to_backend_bytes.as_u64(),
259            backend_to_client_bytes.as_u64(),
260        ))
261    }
262
263    /// Route a single command to a backend and execute it
264    ///
265    /// This function is `pub(super)` to allow reuse of per-command routing logic by sibling handler modules
266    /// (such as `hybrid.rs`) that also need to route commands.
267    pub(super) async fn route_and_execute_command(
268        &self,
269        router: &BackendSelector,
270        command: &str,
271        client_write: &mut tokio::net::tcp::WriteHalf<'_>,
272        client_to_backend_bytes: &mut BytesTransferred,
273        backend_to_client_bytes: &mut BytesTransferred,
274    ) -> Result<crate::types::BackendId> {
275        use crate::pool::{is_connection_error, remove_from_pool};
276
277        // Get reusable buffer from pool (eliminates 64KB Vec allocation on every command!)
278        let mut buffer = self.buffer_pool.get_buffer().await;
279
280        // Route the command to get a backend (lock-free!)
281        let backend_id = router.route_command_sync(self.client_id, command)?;
282
283        debug!(
284            "Client {} routed command to backend {:?}: {}",
285            self.client_addr,
286            backend_id,
287            command.trim()
288        );
289
290        // Get a connection from the router's backend pool
291        let provider = router
292            .get_backend_provider(backend_id)
293            .ok_or_else(|| anyhow::anyhow!("Backend {:?} not found", backend_id))?;
294
295        debug!(
296            "Client {} getting pooled connection for backend {:?}",
297            self.client_addr, backend_id
298        );
299
300        let mut pooled_conn = provider.get_pooled_connection().await?;
301
302        debug!(
303            "Client {} got pooled connection for backend {:?}",
304            self.client_addr, backend_id
305        );
306
307        // Execute the command - returns (result, got_backend_data)
308        // If got_backend_data is true, we successfully communicated with backend
309        let (result, got_backend_data) = self
310            .execute_command_on_backend(
311                &mut pooled_conn,
312                command,
313                client_write,
314                backend_id,
315                client_to_backend_bytes,
316                backend_to_client_bytes,
317                &mut buffer, // Pass reusable buffer
318            )
319            .await;
320
321        // Return buffer to pool before handling result
322
323        // Only remove backend connection if error occurred AND we didn't get data from backend
324        // If we got data from backend, then any error is from writing to client
325        let _ = result
326            .as_ref()
327            .err()
328            .filter(|e| is_connection_error(e))
329            .inspect(|e| {
330                match got_backend_data {
331                    true => debug!(
332                        "Client {} disconnected while receiving data from backend {:?} - backend connection is healthy",
333                        self.client_addr, backend_id
334                    ),
335                    false => warn!(
336                        "Backend connection error for client {}, backend {:?}: {} - removing connection from pool",
337                        self.client_addr, backend_id, e
338                    ),
339                }
340            })
341            .filter(|_| !got_backend_data)
342            .is_some_and(|_| { remove_from_pool(pooled_conn); true });
343
344        // Complete the request - decrement pending count (lock-free!)
345        router.complete_command_sync(backend_id);
346
347        result.map(|_| backend_id)
348    }
349
350    /// Execute a command on a backend connection and stream the response to the client
351    ///
352    /// # Performance Critical Hot Path
353    ///
354    /// This function implements **pipelined streaming** for NNTP responses, which is essential
355    /// for high-throughput article downloads. The double-buffering approach allows reading the
356    /// next chunk from the backend while writing the current chunk to the client concurrently.
357    ///
358    /// **DO NOT refactor this to buffer entire responses** - that would kill performance:
359    /// - Large articles (50MB+) would be fully buffered before streaming to client
360    /// - No concurrent I/O = sequential read-then-write instead of pipelined read+write
361    /// - Performance drops from 100+ MB/s to < 1 MB/s
362    ///
363    /// The complexity here is justified by the 100x+ performance gain on large transfers.
364    ///
365    /// # Return Value
366    ///
367    /// Returns `(Result<()>, got_backend_data)` where:
368    /// - `got_backend_data = true` means we successfully read from backend before any error
369    /// - This distinguishes backend failures (remove from pool) from client disconnects (keep backend)
370    ///
371    /// This function is `pub(super)` and is intended for use by `hybrid.rs` for stateful mode command execution.
372    #[allow(clippy::too_many_arguments)]
373    pub(super) async fn execute_command_on_backend(
374        &self,
375        pooled_conn: &mut deadpool::managed::Object<crate::pool::deadpool_connection::TcpManager>,
376        command: &str,
377        client_write: &mut tokio::net::tcp::WriteHalf<'_>,
378        backend_id: crate::types::BackendId,
379        client_to_backend_bytes: &mut BytesTransferred,
380        backend_to_client_bytes: &mut BytesTransferred,
381        chunk_buffer: &mut PooledBuffer, // Reusable buffer from pool
382    ) -> (Result<()>, bool) {
383        let mut got_backend_data = false;
384
385        // Send command and read first chunk into reusable buffer
386        let (n, _response_code, is_multiline) = match backend::send_command_and_read_first_chunk(
387            &mut **pooled_conn,
388            command,
389            backend_id,
390            self.client_addr,
391            chunk_buffer,
392        )
393        .await
394        {
395            Ok(result) => {
396                got_backend_data = true;
397                result
398            }
399            Err(e) => return (Err(e), got_backend_data),
400        };
401
402        client_to_backend_bytes.add(command.len());
403
404        // Extract message-ID from command if present (for correlation with SABnzbd errors)
405        let msgid = extract_message_id(command);
406
407        // For multiline responses, use pipelined streaming
408        let bytes_written = if is_multiline {
409            let log_msg = if let Some(id) = msgid {
410                format!(
411                    "Client {} ARTICLE {} → multiline ({:?}), streaming {}",
412                    self.client_addr,
413                    id,
414                    _response_code.status_code(),
415                    crate::formatting::format_bytes(n as u64)
416                )
417            } else {
418                format!(
419                    "Client {} '{}' → multiline ({:?}), streaming {}",
420                    self.client_addr,
421                    command.trim(),
422                    _response_code.status_code(),
423                    crate::formatting::format_bytes(n as u64)
424                )
425            };
426            debug!("{}", log_msg);
427
428            match streaming::stream_multiline_response(
429                &mut **pooled_conn,
430                client_write,
431                &chunk_buffer[..n], // Use buffer slice instead of owned Vec
432                n,
433                self.client_addr,
434                backend_id,
435                &self.buffer_pool,
436            )
437            .await
438            {
439                Ok(bytes) => bytes,
440                Err(e) => return (Err(e), got_backend_data),
441            }
442        } else {
443            // Single-line response - just write the first chunk
444            let log_msg = if let Some(id) = msgid {
445                // 430 (No such article) and other 4xx errors are expected single-line responses
446                if let Some(code) = _response_code.status_code() {
447                    if (400..500).contains(&code) {
448                        format!(
449                            "Client {} ARTICLE {} → error {} (single-line), writing {}",
450                            self.client_addr,
451                            id,
452                            code,
453                            crate::formatting::format_bytes(n as u64)
454                        )
455                    } else {
456                        format!(
457                            "Client {} ARTICLE {} → UNUSUAL single-line ({:?}), writing {}: {:02x?}",
458                            self.client_addr,
459                            id,
460                            _response_code.status_code(),
461                            crate::formatting::format_bytes(n as u64),
462                            &chunk_buffer[..n.min(50)]
463                        )
464                    }
465                } else {
466                    format!(
467                        "Client {} ARTICLE {} → UNUSUAL single-line ({:?}), writing {}: {:02x?}",
468                        self.client_addr,
469                        id,
470                        _response_code.status_code(),
471                        crate::formatting::format_bytes(n as u64),
472                        &chunk_buffer[..n.min(50)]
473                    )
474                }
475            } else {
476                format!(
477                    "Client {} '{}' → single-line ({:?}), writing {}: {:02x?}",
478                    self.client_addr,
479                    command.trim(),
480                    _response_code.status_code(),
481                    crate::formatting::format_bytes(n as u64),
482                    &chunk_buffer[..n.min(50)]
483                )
484            };
485
486            // Only warn if it's truly unusual (not a 4xx/5xx error response)
487            if let Some(code) = _response_code.status_code() {
488                if code >= 400 {
489                    debug!("{}", log_msg); // Errors are expected, just debug
490                } else if msgid.is_some() {
491                    warn!("{}", log_msg); // ARTICLE with 2xx/3xx single-line is unusual
492                } else {
493                    debug!("{}", log_msg);
494                }
495            } else {
496                debug!("{}", log_msg);
497            }
498
499            match client_write.write_all(&chunk_buffer[..n]).await {
500                Ok(_) => n as u64,
501                Err(e) => return (Err(e.into()), got_backend_data),
502            }
503        };
504
505        backend_to_client_bytes.add(bytes_written as usize);
506
507        if let Some(id) = msgid {
508            debug!(
509                "Client {} ARTICLE {} completed: wrote {} bytes to client",
510                self.client_addr, id, bytes_written
511            );
512        }
513
514        (Ok(()), got_backend_data)
515    }
516
517    /// Route a command and handle any errors with appropriate logging and client responses
518    ///
519    /// This helper consolidates the error handling logic that was duplicated in multiple places.
520    /// It routes the command via the router, and if an error occurs, it:
521    /// - Classifies the error (client disconnect vs auth error vs other)
522    /// - Logs appropriately based on error type
523    /// - Sends BACKEND_ERROR response to client (if not disconnected)
524    /// - Includes debug logging for small transfers
525    async fn route_command_with_error_handling(
526        &self,
527        router: &BackendSelector,
528        command: &str,
529        client_write: &mut tokio::net::tcp::WriteHalf<'_>,
530        client_to_backend_bytes: &mut BytesTransferred,
531        backend_to_client_bytes: &mut BytesTransferred,
532        trimmed: &str,
533    ) -> Result<()> {
534        if let Err(e) = self
535            .route_and_execute_command(
536                router,
537                command,
538                client_write,
539                client_to_backend_bytes,
540                backend_to_client_bytes,
541            )
542            .await
543        {
544            use crate::session::error_classification::ErrorClassifier;
545            let (up, down) = (
546                crate::formatting::format_bytes(client_to_backend_bytes.as_u64()),
547                crate::formatting::format_bytes(backend_to_client_bytes.as_u64()),
548            );
549
550            // Log based on error type, send error response if client still connected
551            if ErrorClassifier::is_client_disconnect(&e) {
552                debug!(
553                    "Client {} command '{}' resulted in disconnect (already logged by streaming layer) | ↑{} ↓{}",
554                    self.client_addr, trimmed, up, down
555                );
556            } else {
557                if ErrorClassifier::is_authentication_error(&e) {
558                    error!(
559                        "Client {} command '{}' authentication error: {} | ↑{} ↓{}",
560                        self.client_addr, trimmed, e, up, down
561                    );
562                } else {
563                    warn!(
564                        "Client {} error routing '{}': {} | ↑{} ↓{}",
565                        self.client_addr, trimmed, e, up, down
566                    );
567                }
568                let _ = client_write.write_all(BACKEND_ERROR).await;
569                backend_to_client_bytes.add(BACKEND_ERROR.len());
570            }
571
572            // Debug logging for small transfers
573            if (*client_to_backend_bytes + *backend_to_client_bytes).as_u64()
574                < SMALL_TRANSFER_THRESHOLD
575            {
576                debug!(
577                    "ERROR SUMMARY for small transfer - Client {}: Command '{}' failed with {}. \
578                     Total session: {} bytes to backend, {} bytes from backend. \
579                     This appears to be a short session (test connection?). \
580                     Check debug logs above for full command/response hex dumps.",
581                    self.client_addr, trimmed, e, client_to_backend_bytes, backend_to_client_bytes
582                );
583            }
584
585            return Err(e);
586        }
587        Ok(())
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_extract_message_id_valid() {
597        let command = "BODY <test@example.com>";
598        let msgid = extract_message_id(command);
599        assert_eq!(msgid, Some("<test@example.com>"));
600    }
601
602    #[test]
603    fn test_extract_message_id_article_command() {
604        let command = "ARTICLE <1234@news.server>";
605        let msgid = extract_message_id(command);
606        assert_eq!(msgid, Some("<1234@news.server>"));
607    }
608
609    #[test]
610    fn test_extract_message_id_head_command() {
611        let command = "HEAD <article@host.domain>";
612        let msgid = extract_message_id(command);
613        assert_eq!(msgid, Some("<article@host.domain>"));
614    }
615
616    #[test]
617    fn test_extract_message_id_stat_command() {
618        let command = "STAT <msg@example.org>";
619        let msgid = extract_message_id(command);
620        assert_eq!(msgid, Some("<msg@example.org>"));
621    }
622
623    #[test]
624    fn test_extract_message_id_no_brackets() {
625        let command = "GROUP comp.lang.rust";
626        let msgid = extract_message_id(command);
627        assert_eq!(msgid, None);
628    }
629
630    #[test]
631    fn test_extract_message_id_malformed() {
632        let command = "BODY <incomplete";
633        let msgid = extract_message_id(command);
634        assert_eq!(msgid, None);
635    }
636
637    #[test]
638    fn test_extract_message_id_with_extra_text() {
639        let command = "BODY <msg@host> extra stuff";
640        let msgid = extract_message_id(command);
641        assert_eq!(msgid, Some("<msg@host>"));
642    }
643
644    #[test]
645    fn test_extract_message_id_empty_brackets() {
646        let command = "BODY <>";
647        let msgid = extract_message_id(command);
648        assert_eq!(msgid, Some("<>"));
649    }
650
651    #[test]
652    fn test_extract_message_id_lowercase_command() {
653        let command = "body <test@example.com>";
654        let msgid = extract_message_id(command);
655        assert_eq!(msgid, Some("<test@example.com>"));
656    }
657
658    #[test]
659    fn test_extract_message_id_mixed_case() {
660        let command = "BoDy <TeSt@ExAmPlE.cOm>";
661        let msgid = extract_message_id(command);
662        assert_eq!(msgid, Some("<TeSt@ExAmPlE.cOm>"));
663    }
664}