1use anyhow::Result;
2use crossbeam::queue::SegQueue;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9use tokio::net::{TcpSocket, TcpStream};
10use tokio::sync::Semaphore;
11use tracing::{debug, error, info, warn};
12
13fn default_max_connections() -> u32 {
15 10
16}
17
18#[derive(Debug, Clone)]
21pub struct BufferPool {
22 pool: Arc<SegQueue<Vec<u8>>>,
23 buffer_size: usize,
24 max_pool_size: usize,
25 pool_size: Arc<AtomicUsize>,
26}
27
28impl BufferPool {
29 fn create_aligned_buffer(size: usize) -> Vec<u8> {
31 let page_size = 4096;
33 let aligned_size = size.div_ceil(page_size) * page_size;
34
35 let mut buffer = Vec::with_capacity(aligned_size);
37 buffer.resize(size, 0);
38 buffer
39 }
40
41 pub fn new(buffer_size: usize, max_pool_size: usize) -> Self {
42 let pool = Arc::new(SegQueue::new());
43 let pool_size = Arc::new(AtomicUsize::new(0));
44
45 info!(
47 "Pre-allocating {} buffers of {}KB each ({}MB total)",
48 max_pool_size,
49 buffer_size / 1024,
50 (max_pool_size * buffer_size) / (1024 * 1024)
51 );
52
53 for _ in 0..max_pool_size {
54 let buffer = Self::create_aligned_buffer(buffer_size);
55 pool.push(buffer);
56 pool_size.fetch_add(1, Ordering::Relaxed);
57 }
58
59 info!("Buffer pool pre-allocation complete");
60
61 Self {
62 pool,
63 buffer_size,
64 max_pool_size,
65 pool_size,
66 }
67 }
68
69 pub async fn get_buffer(&self) -> Vec<u8> {
71 if let Some(mut buffer) = self.pool.pop() {
72 self.pool_size.fetch_sub(1, Ordering::Relaxed);
73 buffer.clear();
75 buffer.resize(self.buffer_size, 0);
76 buffer
77 } else {
78 Self::create_aligned_buffer(self.buffer_size)
80 }
81 }
82
83 pub async fn return_buffer(&self, buffer: Vec<u8>) {
85 if buffer.len() == self.buffer_size {
86 let current_size = self.pool_size.load(Ordering::Relaxed);
87 if current_size < self.max_pool_size {
88 self.pool.push(buffer);
89 self.pool_size.fetch_add(1, Ordering::Relaxed);
90 }
91 }
93 }
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
97pub struct Config {
98 pub servers: Vec<ServerConfig>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
103pub struct ServerConfig {
104 pub host: String,
105 pub port: u16,
106 pub name: String,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub username: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub password: Option<String>,
111 #[serde(default = "default_max_connections")]
113 pub max_connections: u32,
114}
115
116#[derive(Debug)]
118pub struct PooledConnection {
119 pub stream: TcpStream,
120 pub server_name: String,
121 pub authenticated: bool,
122}
123
124impl PooledConnection {
125 pub fn new(
126 stream: TcpStream,
127 server_name: String,
128 authenticated: bool,
129 ) -> Self {
130 Self {
131 stream,
132 server_name,
133 authenticated,
134 }
135 }
136
137 pub fn into_stream(self) -> TcpStream {
138 self.stream
139 }
140
141 pub fn is_authenticated(&self) -> bool {
142 self.authenticated
143 }
144
145 pub fn server_name(&self) -> &str {
146 &self.server_name
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct ConnectionPool {
153 pool: Arc<SegQueue<TcpStream>>,
154 max_connections: usize,
155 active_connections: Arc<AtomicUsize>,
156 initialized: Arc<AtomicBool>,
157}
158
159impl ConnectionPool {
160 pub fn new(max_connections: usize) -> Self {
161 Self {
162 pool: Arc::new(SegQueue::new()),
163 max_connections,
164 active_connections: Arc::new(AtomicUsize::new(0)),
165 initialized: Arc::new(AtomicBool::new(false)),
166 }
167 }
168
169 async fn initialize_connections(&self, server: &ServerConfig) -> Result<()> {
171 if self
173 .initialized
174 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
175 .is_ok()
176 {
177 info!(
178 "Pre-establishing {} connections to {}",
179 self.max_connections, server.name
180 );
181
182 let mut tasks = Vec::new();
184 for i in 0..self.max_connections {
185 let server_addr = format!("{}:{}", server.host, server.port);
186 let server_name = server.name.clone();
187 let pool = Arc::clone(&self.pool);
188 let active_connections = Arc::clone(&self.active_connections);
189
190 let task = tokio::spawn(async move {
191 match Self::create_optimized_tcp_stream(&server_addr).await {
192 Ok(stream) => {
193 pool.push(stream);
194 active_connections.fetch_add(1, Ordering::Relaxed);
195 debug!("Pre-established connection {} to {}", i + 1, server_name);
196 Ok(())
197 }
198 Err(e) => {
199 warn!(
200 "Failed to pre-establish connection {} to {}: {}",
201 i + 1,
202 server_name,
203 e
204 );
205 Err(e)
206 }
207 }
208 });
209 tasks.push(task);
210 }
211
212 for task in tasks {
214 let _ = task.await;
215 }
216
217 let established = self.active_connections.load(Ordering::Relaxed);
218 info!(
219 "Successfully pre-established {}/{} connections to {} in parallel",
220 established, self.max_connections, server.name
221 );
222 }
223 Ok(())
224 }
225
226 pub async fn get_connection(
228 &self,
229 server: &ServerConfig,
230 _proxy: &NntpProxy,
231 ) -> Result<PooledConnection> {
232 if !self.initialized.load(Ordering::Acquire) {
234 self.initialize_connections(server).await?;
235 }
236
237 if let Some(stream) = self.pool.pop() {
239 let mut test_buf = [0u8; 1];
241 match stream.try_read(&mut test_buf) {
242 Ok(0) => {
243 self.active_connections.fetch_sub(1, Ordering::Relaxed);
245 info!(
246 "Pooled connection to {} was closed, creating new one",
247 server.name
248 );
249 }
250 Ok(_) => {
251 self.active_connections.fetch_sub(1, Ordering::Relaxed);
253 info!(
254 "Pooled connection to {} has unexpected data, creating new one",
255 server.name
256 );
257 }
258 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
259 info!("Reusing pooled connection to {}", server.name);
261 return Ok(PooledConnection {
262 stream,
263 server_name: server.name.clone(),
264 authenticated: false, });
266 }
267 Err(_) => {
268 self.active_connections.fetch_sub(1, Ordering::Relaxed);
270 info!(
271 "Pooled connection to {} has error, creating new one",
272 server.name
273 );
274 }
275 }
276 }
277
278 info!("Creating new connection to {} for pooling", server.name);
280 let backend_addr = format!("{}:{}", server.host, server.port);
281 let stream = Self::create_optimized_tcp_stream(&backend_addr).await?;
282
283 let pooled_conn = PooledConnection::new(
285 stream,
286 server.name.clone(),
287 false,
288 );
289 Ok(pooled_conn)
290 }
291
292 async fn create_optimized_tcp_stream(addr: &str) -> Result<TcpStream, std::io::Error> {
294 use std::net::{SocketAddr, ToSocketAddrs};
295
296 let socket_addr: SocketAddr = addr.to_socket_addrs()?.next().ok_or_else(|| {
298 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid address")
299 })?;
300
301 let socket = if socket_addr.is_ipv4() {
303 TcpSocket::new_v4()?
304 } else {
305 TcpSocket::new_v6()?
306 };
307
308 socket.set_nodelay(true)?; #[cfg(target_os = "linux")]
313 {
314 use std::os::unix::io::AsRawFd;
315 let fd = socket.as_raw_fd();
316
317 let buffer_size = 2 * 1024 * 1024i32; unsafe {
320 libc::setsockopt(
321 fd,
322 libc::SOL_SOCKET,
323 libc::SO_RCVBUF,
324 &buffer_size as *const i32 as *const libc::c_void,
325 std::mem::size_of::<i32>() as u32,
326 );
327 libc::setsockopt(
328 fd,
329 libc::SOL_SOCKET,
330 libc::SO_SNDBUF,
331 &buffer_size as *const i32 as *const libc::c_void,
332 std::mem::size_of::<i32>() as u32,
333 );
334
335 let keepalive = 1i32;
337 libc::setsockopt(
338 fd,
339 libc::SOL_SOCKET,
340 libc::SO_KEEPALIVE,
341 &keepalive as *const i32 as *const libc::c_void,
342 std::mem::size_of::<i32>() as u32,
343 );
344
345 let keepalive_time = 60i32; let keepalive_interval = 10i32; let keepalive_probes = 3i32; libc::setsockopt(
351 fd,
352 libc::IPPROTO_TCP,
353 libc::TCP_KEEPIDLE,
354 &keepalive_time as *const i32 as *const libc::c_void,
355 std::mem::size_of::<i32>() as u32,
356 );
357 libc::setsockopt(
358 fd,
359 libc::IPPROTO_TCP,
360 libc::TCP_KEEPINTVL,
361 &keepalive_interval as *const i32 as *const libc::c_void,
362 std::mem::size_of::<i32>() as u32,
363 );
364 libc::setsockopt(
365 fd,
366 libc::IPPROTO_TCP,
367 libc::TCP_KEEPCNT,
368 &keepalive_probes as *const i32 as *const libc::c_void,
369 std::mem::size_of::<i32>() as u32,
370 );
371
372 let cork_flag = 1i32;
374 libc::setsockopt(
375 fd,
376 libc::IPPROTO_TCP,
377 libc::TCP_CORK,
378 &cork_flag as *const i32 as *const libc::c_void,
379 std::mem::size_of::<i32>() as u32,
380 );
381
382 let bbr_name = b"bbr\0";
385 let bbr_result = libc::setsockopt(
386 fd,
387 libc::IPPROTO_TCP,
388 libc::TCP_CONGESTION,
389 bbr_name.as_ptr() as *const libc::c_void,
390 bbr_name.len() as u32 - 1,
391 );
392
393 if bbr_result != 0 {
395 let cubic_name = b"cubic\0";
396 libc::setsockopt(
397 fd,
398 libc::IPPROTO_TCP,
399 libc::TCP_CONGESTION,
400 cubic_name.as_ptr() as *const libc::c_void,
401 cubic_name.len() as u32 - 1,
402 );
403 }
404
405 let tcp_fastopen = 1i32; libc::setsockopt(
408 fd,
409 libc::IPPROTO_TCP,
410 libc::TCP_FASTOPEN,
411 &tcp_fastopen as *const i32 as *const libc::c_void,
412 std::mem::size_of::<i32>() as u32,
413 );
414
415 let reuse_addr = 1i32;
417 libc::setsockopt(
418 fd,
419 libc::SOL_SOCKET,
420 libc::SO_REUSEADDR,
421 &reuse_addr as *const i32 as *const libc::c_void,
422 std::mem::size_of::<i32>() as u32,
423 );
424
425 let reuse_port = 1i32;
427 libc::setsockopt(
428 fd,
429 libc::SOL_SOCKET,
430 libc::SO_REUSEPORT,
431 &reuse_port as *const i32 as *const libc::c_void,
432 std::mem::size_of::<i32>() as u32,
433 );
434 }
435 }
436
437 socket.connect(socket_addr).await
439 }
440
441 pub async fn return_connection(&self, conn: PooledConnection) {
443 if self.active_connections.load(Ordering::Relaxed) >= self.max_connections {
444 info!("Pool is full, closing connection to {}", conn.server_name);
445 return; }
447
448 let mut test_buf = [0u8; 1];
450 match conn.stream.try_read(&mut test_buf) {
451 Ok(0) => {
452 info!(
454 "Connection to {} was closed by server, not returning to pool",
455 conn.server_name
456 );
457 self.active_connections.fetch_sub(1, Ordering::Relaxed);
458 }
459 Ok(_) => {
460 info!(
462 "Connection to {} has unexpected data, not returning to pool",
463 conn.server_name
464 );
465 self.active_connections.fetch_sub(1, Ordering::Relaxed);
466 }
467 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
468 info!("Returning connection to {} to pool", conn.server_name);
470 self.pool.push(conn.stream);
471 }
473 Err(_) => {
474 info!(
476 "Connection to {} has error, not returning to pool",
477 conn.server_name
478 );
479 self.active_connections.fetch_sub(1, Ordering::Relaxed);
480 }
481 }
482 }
483}
484
485#[derive(Clone, Debug)]
486pub struct NntpProxy {
487 servers: Vec<ServerConfig>,
488 current_index: Arc<AtomicUsize>,
489 connection_semaphores: Arc<HashMap<String, Arc<Semaphore>>>,
491 connection_pool: ConnectionPool,
493 buffer_pool: BufferPool,
495}
496
497impl NntpProxy {
498 pub fn new(config: Config) -> Result<Self> {
499 if config.servers.is_empty() {
500 anyhow::bail!("No servers configured");
501 }
502
503 let mut connection_semaphores = HashMap::new();
505 for server in &config.servers {
506 let semaphore = Arc::new(Semaphore::new(server.max_connections as usize));
507 connection_semaphores.insert(server.name.clone(), semaphore);
508 info!(
509 "Server '{}' configured with max {} connections",
510 server.name, server.max_connections
511 );
512 }
513
514 Ok(Self {
515 servers: config.servers,
516 current_index: Arc::new(AtomicUsize::new(0)),
517 connection_semaphores: Arc::new(connection_semaphores),
518 connection_pool: ConnectionPool::new(32), buffer_pool: BufferPool::new(256 * 1024, 32), })
521 }
522
523 pub async fn prewarm_connections(&self) -> Result<()> {
525 info!("Pre-warming connections to all backend servers...");
526 for server in &self.servers {
527 for i in 0..4 {
529 match self.connection_pool.get_connection(server, self).await {
531 Ok(conn) => {
532 info!("Pre-warmed connection {}/4 to {}", i + 1, server.name);
533 self.connection_pool.return_connection(conn).await;
534 }
535 Err(e) => {
536 warn!("Failed to pre-warm connection to {}: {}", server.name, e);
537 }
538 }
539 }
540 }
541 info!("Connection pre-warming complete");
542 Ok(())
543 }
544
545 pub fn next_server(&self) -> &ServerConfig {
547 let index = self.current_index.fetch_add(1, Ordering::Relaxed);
548 &self.servers[index % self.servers.len()]
549 }
550
551 #[cfg(test)]
553 pub fn current_index(&self) -> usize {
554 self.current_index.load(Ordering::Relaxed) % self.servers.len()
555 }
556
557 #[cfg(test)]
559 pub fn reset_index(&self) {
560 self.current_index.store(0, Ordering::Relaxed);
561 }
562
563 pub fn servers(&self) -> &[ServerConfig] {
565 &self.servers
566 }
567
568 pub async fn handle_client(
569 &self,
570 mut client_stream: TcpStream,
571 client_addr: SocketAddr,
572 ) -> Result<()> {
573 info!("New client connection from {}", client_addr);
574
575 let server = self.next_server();
577 info!(
578 "Routing client {} to server {}:{}",
579 client_addr, server.host, server.port
580 );
581
582 let semaphore = self.connection_semaphores.get(&server.name).unwrap();
584 let _permit = match semaphore.try_acquire() {
585 Ok(permit) => {
586 info!(
587 "Acquired connection permit for server '{}' ({} remaining)",
588 server.name,
589 semaphore.available_permits()
590 );
591 permit
592 }
593 Err(_) => {
594 warn!(
595 "Server '{}' has reached max connections ({}), rejecting client",
596 server.name, server.max_connections
597 );
598 let _ = client_stream
599 .write_all(b"400 Server temporarily unavailable - too many connections\r\n")
600 .await;
601 return Err(anyhow::anyhow!("Server {} at max connections", server.name));
602 }
603 };
604
605 let backend_addr = format!("{}:{}", server.host, server.port);
607
608 let (mut backend_stream, is_pooled, server_name, pooled_authenticated) =
610 match self.connection_pool.get_connection(server, self).await {
611 Ok(pooled) => {
612 info!(
613 "Using pooled connection to {} (authenticated: {})",
614 pooled.server_name, pooled.authenticated
615 );
616 (
617 pooled.stream,
618 true,
619 pooled.server_name.clone(),
620 pooled.authenticated,
621 )
622 }
623 Err(_) => {
624 info!("Creating new connection to {}", backend_addr);
626 match ConnectionPool::create_optimized_tcp_stream(&backend_addr).await {
627 Ok(stream) => (stream, false, server.name.clone(), false),
628 Err(e) => {
629 error!("Failed to connect to backend {}: {}", backend_addr, e);
630 let _ = client_stream
631 .write_all(b"400 Backend server unavailable\r\n")
632 .await;
633 return Err(e.into());
634 }
635 }
636 }
637 };
638
639 info!("Connected to backend server {}", backend_addr);
640
641 if !is_pooled || !pooled_authenticated {
643 if let (Some(username), Some(password)) = (&server.username, &server.password) {
645 if let Err(e) = self
646 .authenticate_backend(&mut backend_stream, username, password)
647 .await
648 {
649 error!("Authentication failed for {}: {}", server.name, e);
650 let _ = client_stream
651 .write_all(b"502 Authentication failed\r\n")
652 .await;
653 return Err(e);
654 }
655 }
656 }
657
658 if let Err(e) = client_stream.write_all(b"200 NNTP Service Ready\r\n").await {
660 error!("Failed to send greeting to client: {}", e);
661 return Err(e.into());
662 }
663
664 let copy_result = {
666 #[cfg(target_os = "linux")]
667 {
668 match self
669 .copy_bidirectional_zero_copy(&mut client_stream, &mut backend_stream)
670 .await
671 {
672 Ok(result) => {
673 debug!("Zero-copy successful");
674 Ok(result)
675 }
676 Err(_) => {
677 debug!("Zero-copy failed, falling back to buffered copy");
678 self.copy_bidirectional_buffered(&mut client_stream, &mut backend_stream)
679 .await
680 }
681 }
682 }
683 #[cfg(not(target_os = "linux"))]
684 {
685 self.copy_bidirectional_buffered(&mut client_stream, &mut backend_stream)
686 .await
687 }
688 };
689
690 let was_authenticated = if is_pooled {
693 pooled_authenticated || (server.username.is_some() && server.password.is_some())
695 } else {
696 server.username.is_some() && server.password.is_some()
698 };
699
700 let pooled_conn = PooledConnection::new(
701 backend_stream,
702 server_name,
703 was_authenticated,
704 );
705 self.connection_pool.return_connection(pooled_conn).await;
706 info!(
707 "Returned connection to pool for {} (authenticated: {})",
708 server.name, was_authenticated
709 );
710
711 match copy_result {
712 Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
713 info!(
714 "Connection closed for client {}: {} bytes client->backend, {} bytes backend->client",
715 client_addr, client_to_backend_bytes, backend_to_client_bytes
716 );
717 }
718 Err(e) => {
719 warn!("Bidirectional copy error for client {}: {}", client_addr, e);
720 }
721 }
722
723 info!("Connection closed for client {}", client_addr);
725 Ok(())
726 }
727
728 async fn authenticate_backend(
730 &self,
731 stream: &mut TcpStream,
732 username: &str,
733 password: &str,
734 ) -> Result<()> {
735 let mut buffer = self.buffer_pool.get_buffer().await;
737
738 let n = stream.read(&mut buffer).await?;
740 let greeting = &buffer[..n];
741 info!(
742 "Server greeting: {}",
743 String::from_utf8_lossy(greeting).trim()
744 );
745
746 let greeting_str = String::from_utf8_lossy(greeting);
748 if !greeting_str.starts_with("200") && !greeting_str.starts_with("201") {
749 return Err(anyhow::anyhow!(
750 "Server returned non-success greeting: {}",
751 greeting_str.trim()
752 ));
753 }
754
755 let user_command = format!("AUTHINFO USER {}\r\n", username);
757 stream.write_all(user_command.as_bytes()).await?;
758
759 let n = stream.read(&mut buffer).await?;
761 let response = String::from_utf8_lossy(&buffer[..n]);
762 info!("AUTHINFO USER response: {}", response.trim());
763
764 if response.starts_with("281") {
766 return Ok(());
768 } else if !response.starts_with("381") {
769 return Err(anyhow::anyhow!(
770 "Unexpected response to AUTHINFO USER: {}",
771 response.trim()
772 ));
773 }
774
775 let pass_command = format!("AUTHINFO PASS {}\r\n", password);
777 stream.write_all(pass_command.as_bytes()).await?;
778
779 let n = stream.read(&mut buffer).await?;
781 let response = String::from_utf8_lossy(&buffer[..n]);
782 info!("AUTHINFO PASS response: {}", response.trim());
783
784 let result = if response.starts_with("281") {
786 Ok(())
787 } else {
788 Err(anyhow::anyhow!(
789 "Authentication failed: {}",
790 response.trim()
791 ))
792 };
793
794 self.buffer_pool.return_buffer(buffer).await;
796 result
797 }
798
799 #[cfg(target_os = "linux")]
801 async fn copy_bidirectional_zero_copy(
802 &self,
803 client_stream: &mut TcpStream,
804 backend_stream: &mut TcpStream,
805 ) -> Result<(u64, u64), std::io::Error> {
806 debug!("Starting optimized zero-copy bidirectional transfer");
807
808 if let Err(e) = Self::set_high_throughput_optimizations(client_stream) {
810 debug!("Failed to set client socket optimizations: {}", e);
811 }
812 if let Err(e) = Self::set_high_throughput_optimizations(backend_stream) {
813 debug!("Failed to set backend socket optimizations: {}", e);
814 }
815
816 match tokio_splice2::copy_bidirectional(client_stream, backend_stream).await {
817 Ok(traffic_result) => {
818 debug!(
819 "Zero-copy transfer successful: {} bytes (client->server: {}, server->client: {})",
820 traffic_result.tx + traffic_result.rx,
821 traffic_result.tx,
822 traffic_result.rx
823 );
824 Ok((traffic_result.tx as u64, traffic_result.rx as u64))
825 }
826 Err(e) => {
827 debug!("Zero-copy failed: {}", e);
828 Err(e)
829 }
830 }
831 }
832
833 fn set_high_throughput_optimizations(stream: &TcpStream) -> Result<(), std::io::Error> {
835 use std::os::unix::io::AsRawFd;
836 let fd = stream.as_raw_fd();
837
838 unsafe {
839 let quickack: libc::c_int = 1;
844 libc::setsockopt(
845 fd,
846 libc::IPPROTO_TCP,
847 libc::TCP_QUICKACK,
848 &quickack as *const _ as *const libc::c_void,
849 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
850 );
851
852 let window_clamp: libc::c_int = 16 * 1024 * 1024; libc::setsockopt(
855 fd,
856 libc::IPPROTO_TCP,
857 libc::TCP_WINDOW_CLAMP,
858 &window_clamp as *const _ as *const libc::c_void,
859 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
860 );
861
862 let cork: libc::c_int = 1;
864 libc::setsockopt(
865 fd,
866 libc::IPPROTO_TCP,
867 libc::TCP_CORK,
868 &cork as *const _ as *const libc::c_void,
869 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
870 );
871
872 let uncork: libc::c_int = 0;
874 libc::setsockopt(
875 fd,
876 libc::IPPROTO_TCP,
877 libc::TCP_CORK,
878 &uncork as *const _ as *const libc::c_void,
879 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
880 );
881 }
882
883 Ok(())
884 }
885
886 async fn copy_bidirectional_buffered<R, W>(
889 &self,
890 mut reader: R,
891 mut writer: W,
892 ) -> Result<(u64, u64), std::io::Error>
893 where
894 R: AsyncRead + AsyncWrite + Unpin,
895 W: AsyncRead + AsyncWrite + Unpin,
896 {
897 use std::io::ErrorKind;
901 use tokio::io::{AsyncReadExt, AsyncWriteExt};
902
903 let mut buf1 = self.buffer_pool.get_buffer().await;
905 let mut buf2 = self.buffer_pool.get_buffer().await;
906
907 let mut transferred_a_to_b = 0u64;
908 let mut transferred_b_to_a = 0u64;
909
910 let copy_result = async {
912 loop {
913 tokio::select! {
914 result = reader.read(&mut buf1) => {
916 match result {
917 Ok(0) => break, Ok(n) => {
919 writer.write_all(&buf1[..n]).await?;
920 transferred_a_to_b += n as u64;
921 }
922 Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
923 Err(e) => return Err(e),
924 }
925 }
926 result = writer.read(&mut buf2) => {
928 match result {
929 Ok(0) => break, Ok(n) => {
931 reader.write_all(&buf2[..n]).await?;
932 transferred_b_to_a += n as u64;
933 }
934 Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
935 Err(e) => return Err(e),
936 }
937 }
938 }
939 }
940 Ok((transferred_a_to_b, transferred_b_to_a))
941 }
942 .await;
943
944 self.buffer_pool.return_buffer(buf1).await;
946 self.buffer_pool.return_buffer(buf2).await;
947
948 copy_result
949 }
950}
951
952pub fn load_config(config_path: &str) -> Result<Config> {
953 let config_content = std::fs::read_to_string(config_path)
954 .map_err(|e| anyhow::anyhow!("Failed to read config file '{}': {}", config_path, e))?;
955
956 let config: Config = toml::from_str(&config_content)
957 .map_err(|e| anyhow::anyhow!("Failed to parse config file '{}': {}", config_path, e))?;
958
959 Ok(config)
960}
961
962pub fn create_default_config() -> Config {
963 Config {
964 servers: vec![ServerConfig {
965 host: "news.example.com".to_string(),
966 port: 119,
967 name: "Example News Server".to_string(),
968 username: None,
969 password: None,
970 max_connections: default_max_connections(),
971 }],
972 }
973}
974
975#[cfg(test)]
976mod tests {
977 use super::*;
978 use std::io::Write;
979 use std::sync::Arc;
980 use tempfile::NamedTempFile;
981
982 fn create_test_config() -> Config {
983 Config {
984 servers: vec![
985 ServerConfig {
986 host: "server1.example.com".to_string(),
987 port: 119,
988 name: "Test Server 1".to_string(),
989 username: None,
990 password: None,
991 max_connections: 5,
992 },
993 ServerConfig {
994 host: "server2.example.com".to_string(),
995 port: 119,
996 name: "Test Server 2".to_string(),
997 username: None,
998 password: None,
999 max_connections: 8,
1000 },
1001 ServerConfig {
1002 host: "server3.example.com".to_string(),
1003 port: 119,
1004 name: "Test Server 3".to_string(),
1005 username: None,
1006 password: None,
1007 max_connections: 12,
1008 },
1009 ],
1010 }
1011 }
1012
1013 #[test]
1014 fn test_server_config_creation() {
1015 let config = ServerConfig {
1016 host: "news.example.com".to_string(),
1017 port: 119,
1018 name: "Example Server".to_string(),
1019 username: None,
1020 password: None,
1021 max_connections: 15,
1022 };
1023
1024 assert_eq!(config.host, "news.example.com");
1025 assert_eq!(config.port, 119);
1026 assert_eq!(config.name, "Example Server");
1027 assert_eq!(config.max_connections, 15);
1028 }
1029
1030 #[test]
1031 fn test_config_creation() {
1032 let config = create_test_config();
1033 assert_eq!(config.servers.len(), 3);
1034 assert_eq!(config.servers[0].name, "Test Server 1");
1035 assert_eq!(config.servers[1].name, "Test Server 2");
1036 assert_eq!(config.servers[2].name, "Test Server 3");
1037 }
1038
1039 #[test]
1040 fn test_proxy_creation_with_servers() {
1041 let config = create_test_config();
1042 let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1043
1044 assert_eq!(proxy.servers().len(), 3);
1045 assert_eq!(proxy.servers()[0].name, "Test Server 1");
1046 }
1047
1048 #[test]
1049 fn test_proxy_creation_with_empty_servers() {
1050 let config = Config { servers: vec![] };
1051 let result = NntpProxy::new(config);
1052
1053 assert!(result.is_err());
1054 assert!(
1055 result
1056 .unwrap_err()
1057 .to_string()
1058 .contains("No servers configured")
1059 );
1060 }
1061
1062 #[test]
1063 fn test_round_robin_server_selection() {
1064 let config = create_test_config();
1065 let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1066
1067 proxy.reset_index();
1068
1069 assert_eq!(proxy.next_server().name, "Test Server 1");
1071 assert_eq!(proxy.next_server().name, "Test Server 2");
1072 assert_eq!(proxy.next_server().name, "Test Server 3");
1073
1074 assert_eq!(proxy.next_server().name, "Test Server 1");
1076 assert_eq!(proxy.next_server().name, "Test Server 2");
1077 }
1078
1079 #[test]
1080 fn test_round_robin_with_single_server() {
1081 let config = Config {
1082 servers: vec![ServerConfig {
1083 host: "single.example.com".to_string(),
1084 port: 119,
1085 name: "Single Server".to_string(),
1086 username: None,
1087 password: None,
1088 max_connections: 3,
1089 }],
1090 };
1091
1092 let proxy = NntpProxy::new(config).expect("Failed to create proxy");
1093 proxy.reset_index();
1094
1095 assert_eq!(proxy.next_server().name, "Single Server");
1097 assert_eq!(proxy.next_server().name, "Single Server");
1098 assert_eq!(proxy.next_server().name, "Single Server");
1099 }
1100
1101 #[test]
1102 fn test_concurrent_round_robin() {
1103 let config = create_test_config();
1104 let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
1105 proxy.reset_index();
1106
1107 let mut handles = vec![];
1108 let servers_selected = Arc::new(std::sync::Mutex::new(Vec::new()));
1109
1110 for _ in 0..9 {
1112 let proxy_clone = Arc::clone(&proxy);
1113 let servers_clone = Arc::clone(&servers_selected);
1114
1115 let handle = std::thread::spawn(move || {
1116 let server = proxy_clone.next_server();
1117 servers_clone.lock().unwrap().push(server.name.clone());
1118 });
1119 handles.push(handle);
1120 }
1121
1122 for handle in handles {
1124 handle.join().unwrap();
1125 }
1126
1127 let servers = servers_selected.lock().unwrap();
1128 assert_eq!(servers.len(), 9);
1129
1130 let server1_count = servers.iter().filter(|&s| s == "Test Server 1").count();
1132 let server2_count = servers.iter().filter(|&s| s == "Test Server 2").count();
1133 let server3_count = servers.iter().filter(|&s| s == "Test Server 3").count();
1134
1135 assert_eq!(server1_count, 3);
1137 assert_eq!(server2_count, 3);
1138 assert_eq!(server3_count, 3);
1139 }
1140
1141 #[test]
1142 fn test_load_config_from_file() -> Result<()> {
1143 let config = create_test_config();
1144 let config_toml = toml::to_string_pretty(&config)?;
1145
1146 let mut temp_file = NamedTempFile::new()?;
1148 write!(temp_file, "{}", config_toml)?;
1149
1150 let loaded_config = load_config(temp_file.path().to_str().unwrap())?;
1152
1153 assert_eq!(loaded_config.servers.len(), 3);
1154 assert_eq!(loaded_config.servers[0].name, "Test Server 1");
1155 assert_eq!(loaded_config.servers[0].host, "server1.example.com");
1156 assert_eq!(loaded_config.servers[0].port, 119);
1157
1158 Ok(())
1159 }
1160
1161 #[test]
1162 fn test_load_config_nonexistent_file() {
1163 let result = load_config("/nonexistent/path/config.toml");
1164 assert!(result.is_err());
1165 assert!(
1166 result
1167 .unwrap_err()
1168 .to_string()
1169 .contains("Failed to read config file")
1170 );
1171 }
1172
1173 #[test]
1174 fn test_load_config_invalid_toml() -> Result<()> {
1175 let invalid_toml = "invalid toml content [[[";
1176
1177 let mut temp_file = NamedTempFile::new()?;
1179 write!(temp_file, "{}", invalid_toml)?;
1180
1181 let result = load_config(temp_file.path().to_str().unwrap());
1182 assert!(result.is_err());
1183 assert!(
1184 result
1185 .unwrap_err()
1186 .to_string()
1187 .contains("Failed to parse config file")
1188 );
1189
1190 Ok(())
1191 }
1192
1193 #[test]
1194 fn test_create_default_config() {
1195 let config = create_default_config();
1196
1197 assert_eq!(config.servers.len(), 1);
1198 assert_eq!(config.servers[0].host, "news.example.com");
1199 assert_eq!(config.servers[0].port, 119);
1200 assert_eq!(config.servers[0].name, "Example News Server");
1201 }
1202
1203 #[test]
1204 fn test_config_serialization() -> Result<()> {
1205 let config = create_test_config();
1206
1207 let toml_string = toml::to_string_pretty(&config)?;
1209 assert!(toml_string.contains("server1.example.com"));
1210 assert!(toml_string.contains("Test Server 1"));
1211
1212 let deserialized: Config = toml::from_str(&toml_string)?;
1214 assert_eq!(deserialized, config);
1215
1216 Ok(())
1217 }
1218}