1use anyhow::{Context, Result};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::auth::AuthHandler;
14use crate::config::{Config, RoutingMode, ServerConfig};
15use crate::constants::buffer::{POOL, POOL_COUNT};
16use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
17use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
18use crate::protocol::BACKEND_UNAVAILABLE;
19use crate::router;
20use crate::session::ClientSession;
21use crate::types::{self, BufferSize};
22
23#[derive(Debug)]
55pub struct NntpProxyBuilder {
56 config: Config,
57 routing_mode: RoutingMode,
58 buffer_size: Option<usize>,
59 buffer_count: Option<usize>,
60}
61
62impl NntpProxyBuilder {
63 #[must_use]
67 pub fn new(config: Config) -> Self {
68 Self {
69 config,
70 routing_mode: RoutingMode::Standard,
71 buffer_size: None,
72 buffer_count: None,
73 }
74 }
75
76 #[must_use]
83 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
84 self.routing_mode = mode;
85 self
86 }
87
88 #[must_use]
93 pub fn with_buffer_pool_size(mut self, size: usize) -> Self {
94 self.buffer_size = Some(size);
95 self
96 }
97
98 #[must_use]
103 pub fn with_buffer_pool_count(mut self, count: usize) -> Self {
104 self.buffer_count = Some(count);
105 self
106 }
107
108 pub fn build(self) -> Result<NntpProxy> {
117 if self.config.servers.is_empty() {
118 anyhow::bail!("No servers configured in configuration");
119 }
120
121 let buffer_size = self.buffer_size.unwrap_or(POOL);
123 let buffer_count = self.buffer_count.unwrap_or(POOL_COUNT);
124
125 let connection_providers: Result<Vec<DeadpoolConnectionProvider>> = self
127 .config
128 .servers
129 .iter()
130 .map(|server| {
131 info!(
132 "Configuring deadpool connection provider for '{}'",
133 server.name
134 );
135 DeadpoolConnectionProvider::from_server_config(server)
136 })
137 .collect();
138
139 let connection_providers = connection_providers?;
140
141 let buffer_pool = BufferPool::new(
142 BufferSize::new(buffer_size)
143 .ok_or_else(|| anyhow::anyhow!("Buffer size must be non-zero"))?,
144 buffer_count,
145 );
146
147 let servers = Arc::new(self.config.servers);
148
149 let router = Arc::new({
151 use types::BackendId;
152 connection_providers.iter().enumerate().fold(
153 router::BackendSelector::new(),
154 |mut r, (idx, provider)| {
155 let backend_id = BackendId::from_index(idx);
156 r.add_backend(
157 backend_id,
158 servers[idx].name.as_str().to_string(),
159 provider.clone(),
160 );
161 r
162 },
163 )
164 });
165
166 let auth_handler = Arc::new(AuthHandler::new(
168 self.config.client_auth.username.clone(),
169 self.config.client_auth.password.clone(),
170 ));
171
172 Ok(NntpProxy {
173 servers,
174 router,
175 connection_providers,
176 buffer_pool,
177 routing_mode: self.routing_mode,
178 auth_handler,
179 })
180 }
181}
182
183#[derive(Debug, Clone)]
184pub struct NntpProxy {
185 servers: Arc<Vec<ServerConfig>>,
186 router: Arc<router::BackendSelector>,
188 connection_providers: Vec<DeadpoolConnectionProvider>,
190 buffer_pool: BufferPool,
192 routing_mode: RoutingMode,
194 auth_handler: Arc<AuthHandler>,
196}
197
198impl NntpProxy {
199 pub fn new(config: Config, routing_mode: RoutingMode) -> Result<Self> {
216 NntpProxyBuilder::new(config)
217 .with_routing_mode(routing_mode)
218 .build()
219 }
220
221 #[must_use]
238 pub fn builder(config: Config) -> NntpProxyBuilder {
239 NntpProxyBuilder::new(config)
240 }
241
242 pub async fn prewarm_connections(&self) -> Result<()> {
245 prewarm_pools(&self.connection_providers, &self.servers).await
246 }
247
248 pub async fn graceful_shutdown(&self) {
250 info!("Initiating graceful shutdown of all connection pools...");
251
252 for provider in &self.connection_providers {
253 provider.graceful_shutdown().await;
254 }
255
256 info!("All connection pools have been shut down gracefully");
257 }
258
259 #[inline]
261 pub fn servers(&self) -> &[ServerConfig] {
262 &self.servers
263 }
264
265 #[inline]
267 pub fn router(&self) -> &Arc<router::BackendSelector> {
268 &self.router
269 }
270
271 #[inline]
273 pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
274 &self.connection_providers
275 }
276
277 #[inline]
279 pub fn buffer_pool(&self) -> &BufferPool {
280 &self.buffer_pool
281 }
282
283 async fn setup_client_connection(
285 &self,
286 client_stream: &mut TcpStream,
287 client_addr: SocketAddr,
288 ) -> Result<()> {
289 crate::protocol::send_proxy_greeting(client_stream, client_addr).await
291 }
292
293 pub async fn handle_client(
294 &self,
295 mut client_stream: TcpStream,
296 client_addr: SocketAddr,
297 ) -> Result<()> {
298 debug!("New client connection from {}", client_addr);
299
300 use types::ClientId;
302 let client_id = ClientId::new();
303
304 let backend_id = self.router.route_command_sync(client_id, "")?;
306 let server_idx = backend_id.as_index();
307 let server = &self.servers[server_idx];
308
309 info!(
310 "Routing client {} to backend {:?} ({}:{})",
311 client_addr, backend_id, server.host, server.port
312 );
313
314 self.setup_client_connection(&mut client_stream, client_addr)
316 .await?;
317
318 let pool_status = self.connection_providers[server_idx].status();
320 debug!(
321 "Pool status for {}: {}/{} available, {} created",
322 server.name, pool_status.available, pool_status.max_size, pool_status.created
323 );
324
325 let mut backend_conn = match self.connection_providers[server_idx]
326 .get_pooled_connection()
327 .await
328 {
329 Ok(conn) => {
330 debug!("Got pooled connection for {}", server.name);
331 conn
332 }
333 Err(e) => {
334 error!(
335 "Failed to get pooled connection for {} (client {}): {}",
336 server.name, client_addr, e
337 );
338 let _ = client_stream.write_all(BACKEND_UNAVAILABLE).await;
339 return Err(anyhow::anyhow!(
340 "Failed to get pooled connection for backend '{}' (client {}): {}",
341 server.name,
342 client_addr,
343 e
344 ));
345 }
346 };
347
348 let client_optimizer = TcpOptimizer::new(&client_stream);
350 if let Err(e) = client_optimizer.optimize() {
351 debug!("Failed to optimize client socket: {}", e);
352 }
353
354 let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
355 if let Err(e) = backend_optimizer.optimize() {
356 debug!("Failed to optimize backend socket: {}", e);
357 }
358
359 let session = ClientSession::new(
361 client_addr,
362 self.buffer_pool.clone(),
363 self.auth_handler.clone(),
364 );
365 debug!("Starting session for client {}", client_addr);
366
367 let copy_result = session
368 .handle_with_pooled_backend(client_stream, &mut *backend_conn)
369 .await;
370
371 debug!("Session completed for client {}", client_addr);
372
373 self.router.complete_command_sync(backend_id);
375
376 match copy_result {
378 Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
379 info!(
380 "Connection closed for client {}: {} bytes sent, {} bytes received",
381 client_addr, client_to_backend_bytes, backend_to_client_bytes
382 );
383 }
384 Err(e) => {
385 if crate::pool::is_connection_error(&e) {
387 warn!(
388 "Backend connection error for client {}: {} - removing connection from pool",
389 client_addr, e
390 );
391 crate::pool::remove_from_pool(backend_conn);
392 return Err(e);
393 }
394 warn!("Session error for client {}: {}", client_addr, e);
395 }
396 }
397
398 debug!("Connection returned to pool for {}", server.name);
399 Ok(())
400 }
401
402 pub async fn handle_client_per_command_routing(
407 &self,
408 client_stream: TcpStream,
409 client_addr: SocketAddr,
410 ) -> Result<()> {
411 debug!(
412 "New per-command routing client connection from {}",
413 client_addr
414 );
415
416 if let Err(e) = client_stream.set_nodelay(true) {
418 debug!("Failed to set TCP_NODELAY for {}: {}", client_addr, e);
419 }
420
421 let session = ClientSession::new_with_router(
427 client_addr,
428 self.buffer_pool.clone(),
429 self.router.clone(),
430 self.routing_mode,
431 self.auth_handler.clone(),
432 );
433
434 let session_id = crate::formatting::short_id(session.client_id().as_uuid());
435
436 info!(
437 "Client {} [{}] connected in per-command routing mode",
438 client_addr, session_id
439 );
440
441 let result = session
443 .handle_per_command_routing(client_stream)
444 .await
445 .with_context(|| {
446 format!(
447 "Per-command routing session failed for {} [{}]",
448 client_addr, session_id
449 )
450 });
451
452 match result {
454 Ok((client_to_backend, backend_to_client)) => {
455 info!(
456 "Session closed {} [{}] ↑{} ↓{}",
457 client_addr,
458 session_id,
459 crate::formatting::format_bytes(client_to_backend),
460 crate::formatting::format_bytes(backend_to_client)
461 );
462 }
463 Err(e) => {
464 let is_broken_pipe = e.downcast_ref::<std::io::Error>().is_some_and(|io_err| {
466 matches!(
467 io_err.kind(),
468 std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset
469 )
470 });
471
472 if is_broken_pipe {
473 debug!(
474 "Client {} [{}] disconnected: {} (normal for test connections)",
475 client_addr, session_id, e
476 );
477 } else {
478 warn!("Session error {} [{}]: {}", client_addr, session_id, e);
479 }
480 }
481 }
482
483 debug!(
484 "Per-command routing connection closed for {} (ID: {})",
485 client_addr,
486 session.client_id()
487 );
488 Ok(())
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use std::sync::Arc;
496
497 fn create_test_config() -> Config {
498 use crate::config::{health_check_max_per_cycle, health_check_pool_timeout};
499 use crate::types::{HostName, MaxConnections, Port, ServerName};
500 Config {
501 servers: vec![
502 ServerConfig {
503 host: HostName::new("server1.example.com".to_string()).unwrap(),
504 port: Port::new(119).unwrap(),
505 name: ServerName::new("Test Server 1".to_string()).unwrap(),
506 username: None,
507 password: None,
508 max_connections: MaxConnections::new(5).unwrap(),
509 use_tls: false,
510 tls_verify_cert: true,
511 tls_cert_path: None,
512 connection_keepalive: None,
513 health_check_max_per_cycle: health_check_max_per_cycle(),
514 health_check_pool_timeout: health_check_pool_timeout(),
515 },
516 ServerConfig {
517 host: HostName::new("server2.example.com".to_string()).unwrap(),
518 port: Port::new(119).unwrap(),
519 name: ServerName::new("Test Server 2".to_string()).unwrap(),
520 username: None,
521 password: None,
522 max_connections: MaxConnections::new(8).unwrap(),
523 use_tls: false,
524 tls_verify_cert: true,
525 tls_cert_path: None,
526 connection_keepalive: None,
527 health_check_max_per_cycle: health_check_max_per_cycle(),
528 health_check_pool_timeout: health_check_pool_timeout(),
529 },
530 ServerConfig {
531 host: HostName::new("server3.example.com".to_string()).unwrap(),
532 port: Port::new(119).unwrap(),
533 name: ServerName::new("Test Server 3".to_string()).unwrap(),
534 username: None,
535 password: None,
536 max_connections: MaxConnections::new(12).unwrap(),
537 use_tls: false,
538 tls_verify_cert: true,
539 tls_cert_path: None,
540 connection_keepalive: None,
541 health_check_max_per_cycle: health_check_max_per_cycle(),
542 health_check_pool_timeout: health_check_pool_timeout(),
543 },
544 ],
545 ..Default::default()
546 }
547 }
548
549 #[test]
550 fn test_proxy_creation_with_servers() {
551 let config = create_test_config();
552 let proxy = Arc::new(
553 NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
554 );
555
556 assert_eq!(proxy.servers().len(), 3);
557 assert_eq!(proxy.servers()[0].name.as_str(), "Test Server 1");
558 }
559
560 #[test]
561 fn test_proxy_creation_with_empty_servers() {
562 let config = Config {
563 servers: vec![],
564 ..Default::default()
565 };
566 let result = NntpProxy::new(config, RoutingMode::Standard);
567
568 assert!(result.is_err());
569 assert!(
570 result
571 .unwrap_err()
572 .to_string()
573 .contains("No servers configured")
574 );
575 }
576
577 #[test]
578 fn test_proxy_has_router() {
579 let config = create_test_config();
580 let proxy = Arc::new(
581 NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
582 );
583
584 assert_eq!(proxy.router.backend_count(), 3);
586 }
587
588 #[test]
589 fn test_builder_basic_usage() {
590 let config = create_test_config();
591 let proxy = NntpProxy::builder(config)
592 .build()
593 .expect("Failed to build proxy");
594
595 assert_eq!(proxy.servers().len(), 3);
596 assert_eq!(proxy.router.backend_count(), 3);
597 }
598
599 #[test]
600 fn test_builder_with_routing_mode() {
601 let config = create_test_config();
602 let proxy = NntpProxy::builder(config)
603 .with_routing_mode(RoutingMode::PerCommand)
604 .build()
605 .expect("Failed to build proxy");
606
607 assert_eq!(proxy.servers().len(), 3);
608 }
609
610 #[test]
611 fn test_builder_with_custom_buffer_pool() {
612 let config = create_test_config();
613 let proxy = NntpProxy::builder(config)
614 .with_buffer_pool_size(512 * 1024)
615 .with_buffer_pool_count(64)
616 .build()
617 .expect("Failed to build proxy");
618
619 assert_eq!(proxy.servers().len(), 3);
620 }
622
623 #[test]
624 fn test_builder_with_all_options() {
625 let config = create_test_config();
626 let proxy = NntpProxy::builder(config)
627 .with_routing_mode(RoutingMode::Hybrid)
628 .with_buffer_pool_size(1024 * 1024)
629 .with_buffer_pool_count(16)
630 .build()
631 .expect("Failed to build proxy");
632
633 assert_eq!(proxy.servers().len(), 3);
634 assert_eq!(proxy.router.backend_count(), 3);
635 }
636
637 #[test]
638 fn test_builder_empty_servers_error() {
639 let config = Config {
640 servers: vec![],
641 ..Default::default()
642 };
643 let result = NntpProxy::builder(config).build();
644
645 assert!(result.is_err());
646 assert!(
647 result
648 .unwrap_err()
649 .to_string()
650 .contains("No servers configured")
651 );
652 }
653
654 #[test]
655 fn test_backward_compatibility_new() {
656 let config = create_test_config();
658 let proxy = NntpProxy::new(config, RoutingMode::Standard)
659 .expect("Failed to create proxy with new()");
660
661 assert_eq!(proxy.servers().len(), 3);
662 assert_eq!(proxy.router.backend_count(), 3);
663 }
664}