1use anyhow::Result;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpStream;
11use tracing::{debug, error, warn};
12
13use crate::auth::AuthHandler;
14use crate::command::{AuthAction, CommandAction, CommandHandler};
15use crate::constants::buffer::{COMMAND_SIZE, STREAMING_CHUNK_SIZE};
16use crate::constants::protocol::{
17 BACKEND_ERROR, CONNECTION_CLOSING, PROXY_GREETING_PCR, TERMINATOR_TAIL_SIZE,
18};
19use crate::constants::stateless_proxy::NNTP_COMMAND_NOT_SUPPORTED;
20use crate::pool::BufferPool;
21use crate::router::BackendSelector;
22use crate::streaming::StreamHandler;
23use crate::types::ClientId;
24
25pub struct ClientSession {
27 client_addr: SocketAddr,
28 buffer_pool: BufferPool,
29 client_id: ClientId,
31 router: Option<Arc<BackendSelector>>,
33}
34
35impl ClientSession {
36 #[must_use]
38 pub fn new(client_addr: SocketAddr, buffer_pool: BufferPool) -> Self {
39 Self {
40 client_addr,
41 buffer_pool,
42 client_id: ClientId::new(),
43 router: None,
44 }
45 }
46
47 #[must_use]
49 pub fn new_with_router(
50 client_addr: SocketAddr,
51 buffer_pool: BufferPool,
52 router: Arc<BackendSelector>,
53 ) -> Self {
54 Self {
55 client_addr,
56 buffer_pool,
57 client_id: ClientId::new(),
58 router: Some(router),
59 }
60 }
61
62 #[must_use]
64 #[inline]
65 pub fn client_id(&self) -> ClientId {
66 self.client_id
67 }
68
69 #[must_use]
71 #[inline]
72 pub fn is_per_command_routing(&self) -> bool {
73 self.router.is_some()
74 }
75
76 pub async fn handle_with_pooled_backend<T>(
80 &self,
81 mut client_stream: TcpStream,
82 backend_conn: T,
83 ) -> Result<(u64, u64)>
84 where
85 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
86 {
87 use tokio::io::BufReader;
88
89 let (client_read, mut client_write) = client_stream.split();
91 let (mut backend_read, mut backend_write) = tokio::io::split(backend_conn);
92 let mut client_reader = BufReader::new(client_read);
93
94 let mut client_to_backend_bytes = 0u64;
95 let mut backend_to_client_bytes = 0u64;
96
97 let mut line = String::with_capacity(COMMAND_SIZE);
99
100 debug!("Client {} session loop starting", self.client_addr);
101
102 loop {
104 line.clear();
105 let mut buffer = self.buffer_pool.get_buffer().await;
106
107 tokio::select! {
108 result = client_reader.read_line(&mut line) => {
110 match result {
111 Ok(0) => {
112 debug!("Client {} disconnected (0 bytes read)", self.client_addr);
113 self.buffer_pool.return_buffer(buffer).await;
114 break; }
116 Ok(n) => {
117 debug!("Client {} sent {} bytes: {:?}", self.client_addr, n, line.trim());
118 let trimmed = line.trim();
119 debug!("Client {} command: {}", self.client_addr, trimmed);
120
121 match CommandHandler::handle_command(&line) {
123 CommandAction::InterceptAuth(auth_action) => {
124 let response = match auth_action {
125 AuthAction::RequestPassword => AuthHandler::user_response(),
126 AuthAction::AcceptAuth => AuthHandler::pass_response(),
127 };
128 client_write.write_all(response).await?;
129 backend_to_client_bytes += response.len() as u64;
130 debug!("Intercepted auth command for client {}", self.client_addr);
131 }
132 CommandAction::Reject(_reason) => {
133 warn!("Rejecting command from client {}: {}", self.client_addr, trimmed);
134 client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
135 backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
136 }
137 CommandAction::ForwardHighThroughput => {
138 backend_write.write_all(line.as_bytes()).await?;
140 client_to_backend_bytes += line.len() as u64;
141 debug!("Client {} switching to high-throughput mode", self.client_addr);
142
143 self.buffer_pool.return_buffer(buffer).await;
145
146 return StreamHandler::high_throughput_transfer(
148 client_reader,
149 client_write,
150 backend_read,
151 backend_write,
152 client_to_backend_bytes,
153 backend_to_client_bytes,
154 ).await;
155 }
156 CommandAction::ForwardStateless => {
157 backend_write.write_all(line.as_bytes()).await?;
159 client_to_backend_bytes += line.len() as u64;
160 }
161 }
162 }
163 Err(e) => {
164 warn!("Error reading from client {}: {}", self.client_addr, e);
165 self.buffer_pool.return_buffer(buffer).await;
166 break;
167 }
168 }
169 }
170
171 result = backend_read.read(&mut buffer) => {
173 match result {
174 Ok(0) => {
175 self.buffer_pool.return_buffer(buffer).await;
176 break; }
178 Ok(n) => {
179 client_write.write_all(&buffer[..n]).await?;
180 backend_to_client_bytes += n as u64;
181 }
182 Err(e) => {
183 warn!("Error reading from backend for client {}: {}", self.client_addr, e);
184 self.buffer_pool.return_buffer(buffer).await;
185 break;
186 }
187 }
188 }
189 }
190
191 self.buffer_pool.return_buffer(buffer).await;
192 }
193
194 Ok((client_to_backend_bytes, backend_to_client_bytes))
195 }
196
197 pub async fn handle_per_command_routing(
200 &self,
201 mut client_stream: TcpStream,
202 ) -> Result<(u64, u64)> {
203 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
204
205 debug!(
206 "Client {} starting per-command routing session",
207 self.client_addr
208 );
209
210 let router = self
211 .router
212 .as_ref()
213 .ok_or_else(|| anyhow::anyhow!("Per-command routing mode requires a router"))?;
214
215 let (client_read, mut client_write) = client_stream.split();
216 let mut client_reader = BufReader::new(client_read);
217
218 let mut client_to_backend_bytes = 0u64;
219 let mut backend_to_client_bytes = 0u64;
220
221 debug!(
223 "Client {} sending greeting: {} | hex: {:02x?}",
224 self.client_addr,
225 String::from_utf8_lossy(PROXY_GREETING_PCR),
226 PROXY_GREETING_PCR
227 );
228
229 if let Err(e) = client_write.write_all(PROXY_GREETING_PCR).await {
230 debug!(
231 "Client {} failed to send greeting: {} (kind: {:?}). \
232 This suggests the client disconnected immediately after connecting.",
233 self.client_addr, e, e.kind()
234 );
235 return Err(e.into());
236 }
237 backend_to_client_bytes += PROXY_GREETING_PCR.len() as u64;
238
239 debug!(
240 "Client {} sent greeting successfully, entering command loop",
241 self.client_addr
242 );
243
244 let mut command = String::with_capacity(COMMAND_SIZE);
246
247 loop {
249 command.clear();
250
251 match client_reader.read_line(&mut command).await {
252 Ok(0) => {
253 debug!("Client {} disconnected", self.client_addr);
254 break; }
256 Ok(n) => {
257 client_to_backend_bytes += n as u64;
258 let trimmed = command.trim();
259
260 debug!(
261 "Client {} received command ({} bytes): {} | hex: {:02x?}",
262 self.client_addr, n, trimmed, command.as_bytes()
263 );
264
265 if trimmed.eq_ignore_ascii_case("QUIT") {
267 if let Err(e) = client_write.write_all(CONNECTION_CLOSING).await {
269 debug!(
270 "Failed to write CONNECTION_CLOSING to client {}: {}",
271 self.client_addr, e
272 );
273 }
274 backend_to_client_bytes += CONNECTION_CLOSING.len() as u64;
275 debug!("Client {} sent QUIT, closing connection", self.client_addr);
276 break;
277 }
278
279 match CommandHandler::handle_command(&command) {
281 CommandAction::InterceptAuth(auth_action) => {
282 let response = match auth_action {
284 AuthAction::RequestPassword => AuthHandler::user_response(),
285 AuthAction::AcceptAuth => AuthHandler::pass_response(),
286 };
287 client_write.write_all(response).await?;
288 backend_to_client_bytes += response.len() as u64;
289 continue;
290 }
291 CommandAction::Reject(reason) => {
292 warn!(
293 "Rejecting command from client {}: {} ({})",
294 self.client_addr, trimmed, reason
295 );
296 client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
297 backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
298 continue;
299 }
300 CommandAction::ForwardStateless | CommandAction::ForwardHighThroughput => {
301 match self
303 .route_and_execute_command(
304 router,
305 &command,
306 &mut client_write,
307 &mut client_to_backend_bytes,
308 &mut backend_to_client_bytes,
309 )
310 .await
311 {
312 Ok(()) => {}
313 Err(e) => {
314 if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
316 match io_err.kind() {
317 std::io::ErrorKind::BrokenPipe => {
318 warn!(
319 "Client {} disconnected unexpectedly while routing command '{}' (broken pipe). \
320 Session stats: {} bytes sent to backend, {} bytes received from backend. \
321 This usually indicates the client closed the connection before receiving the response.",
322 self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
323 );
324 }
325 std::io::ErrorKind::ConnectionReset => {
326 warn!(
327 "Client {} connection reset while routing command '{}'. \
328 Session stats: {} bytes sent to backend, {} bytes received from backend. \
329 This usually indicates a network issue or client crash.",
330 self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
331 );
332 }
333 std::io::ErrorKind::ConnectionAborted => {
334 warn!(
335 "Client {} connection aborted while routing command '{}'. \
336 Session stats: {} bytes sent to backend, {} bytes received from backend. \
337 This usually indicates the connection was terminated by the local system.",
338 self.client_addr, trimmed, client_to_backend_bytes, backend_to_client_bytes
339 );
340 }
341 _ => {
342 error!(
343 "I/O error routing command '{}' for client {}: {} (kind: {:?}). \
344 Session stats: {} bytes sent to backend, {} bytes received from backend.",
345 trimmed, self.client_addr, e, io_err.kind(), client_to_backend_bytes, backend_to_client_bytes
346 );
347 }
348 }
349 } else {
350 error!(
351 "Error routing command '{}' for client {}: {}. \
352 Session stats: {} bytes sent to backend, {} bytes received from backend.",
353 trimmed, self.client_addr, e, client_to_backend_bytes, backend_to_client_bytes
354 );
355 }
356
357 let _ = client_write.write_all(BACKEND_ERROR).await;
359 backend_to_client_bytes += BACKEND_ERROR.len() as u64;
360
361 if client_to_backend_bytes + backend_to_client_bytes < 500 {
363 debug!(
364 "ERROR SUMMARY for small transfer - Client {}: \
365 Command '{}' failed with {}. \
366 Total session: {} bytes to backend, {} bytes from backend. \
367 This appears to be a short session (test connection?). \
368 Check debug logs above for full command/response hex dumps.",
369 self.client_addr, trimmed, e, client_to_backend_bytes, backend_to_client_bytes
370 );
371 }
372 }
373 }
374 }
375 }
376 }
377 Err(e) => {
378 match e.kind() {
380 std::io::ErrorKind::UnexpectedEof => {
381 debug!(
382 "Client {} closed connection (EOF). Session stats: {} bytes sent to backend, {} bytes received from backend.",
383 self.client_addr, client_to_backend_bytes, backend_to_client_bytes
384 );
385 }
386 std::io::ErrorKind::BrokenPipe => {
387 debug!(
388 "Client {} connection broken pipe while reading. Session stats: {} bytes sent to backend, {} bytes received from backend.",
389 self.client_addr, client_to_backend_bytes, backend_to_client_bytes
390 );
391 }
392 std::io::ErrorKind::ConnectionReset => {
393 warn!(
394 "Client {} connection reset while reading. Session stats: {} bytes sent to backend, {} bytes received from backend.",
395 self.client_addr, client_to_backend_bytes, backend_to_client_bytes
396 );
397 }
398 _ => {
399 warn!(
400 "Error reading from client {}: {} (kind: {:?}). Session stats: {} bytes sent to backend, {} bytes received from backend.",
401 self.client_addr, e, e.kind(), client_to_backend_bytes, backend_to_client_bytes
402 );
403 }
404 }
405 break;
406 }
407 }
408 }
409
410 if client_to_backend_bytes + backend_to_client_bytes < 500 {
412 debug!(
413 "SESSION SUMMARY for Client {}: Small transfer completed successfully. \
414 {} bytes sent to backend, {} bytes received from backend. \
415 This appears to be a short session (likely test connection). \
416 Check debug logs above for individual command/response details.",
417 self.client_addr, client_to_backend_bytes, backend_to_client_bytes
418 );
419 }
420
421 Ok((client_to_backend_bytes, backend_to_client_bytes))
422 }
423
424 async fn route_and_execute_command(
426 &self,
427 router: &BackendSelector,
428 command: &str,
429 client_write: &mut tokio::net::tcp::WriteHalf<'_>,
430 client_to_backend_bytes: &mut u64,
431 backend_to_client_bytes: &mut u64,
432 ) -> Result<()> {
433 use tokio::io::AsyncWriteExt;
434
435 let backend_id = router.route_command_sync(self.client_id, command)?;
437
438 debug!(
439 "Client {} routed command to backend {:?}: {}",
440 self.client_addr,
441 backend_id,
442 command.trim()
443 );
444
445 let provider = router
447 .get_backend_provider(backend_id)
448 .ok_or_else(|| anyhow::anyhow!("Backend {:?} not found", backend_id))?;
449
450 debug!(
451 "Client {} getting pooled connection for backend {:?}",
452 self.client_addr, backend_id
453 );
454 let mut pooled_conn = provider.get_pooled_connection().await?;
458 debug!(
459 "Client {} got pooled connection for backend {:?}",
460 self.client_addr, backend_id
461 );
462
463 debug!(
467 "Client {} forwarding command to backend {:?} ({} bytes): {} | hex: {:02x?}",
468 self.client_addr,
469 backend_id,
470 command.len(),
471 command.trim(),
472 command.as_bytes()
473 );
474 pooled_conn.write_all(command.as_bytes()).await?;
475 *client_to_backend_bytes += command.len() as u64;
476 debug!(
477 "Client {} command sent to backend {:?}",
478 self.client_addr, backend_id
479 );
480
481 debug!(
483 "Client {} reading response from backend {:?}",
484 self.client_addr, backend_id
485 );
486
487 use tokio::io::AsyncReadExt;
489
490 let mut chunk = vec![0u8; STREAMING_CHUNK_SIZE];
491 let mut total_bytes = 0;
492
493 let n = pooled_conn.read(&mut chunk).await?;
495 if n == 0 {
496 return Err(anyhow::anyhow!("Backend connection closed unexpectedly"));
497 }
498
499 debug!(
500 "Client {} received backend response chunk ({} bytes): {} | hex: {:02x?}",
501 self.client_addr, n,
502 String::from_utf8_lossy(&chunk[..n.min(100)]), &chunk[..n.min(32)] );
505
506 let first_newline = chunk[..n].iter().position(|&b| b == b'\n').unwrap_or(n);
508
509 let is_multiline = first_newline >= 3
516 && chunk[0] == b'2'
517 && (chunk[1] == b'1' || chunk[1] == b'2' || chunk[1] == b'3');
518
519 if let Ok(first_line_str) = std::str::from_utf8(&chunk[..first_newline.min(n)]) {
521 debug!(
522 "Client {} got first line from backend {:?}: {}",
523 self.client_addr,
524 backend_id,
525 first_line_str.trim()
526 );
527 }
528
529 debug!(
531 "Client {} sending first chunk ({} bytes): {} | hex: {:02x?}",
532 self.client_addr, n,
533 String::from_utf8_lossy(&chunk[..n.min(100)]), &chunk[..n.min(32)] );
536 client_write.write_all(&chunk[..n]).await?;
537 total_bytes += n;
538
539 if is_multiline {
540 let has_terminator = if n >= 5 {
542 chunk[n - 5..n] == *b"\r\n.\r\n" || (n >= 3 && chunk[n - 3..n] == *b"\n.\n")
543 } else {
544 n >= 3 && chunk[..n] == *b"\n.\n"
545 };
546
547 if !has_terminator {
548 let mut chunk1 = chunk; let mut chunk2 = vec![0u8; STREAMING_CHUNK_SIZE]; let mut tail: [u8; TERMINATOR_TAIL_SIZE] = [0; TERMINATOR_TAIL_SIZE]; let mut tail_len: usize = 0; if n >= TERMINATOR_TAIL_SIZE {
558 tail.copy_from_slice(&chunk1[n - TERMINATOR_TAIL_SIZE..n]);
559 tail_len = TERMINATOR_TAIL_SIZE;
560 } else if n > 0 {
561 tail[..n].copy_from_slice(&chunk1[..n]);
562 tail_len = n;
563 }
564
565 let first_has_term = if n >= 5 {
567 chunk1[n - 5..n] == *b"\r\n.\r\n" || (n >= 3 && chunk1[n - 3..n] == *b"\n.\n")
568 } else {
569 n >= 3 && chunk1[..n] == *b"\n.\n"
570 };
571
572 if !first_has_term {
573 let mut current_chunk = &mut chunk1;
575 let mut next_chunk = &mut chunk2;
576
577 let mut current_n = pooled_conn.read(next_chunk).await?;
579 if current_n > 0 {
580 std::mem::swap(&mut current_chunk, &mut next_chunk);
581
582 loop {
583 debug!(
585 "Client {} sending streaming chunk ({} bytes): {} | hex: {:02x?}",
586 self.client_addr, current_n,
587 String::from_utf8_lossy(¤t_chunk[..current_n.min(100)]), ¤t_chunk[..current_n.min(32)] );
590 client_write.write_all(¤t_chunk[..current_n]).await?;
591 total_bytes += current_n;
592
593 let has_term = if current_n >= 5 {
595 current_chunk[current_n - 5..current_n] == *b"\r\n.\r\n"
596 || (current_n >= 3
597 && current_chunk[current_n - 3..current_n] == *b"\n.\n")
598 } else {
599 current_n >= 3 && current_chunk[..current_n] == *b"\n.\n"
600 };
601
602 if has_term {
603 break; }
605
606 let has_spanning_term = if tail_len >= 2 && (1..=4).contains(¤t_n)
609 {
610 let mut check_buf = [0u8; 9]; check_buf[..tail_len].copy_from_slice(&tail[..tail_len]);
613 let curr_copy = current_n.min(5);
614 check_buf[tail_len..tail_len + curr_copy]
615 .copy_from_slice(¤t_chunk[..curr_copy]);
616 let total = tail_len + curr_copy;
617
618 (total >= 5 && check_buf[total - 5..total] == *b"\r\n.\r\n")
619 || (total >= 3 && check_buf[total - 3..total] == *b"\n.\n")
620 } else {
621 false
622 };
623
624 if has_spanning_term {
625 break; }
627
628 if current_n >= TERMINATOR_TAIL_SIZE {
630 tail.copy_from_slice(
631 ¤t_chunk[current_n - TERMINATOR_TAIL_SIZE..current_n],
632 );
633 tail_len = TERMINATOR_TAIL_SIZE;
634 } else if current_n > 0 {
635 tail[..current_n].copy_from_slice(¤t_chunk[..current_n]);
636 tail_len = current_n;
637 }
638
639 let next_n = pooled_conn.read(next_chunk).await?;
641 if next_n == 0 {
642 break; }
644
645 std::mem::swap(&mut current_chunk, &mut next_chunk);
647 current_n = next_n;
648 }
649 }
650 }
651 }
652 }
653
654 debug!(
655 "Client {} forwarded response ({} bytes) to client",
656 self.client_addr, total_bytes
657 );
658 *backend_to_client_bytes += total_bytes as u64;
659
660 router.complete_command_sync(backend_id);
662
663 Ok(())
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670 use std::net::{IpAddr, Ipv4Addr};
671
672 #[test]
673 fn test_client_session_creation() {
674 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
675 let buffer_pool = BufferPool::new(1024, 4);
676 let session = ClientSession::new(addr, buffer_pool.clone());
677
678 assert_eq!(session.client_addr.port(), 8080);
679 assert_eq!(
680 session.client_addr.ip(),
681 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
682 );
683 }
684
685 #[test]
686 fn test_client_session_with_different_ports() {
687 let buffer_pool = BufferPool::new(1024, 4);
688
689 let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
690 let session1 = ClientSession::new(addr1, buffer_pool.clone());
691
692 let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9090);
693 let session2 = ClientSession::new(addr2, buffer_pool.clone());
694
695 assert_ne!(session1.client_addr.port(), session2.client_addr.port());
696 assert_eq!(session1.client_addr.port(), 8080);
697 assert_eq!(session2.client_addr.port(), 9090);
698 }
699
700 #[test]
701 fn test_client_session_with_ipv6() {
702 let buffer_pool = BufferPool::new(1024, 4);
703 let addr = SocketAddr::new(IpAddr::V6("::1".parse().unwrap()), 8119);
704 let session = ClientSession::new(addr, buffer_pool);
705
706 assert_eq!(session.client_addr.port(), 8119);
707 assert!(session.client_addr.is_ipv6());
708 }
709
710 #[test]
711 fn test_buffer_pool_cloning() {
712 let buffer_pool = BufferPool::new(8192, 10);
713 let buffer_pool_clone = buffer_pool.clone();
714
715 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 1234);
716 let _session1 = ClientSession::new(addr, buffer_pool);
717 let _session2 = ClientSession::new(addr, buffer_pool_clone);
718
719 }
721
722 #[test]
723 fn test_session_addr_formatting() {
724 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5555);
725 let buffer_pool = BufferPool::new(1024, 4);
726 let session = ClientSession::new(addr, buffer_pool);
727
728 let addr_str = format!("{}", session.client_addr);
729 assert!(addr_str.contains("10.0.0.1"));
730 assert!(addr_str.contains("5555"));
731 }
732
733 #[test]
734 fn test_multiple_sessions_same_buffer_pool() {
735 let buffer_pool = BufferPool::new(4096, 8);
736 let sessions: Vec<_> = (0..5)
737 .map(|i| {
738 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8000 + i);
739 ClientSession::new(addr, buffer_pool.clone())
740 })
741 .collect();
742
743 assert_eq!(sessions.len(), 5);
744 for (i, session) in sessions.iter().enumerate() {
745 assert_eq!(session.client_addr.port(), 8000 + i as u16);
746 }
747 }
748
749 #[test]
750 fn test_loopback_address() {
751 let buffer_pool = BufferPool::new(1024, 4);
752 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8119);
753 let session = ClientSession::new(addr, buffer_pool);
754
755 assert!(session.client_addr.ip().is_loopback());
756 }
757
758 #[test]
759 fn test_unspecified_address() {
760 let buffer_pool = BufferPool::new(1024, 4);
761 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
762 let session = ClientSession::new(addr, buffer_pool);
763
764 assert!(session.client_addr.ip().is_unspecified());
765 assert_eq!(session.client_addr.port(), 0);
766 }
767
768 #[test]
769 fn test_session_without_router() {
770 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
771 let buffer_pool = BufferPool::new(1024, 4);
772 let session = ClientSession::new(addr, buffer_pool);
773
774 assert!(!session.is_per_command_routing());
775 assert_eq!(session.client_addr.port(), 8080);
776 }
777
778 #[test]
779 fn test_session_with_router() {
780 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
781 let buffer_pool = BufferPool::new(1024, 4);
782 let router = Arc::new(BackendSelector::new());
783 let session = ClientSession::new_with_router(addr, buffer_pool, router);
784
785 assert!(session.is_per_command_routing());
786 assert_eq!(session.client_addr.port(), 8080);
787 }
788
789 #[test]
790 fn test_client_id_uniqueness() {
791 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
792 let buffer_pool = BufferPool::new(1024, 4);
793
794 let session1 = ClientSession::new(addr, buffer_pool.clone());
795 let session2 = ClientSession::new(addr, buffer_pool);
796
797 assert_ne!(session1.client_id(), session2.client_id());
799 }
800
801 #[tokio::test]
802 async fn test_quit_command_per_command_routing() {
803 use tokio::io::{AsyncReadExt, AsyncWriteExt};
804 use tokio::net::TcpListener;
805
806 let backend_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
808 let backend_addr = backend_listener.local_addr().unwrap();
809
810 tokio::spawn(async move {
812 if let Ok((mut stream, _)) = backend_listener.accept().await {
813 let _ = stream.write_all(b"200 Mock Server Ready\r\n").await;
815
816 let mut buf = [0u8; 1024];
818 let _ = stream.read(&mut buf).await;
819 }
820 });
821
822 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
824
825 let mut router = BackendSelector::new();
827 let provider = crate::pool::DeadpoolConnectionProvider::new(
828 "127.0.0.1".to_string(),
829 backend_addr.port(),
830 "test-backend".to_string(),
831 2,
832 None,
833 None,
834 );
835 router.add_backend(
836 crate::types::BackendId::from_index(0),
837 "test-backend".to_string(),
838 provider,
839 );
840
841 let client_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
843 let client_addr = client_listener.local_addr().unwrap();
844
845 let buffer_pool = BufferPool::new(1024, 4);
847 let session = ClientSession::new_with_router(client_addr, buffer_pool, Arc::new(router));
848
849 let client_handle = tokio::spawn(async move {
851 let mut client = tokio::net::TcpStream::connect(client_addr).await.unwrap();
852
853 let mut greeting = [0u8; 256];
855 let n = client.read(&mut greeting).await.unwrap();
856 assert!(n > 0);
857
858 client.write_all(b"QUIT\r\n").await.unwrap();
860
861 let mut response = [0u8; 256];
863 let _ = client.read(&mut response).await;
864
865 drop(client);
867 });
868
869 let (client_stream, _) = client_listener.accept().await.unwrap();
871
872 let result = session.handle_per_command_routing(client_stream).await;
874
875 assert!(
877 result.is_ok(),
878 "QUIT handling should not return error: {:?}",
879 result
880 );
881
882 if let Ok((sent, received)) = result {
883 assert!(sent > 0, "Should have sent bytes (QUIT command)");
885 assert!(received > 0, "Should have received bytes (greeting)");
887 }
888
889 let _ = client_handle.await;
891 }
892
893 #[tokio::test]
894 async fn test_quit_command_closes_connection_cleanly() {
895 use tokio::io::{AsyncReadExt, AsyncWriteExt};
896 use tokio::net::TcpListener;
897
898 let backend_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900 let backend_addr = backend_listener.local_addr().unwrap();
901
902 tokio::spawn(async move {
903 if let Ok((mut stream, _)) = backend_listener.accept().await {
904 let _ = stream.write_all(b"200 Ready\r\n").await;
905 let mut buf = [0u8; 1024];
906 let _ = stream.read(&mut buf).await;
907 }
908 });
909
910 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
911
912 let mut router = BackendSelector::new();
914 let provider = crate::pool::DeadpoolConnectionProvider::new(
915 "127.0.0.1".to_string(),
916 backend_addr.port(),
917 "test".to_string(),
918 1,
919 None,
920 None,
921 );
922 router.add_backend(
923 crate::types::BackendId::from_index(0),
924 "test".to_string(),
925 provider,
926 );
927
928 let client_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
930 let client_addr = client_listener.local_addr().unwrap();
931
932 let buffer_pool = BufferPool::new(1024, 4);
933 let session = ClientSession::new_with_router(client_addr, buffer_pool, Arc::new(router));
934
935 let client_handle = tokio::spawn(async move {
937 let mut client = tokio::net::TcpStream::connect(client_addr).await.unwrap();
938
939 let mut buf = [0u8; 256];
941 let n = client.read(&mut buf).await.unwrap();
942 assert!(n > 0, "Should receive greeting");
943
944 client.write_all(b"QUIT\r\n").await.unwrap();
946
947 let n = client.read(&mut buf).await.unwrap();
949
950 let response = String::from_utf8_lossy(&buf[..n]);
952 assert!(
953 response.contains("205"),
954 "Should receive 205 closing response"
955 );
956 });
957
958 let (client_stream, _) = client_listener.accept().await.unwrap();
959 let result = session.handle_per_command_routing(client_stream).await;
960
961 assert!(result.is_ok(), "Session should handle QUIT cleanly");
962
963 client_handle.await.unwrap();
964 }
965}