1use super::common::{SMALL_TRANSFER_THRESHOLD, extract_message_id};
8use crate::pool::PooledBuffer;
9use crate::session::{ClientSession, backend, connection, streaming};
10use anyhow::Result;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tracing::{debug, error, info, warn};
14
15use crate::command::{CommandHandler, NntpCommand};
16use crate::config::RoutingMode;
17use crate::constants::buffer::COMMAND;
18use crate::protocol::{BACKEND_ERROR, CONNECTION_CLOSING, PROXY_GREETING_PCR};
19use crate::router::BackendSelector;
20use crate::types::{BytesTransferred, TransferMetrics};
21
22impl ClientSession {
23 pub async fn handle_per_command_routing(
26 &self,
27 mut client_stream: TcpStream,
28 ) -> Result<(u64, u64)> {
29 use tokio::io::BufReader;
30
31 debug!(
32 "Client {} starting per-command routing session",
33 self.client_addr
34 );
35
36 let router = self
37 .router
38 .as_ref()
39 .ok_or_else(|| anyhow::anyhow!("Per-command routing mode requires a router"))?;
40
41 let (client_read, mut client_write) = client_stream.split();
42 let mut client_reader = BufReader::new(client_read);
43
44 let mut client_to_backend_bytes = BytesTransferred::zero();
45 let mut backend_to_client_bytes = BytesTransferred::zero();
46
47 let mut auth_username: Option<String> = None;
49
50 debug!(
52 "Client {} sending greeting: {} | hex: {:02x?}",
53 self.client_addr,
54 String::from_utf8_lossy(PROXY_GREETING_PCR),
55 PROXY_GREETING_PCR
56 );
57
58 if let Err(e) = client_write.write_all(PROXY_GREETING_PCR).await {
59 debug!(
60 "Client {} failed to send greeting: {} (kind: {:?}). \
61 This suggests the client disconnected immediately after connecting.",
62 self.client_addr,
63 e,
64 e.kind()
65 );
66 return Err(e.into());
67 }
68 backend_to_client_bytes.add(PROXY_GREETING_PCR.len());
69
70 debug!(
71 "Client {} sent greeting successfully, entering command loop",
72 self.client_addr
73 );
74
75 let mut command = String::with_capacity(COMMAND);
77
78 let mut skip_auth_check = !self.auth_handler.is_enabled();
81
82 loop {
84 command.clear();
85
86 let n = match client_reader.read_line(&mut command).await {
87 Ok(0) => {
88 debug!("Client {} disconnected", self.client_addr);
89 break;
90 }
91 Ok(n) => n,
92 Err(e) => {
93 connection::log_client_error(
94 self.client_addr,
95 &e,
96 TransferMetrics {
97 client_to_backend: client_to_backend_bytes,
98 backend_to_client: backend_to_client_bytes,
99 },
100 );
101 break;
102 }
103 };
104
105 client_to_backend_bytes.add(n);
106 let trimmed = command.trim();
107
108 debug!(
109 "Client {} received command ({} bytes): {} | hex: {:02x?}",
110 self.client_addr,
111 n,
112 trimmed,
113 command.as_bytes()
114 );
115
116 if trimmed.eq_ignore_ascii_case("QUIT") {
118 let _ = client_write
119 .write_all(CONNECTION_CLOSING)
120 .await
121 .inspect_err(|e| {
122 debug!(
123 "Failed to write CONNECTION_CLOSING to client {}: {}",
124 self.client_addr, e
125 );
126 });
127 backend_to_client_bytes.add(CONNECTION_CLOSING.len());
128 debug!("Client {} sent QUIT, closing connection", self.client_addr);
129 break;
130 }
131
132 let action = CommandHandler::handle_command(&command);
133
134 if matches!(action, CommandAction::InterceptAuth(_)) {
137 match action {
138 CommandAction::InterceptAuth(auth_action) => {
139 if let crate::command::AuthAction::RequestPassword(ref username) =
141 auth_action
142 {
143 auth_username = Some(username.clone());
144 }
145
146 let (bytes, auth_success) = self
148 .auth_handler
149 .handle_auth_command(
150 auth_action,
151 &mut client_write,
152 auth_username.as_deref(),
153 )
154 .await?;
155 backend_to_client_bytes.add(bytes);
156
157 if auth_success {
158 skip_auth_check = true;
159 self.authenticated
160 .store(true, std::sync::atomic::Ordering::Release);
161 }
162 }
163 _ => unreachable!(),
164 }
165 continue;
166 }
167
168 skip_auth_check = skip_auth_check
172 || self
173 .authenticated
174 .load(std::sync::atomic::Ordering::Acquire);
175 if skip_auth_check {
176 self.route_command_with_error_handling(
178 router,
179 &command,
180 &mut client_write,
181 &mut client_to_backend_bytes,
182 &mut backend_to_client_bytes,
183 trimmed,
184 )
185 .await?;
186 continue;
187 }
188
189 if self.routing_mode == RoutingMode::Hybrid
193 && matches!(action, CommandAction::Reject(_))
194 && NntpCommand::classify(&command).is_stateful()
195 {
196 info!(
197 "Client {} switching to stateful mode (command: {})",
198 self.client_addr, trimmed
199 );
200 return self
201 .switch_to_stateful_mode(
202 client_reader,
203 client_write,
204 &command,
205 client_to_backend_bytes,
206 backend_to_client_bytes,
207 )
208 .await;
209 }
210
211 use crate::command::CommandAction;
214 match action {
215 CommandAction::ForwardStateless => {
216 if self.auth_handler.is_enabled() {
218 let response = b"480 Authentication required\r\n";
220 client_write.write_all(response).await?;
221 backend_to_client_bytes.add(response.len());
222 } else {
223 self.route_command_with_error_handling(
225 router,
226 &command,
227 &mut client_write,
228 &mut client_to_backend_bytes,
229 &mut backend_to_client_bytes,
230 trimmed,
231 )
232 .await?;
233 }
234 }
235 CommandAction::Reject(response) => {
236 client_write.write_all(response.as_bytes()).await?;
238 backend_to_client_bytes.add(response.len());
239 }
240 CommandAction::InterceptAuth(_) => {
241 unreachable!("Auth commands should be handled before reaching here");
243 }
244 }
245 }
246
247 if (client_to_backend_bytes + backend_to_client_bytes).as_u64() < SMALL_TRANSFER_THRESHOLD {
249 debug!(
250 "Session summary {} | ↑{} ↓{} | Short session (likely test connection)",
251 self.client_addr,
252 crate::formatting::format_bytes(client_to_backend_bytes.as_u64()),
253 crate::formatting::format_bytes(backend_to_client_bytes.as_u64())
254 );
255 }
256
257 Ok((
258 client_to_backend_bytes.as_u64(),
259 backend_to_client_bytes.as_u64(),
260 ))
261 }
262
263 pub(super) async fn route_and_execute_command(
268 &self,
269 router: &BackendSelector,
270 command: &str,
271 client_write: &mut tokio::net::tcp::WriteHalf<'_>,
272 client_to_backend_bytes: &mut BytesTransferred,
273 backend_to_client_bytes: &mut BytesTransferred,
274 ) -> Result<crate::types::BackendId> {
275 use crate::pool::{is_connection_error, remove_from_pool};
276
277 let mut buffer = self.buffer_pool.get_buffer().await;
279
280 let backend_id = router.route_command_sync(self.client_id, command)?;
282
283 debug!(
284 "Client {} routed command to backend {:?}: {}",
285 self.client_addr,
286 backend_id,
287 command.trim()
288 );
289
290 let provider = router
292 .get_backend_provider(backend_id)
293 .ok_or_else(|| anyhow::anyhow!("Backend {:?} not found", backend_id))?;
294
295 debug!(
296 "Client {} getting pooled connection for backend {:?}",
297 self.client_addr, backend_id
298 );
299
300 let mut pooled_conn = provider.get_pooled_connection().await?;
301
302 debug!(
303 "Client {} got pooled connection for backend {:?}",
304 self.client_addr, backend_id
305 );
306
307 let (result, got_backend_data) = self
310 .execute_command_on_backend(
311 &mut pooled_conn,
312 command,
313 client_write,
314 backend_id,
315 client_to_backend_bytes,
316 backend_to_client_bytes,
317 &mut buffer, )
319 .await;
320
321 let _ = result
326 .as_ref()
327 .err()
328 .filter(|e| is_connection_error(e))
329 .inspect(|e| {
330 match got_backend_data {
331 true => debug!(
332 "Client {} disconnected while receiving data from backend {:?} - backend connection is healthy",
333 self.client_addr, backend_id
334 ),
335 false => warn!(
336 "Backend connection error for client {}, backend {:?}: {} - removing connection from pool",
337 self.client_addr, backend_id, e
338 ),
339 }
340 })
341 .filter(|_| !got_backend_data)
342 .is_some_and(|_| { remove_from_pool(pooled_conn); true });
343
344 router.complete_command_sync(backend_id);
346
347 result.map(|_| backend_id)
348 }
349
350 #[allow(clippy::too_many_arguments)]
373 pub(super) async fn execute_command_on_backend(
374 &self,
375 pooled_conn: &mut deadpool::managed::Object<crate::pool::deadpool_connection::TcpManager>,
376 command: &str,
377 client_write: &mut tokio::net::tcp::WriteHalf<'_>,
378 backend_id: crate::types::BackendId,
379 client_to_backend_bytes: &mut BytesTransferred,
380 backend_to_client_bytes: &mut BytesTransferred,
381 chunk_buffer: &mut PooledBuffer, ) -> (Result<()>, bool) {
383 let mut got_backend_data = false;
384
385 let (n, _response_code, is_multiline) = match backend::send_command_and_read_first_chunk(
387 &mut **pooled_conn,
388 command,
389 backend_id,
390 self.client_addr,
391 chunk_buffer,
392 )
393 .await
394 {
395 Ok(result) => {
396 got_backend_data = true;
397 result
398 }
399 Err(e) => return (Err(e), got_backend_data),
400 };
401
402 client_to_backend_bytes.add(command.len());
403
404 let msgid = extract_message_id(command);
406
407 let bytes_written = if is_multiline {
409 let log_msg = if let Some(id) = msgid {
410 format!(
411 "Client {} ARTICLE {} → multiline ({:?}), streaming {}",
412 self.client_addr,
413 id,
414 _response_code.status_code(),
415 crate::formatting::format_bytes(n as u64)
416 )
417 } else {
418 format!(
419 "Client {} '{}' → multiline ({:?}), streaming {}",
420 self.client_addr,
421 command.trim(),
422 _response_code.status_code(),
423 crate::formatting::format_bytes(n as u64)
424 )
425 };
426 debug!("{}", log_msg);
427
428 match streaming::stream_multiline_response(
429 &mut **pooled_conn,
430 client_write,
431 &chunk_buffer[..n], n,
433 self.client_addr,
434 backend_id,
435 &self.buffer_pool,
436 )
437 .await
438 {
439 Ok(bytes) => bytes,
440 Err(e) => return (Err(e), got_backend_data),
441 }
442 } else {
443 let log_msg = if let Some(id) = msgid {
445 if let Some(code) = _response_code.status_code() {
447 if (400..500).contains(&code) {
448 format!(
449 "Client {} ARTICLE {} → error {} (single-line), writing {}",
450 self.client_addr,
451 id,
452 code,
453 crate::formatting::format_bytes(n as u64)
454 )
455 } else {
456 format!(
457 "Client {} ARTICLE {} → UNUSUAL single-line ({:?}), writing {}: {:02x?}",
458 self.client_addr,
459 id,
460 _response_code.status_code(),
461 crate::formatting::format_bytes(n as u64),
462 &chunk_buffer[..n.min(50)]
463 )
464 }
465 } else {
466 format!(
467 "Client {} ARTICLE {} → UNUSUAL single-line ({:?}), writing {}: {:02x?}",
468 self.client_addr,
469 id,
470 _response_code.status_code(),
471 crate::formatting::format_bytes(n as u64),
472 &chunk_buffer[..n.min(50)]
473 )
474 }
475 } else {
476 format!(
477 "Client {} '{}' → single-line ({:?}), writing {}: {:02x?}",
478 self.client_addr,
479 command.trim(),
480 _response_code.status_code(),
481 crate::formatting::format_bytes(n as u64),
482 &chunk_buffer[..n.min(50)]
483 )
484 };
485
486 if let Some(code) = _response_code.status_code() {
488 if code >= 400 {
489 debug!("{}", log_msg); } else if msgid.is_some() {
491 warn!("{}", log_msg); } else {
493 debug!("{}", log_msg);
494 }
495 } else {
496 debug!("{}", log_msg);
497 }
498
499 match client_write.write_all(&chunk_buffer[..n]).await {
500 Ok(_) => n as u64,
501 Err(e) => return (Err(e.into()), got_backend_data),
502 }
503 };
504
505 backend_to_client_bytes.add(bytes_written as usize);
506
507 if let Some(id) = msgid {
508 debug!(
509 "Client {} ARTICLE {} completed: wrote {} bytes to client",
510 self.client_addr, id, bytes_written
511 );
512 }
513
514 (Ok(()), got_backend_data)
515 }
516
517 async fn route_command_with_error_handling(
526 &self,
527 router: &BackendSelector,
528 command: &str,
529 client_write: &mut tokio::net::tcp::WriteHalf<'_>,
530 client_to_backend_bytes: &mut BytesTransferred,
531 backend_to_client_bytes: &mut BytesTransferred,
532 trimmed: &str,
533 ) -> Result<()> {
534 if let Err(e) = self
535 .route_and_execute_command(
536 router,
537 command,
538 client_write,
539 client_to_backend_bytes,
540 backend_to_client_bytes,
541 )
542 .await
543 {
544 use crate::session::error_classification::ErrorClassifier;
545 let (up, down) = (
546 crate::formatting::format_bytes(client_to_backend_bytes.as_u64()),
547 crate::formatting::format_bytes(backend_to_client_bytes.as_u64()),
548 );
549
550 if ErrorClassifier::is_client_disconnect(&e) {
552 debug!(
553 "Client {} command '{}' resulted in disconnect (already logged by streaming layer) | ↑{} ↓{}",
554 self.client_addr, trimmed, up, down
555 );
556 } else {
557 if ErrorClassifier::is_authentication_error(&e) {
558 error!(
559 "Client {} command '{}' authentication error: {} | ↑{} ↓{}",
560 self.client_addr, trimmed, e, up, down
561 );
562 } else {
563 warn!(
564 "Client {} error routing '{}': {} | ↑{} ↓{}",
565 self.client_addr, trimmed, e, up, down
566 );
567 }
568 let _ = client_write.write_all(BACKEND_ERROR).await;
569 backend_to_client_bytes.add(BACKEND_ERROR.len());
570 }
571
572 if (*client_to_backend_bytes + *backend_to_client_bytes).as_u64()
574 < SMALL_TRANSFER_THRESHOLD
575 {
576 debug!(
577 "ERROR SUMMARY for small transfer - Client {}: Command '{}' failed with {}. \
578 Total session: {} bytes to backend, {} bytes from backend. \
579 This appears to be a short session (test connection?). \
580 Check debug logs above for full command/response hex dumps.",
581 self.client_addr, trimmed, e, client_to_backend_bytes, backend_to_client_bytes
582 );
583 }
584
585 return Err(e);
586 }
587 Ok(())
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_extract_message_id_valid() {
597 let command = "BODY <test@example.com>";
598 let msgid = extract_message_id(command);
599 assert_eq!(msgid, Some("<test@example.com>"));
600 }
601
602 #[test]
603 fn test_extract_message_id_article_command() {
604 let command = "ARTICLE <1234@news.server>";
605 let msgid = extract_message_id(command);
606 assert_eq!(msgid, Some("<1234@news.server>"));
607 }
608
609 #[test]
610 fn test_extract_message_id_head_command() {
611 let command = "HEAD <article@host.domain>";
612 let msgid = extract_message_id(command);
613 assert_eq!(msgid, Some("<article@host.domain>"));
614 }
615
616 #[test]
617 fn test_extract_message_id_stat_command() {
618 let command = "STAT <msg@example.org>";
619 let msgid = extract_message_id(command);
620 assert_eq!(msgid, Some("<msg@example.org>"));
621 }
622
623 #[test]
624 fn test_extract_message_id_no_brackets() {
625 let command = "GROUP comp.lang.rust";
626 let msgid = extract_message_id(command);
627 assert_eq!(msgid, None);
628 }
629
630 #[test]
631 fn test_extract_message_id_malformed() {
632 let command = "BODY <incomplete";
633 let msgid = extract_message_id(command);
634 assert_eq!(msgid, None);
635 }
636
637 #[test]
638 fn test_extract_message_id_with_extra_text() {
639 let command = "BODY <msg@host> extra stuff";
640 let msgid = extract_message_id(command);
641 assert_eq!(msgid, Some("<msg@host>"));
642 }
643
644 #[test]
645 fn test_extract_message_id_empty_brackets() {
646 let command = "BODY <>";
647 let msgid = extract_message_id(command);
648 assert_eq!(msgid, Some("<>"));
649 }
650
651 #[test]
652 fn test_extract_message_id_lowercase_command() {
653 let command = "body <test@example.com>";
654 let msgid = extract_message_id(command);
655 assert_eq!(msgid, Some("<test@example.com>"));
656 }
657
658 #[test]
659 fn test_extract_message_id_mixed_case() {
660 let command = "BoDy <TeSt@ExAmPlE.cOm>";
661 let msgid = extract_message_id(command);
662 assert_eq!(msgid, Some("<TeSt@ExAmPlE.cOm>"));
663 }
664}