1use anyhow::{Context, Result};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::config::{Config, RoutingMode, ServerConfig};
14use crate::constants::buffer::{POOL, POOL_COUNT};
15use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
16use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
17use crate::protocol::BACKEND_UNAVAILABLE;
18use crate::router;
19use crate::session::ClientSession;
20use crate::types::{self, BufferSize};
21
22#[derive(Debug)]
54pub struct NntpProxyBuilder {
55 config: Config,
56 routing_mode: RoutingMode,
57 buffer_size: Option<usize>,
58 buffer_count: Option<usize>,
59}
60
61impl NntpProxyBuilder {
62 #[must_use]
66 pub fn new(config: Config) -> Self {
67 Self {
68 config,
69 routing_mode: RoutingMode::Standard,
70 buffer_size: None,
71 buffer_count: None,
72 }
73 }
74
75 #[must_use]
82 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
83 self.routing_mode = mode;
84 self
85 }
86
87 #[must_use]
92 pub fn with_buffer_pool_size(mut self, size: usize) -> Self {
93 self.buffer_size = Some(size);
94 self
95 }
96
97 #[must_use]
102 pub fn with_buffer_pool_count(mut self, count: usize) -> Self {
103 self.buffer_count = Some(count);
104 self
105 }
106
107 pub fn build(self) -> Result<NntpProxy> {
116 if self.config.servers.is_empty() {
117 anyhow::bail!("No servers configured in configuration");
118 }
119
120 let buffer_size = self.buffer_size.unwrap_or(POOL);
122 let buffer_count = self.buffer_count.unwrap_or(POOL_COUNT);
123
124 let connection_providers: Result<Vec<DeadpoolConnectionProvider>> = self
126 .config
127 .servers
128 .iter()
129 .map(|server| {
130 info!(
131 "Configuring deadpool connection provider for '{}'",
132 server.name
133 );
134 DeadpoolConnectionProvider::from_server_config(server)
135 })
136 .collect();
137
138 let connection_providers = connection_providers?;
139
140 let buffer_pool = BufferPool::new(
141 BufferSize::new(buffer_size)
142 .ok_or_else(|| anyhow::anyhow!("Buffer size must be non-zero"))?,
143 buffer_count,
144 );
145
146 let servers = Arc::new(self.config.servers);
147
148 let router = Arc::new({
150 use types::BackendId;
151 connection_providers.iter().enumerate().fold(
152 router::BackendSelector::new(),
153 |mut r, (idx, provider)| {
154 let backend_id = BackendId::from_index(idx);
155 r.add_backend(
156 backend_id,
157 servers[idx].name.as_str().to_string(),
158 provider.clone(),
159 );
160 r
161 },
162 )
163 });
164
165 Ok(NntpProxy {
166 servers,
167 router,
168 connection_providers,
169 buffer_pool,
170 routing_mode: self.routing_mode,
171 })
172 }
173}
174
175#[derive(Debug, Clone)]
176pub struct NntpProxy {
177 servers: Arc<Vec<ServerConfig>>,
178 router: Arc<router::BackendSelector>,
180 connection_providers: Vec<DeadpoolConnectionProvider>,
182 buffer_pool: BufferPool,
184 routing_mode: RoutingMode,
186}
187
188impl NntpProxy {
189 pub fn new(config: Config, routing_mode: RoutingMode) -> Result<Self> {
206 NntpProxyBuilder::new(config)
207 .with_routing_mode(routing_mode)
208 .build()
209 }
210
211 #[must_use]
228 pub fn builder(config: Config) -> NntpProxyBuilder {
229 NntpProxyBuilder::new(config)
230 }
231
232 pub async fn prewarm_connections(&self) -> Result<()> {
235 prewarm_pools(&self.connection_providers, &self.servers).await
236 }
237
238 pub async fn graceful_shutdown(&self) {
240 info!("Initiating graceful shutdown of all connection pools...");
241
242 for provider in &self.connection_providers {
243 provider.graceful_shutdown().await;
244 }
245
246 info!("All connection pools have been shut down gracefully");
247 }
248
249 #[inline]
251 pub fn servers(&self) -> &[ServerConfig] {
252 &self.servers
253 }
254
255 #[inline]
257 pub fn router(&self) -> &Arc<router::BackendSelector> {
258 &self.router
259 }
260
261 #[inline]
263 pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
264 &self.connection_providers
265 }
266
267 #[inline]
269 pub fn buffer_pool(&self) -> &BufferPool {
270 &self.buffer_pool
271 }
272
273 async fn setup_client_connection(
275 &self,
276 client_stream: &mut TcpStream,
277 client_addr: SocketAddr,
278 ) -> Result<()> {
279 crate::protocol::send_proxy_greeting(client_stream, client_addr).await
281 }
282
283 pub async fn handle_client(
284 &self,
285 mut client_stream: TcpStream,
286 client_addr: SocketAddr,
287 ) -> Result<()> {
288 debug!("New client connection from {}", client_addr);
289
290 use types::ClientId;
292 let client_id = ClientId::new();
293
294 let backend_id = self.router.route_command_sync(client_id, "")?;
296 let server_idx = backend_id.as_index();
297 let server = &self.servers[server_idx];
298
299 info!(
300 "Routing client {} to backend {:?} ({}:{})",
301 client_addr, backend_id, server.host, server.port
302 );
303
304 self.setup_client_connection(&mut client_stream, client_addr)
306 .await?;
307
308 let pool_status = self.connection_providers[server_idx].status();
310 debug!(
311 "Pool status for {}: {}/{} available, {} created",
312 server.name, pool_status.available, pool_status.max_size, pool_status.created
313 );
314
315 let mut backend_conn = match self.connection_providers[server_idx]
316 .get_pooled_connection()
317 .await
318 {
319 Ok(conn) => {
320 debug!("Got pooled connection for {}", server.name);
321 conn
322 }
323 Err(e) => {
324 error!(
325 "Failed to get pooled connection for {} (client {}): {}",
326 server.name, client_addr, e
327 );
328 let _ = client_stream.write_all(BACKEND_UNAVAILABLE).await;
329 return Err(anyhow::anyhow!(
330 "Failed to get pooled connection for backend '{}' (client {}): {}",
331 server.name,
332 client_addr,
333 e
334 ));
335 }
336 };
337
338 let client_optimizer = TcpOptimizer::new(&client_stream);
340 if let Err(e) = client_optimizer.optimize() {
341 debug!("Failed to optimize client socket: {}", e);
342 }
343
344 let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
345 if let Err(e) = backend_optimizer.optimize() {
346 debug!("Failed to optimize backend socket: {}", e);
347 }
348
349 let session = ClientSession::new(client_addr, self.buffer_pool.clone());
351 debug!("Starting session for client {}", client_addr);
352
353 let copy_result = session
354 .handle_with_pooled_backend(client_stream, &mut *backend_conn)
355 .await;
356
357 debug!("Session completed for client {}", client_addr);
358
359 self.router.complete_command_sync(backend_id);
361
362 match copy_result {
364 Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
365 info!(
366 "Connection closed for client {}: {} bytes sent, {} bytes received",
367 client_addr, client_to_backend_bytes, backend_to_client_bytes
368 );
369 }
370 Err(e) => {
371 if crate::pool::is_connection_error(&e) {
373 warn!(
374 "Backend connection error for client {}: {} - removing connection from pool",
375 client_addr, e
376 );
377 crate::pool::remove_from_pool(backend_conn);
378 return Err(e);
379 }
380 warn!("Session error for client {}: {}", client_addr, e);
381 }
382 }
383
384 debug!("Connection returned to pool for {}", server.name);
385 Ok(())
386 }
387
388 pub async fn handle_client_per_command_routing(
393 &self,
394 client_stream: TcpStream,
395 client_addr: SocketAddr,
396 ) -> Result<()> {
397 debug!(
398 "New per-command routing client connection from {}",
399 client_addr
400 );
401
402 if let Err(e) = client_stream.set_nodelay(true) {
404 debug!("Failed to set TCP_NODELAY for {}: {}", client_addr, e);
405 }
406
407 let session = ClientSession::new_with_router(
413 client_addr,
414 self.buffer_pool.clone(),
415 self.router.clone(),
416 self.routing_mode,
417 );
418
419 let session_id = crate::formatting::short_id(session.client_id().as_uuid());
420
421 info!(
422 "Client {} [{}] connected in per-command routing mode",
423 client_addr, session_id
424 );
425
426 let result = session
428 .handle_per_command_routing(client_stream)
429 .await
430 .with_context(|| {
431 format!(
432 "Per-command routing session failed for {} [{}]",
433 client_addr, session_id
434 )
435 });
436
437 match result {
439 Ok((client_to_backend, backend_to_client)) => {
440 info!(
441 "Session closed {} [{}] ↑{} ↓{}",
442 client_addr,
443 session_id,
444 crate::formatting::format_bytes(client_to_backend),
445 crate::formatting::format_bytes(backend_to_client)
446 );
447 }
448 Err(e) => {
449 let is_broken_pipe = e.downcast_ref::<std::io::Error>().is_some_and(|io_err| {
451 matches!(
452 io_err.kind(),
453 std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset
454 )
455 });
456
457 if is_broken_pipe {
458 debug!(
459 "Client {} [{}] disconnected: {} (normal for test connections)",
460 client_addr, session_id, e
461 );
462 } else {
463 warn!("Session error {} [{}]: {}", client_addr, session_id, e);
464 }
465 }
466 }
467
468 debug!(
469 "Per-command routing connection closed for {} (ID: {})",
470 client_addr,
471 session.client_id()
472 );
473 Ok(())
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use std::sync::Arc;
481
482 fn create_test_config() -> Config {
483 use crate::config::{health_check_max_per_cycle, health_check_pool_timeout};
484 use crate::types::{HostName, MaxConnections, Port, ServerName};
485 Config {
486 servers: vec![
487 ServerConfig {
488 host: HostName::new("server1.example.com".to_string()).unwrap(),
489 port: Port::new(119).unwrap(),
490 name: ServerName::new("Test Server 1".to_string()).unwrap(),
491 username: None,
492 password: None,
493 max_connections: MaxConnections::new(5).unwrap(),
494 use_tls: false,
495 tls_verify_cert: true,
496 tls_cert_path: None,
497 connection_keepalive: None,
498 health_check_max_per_cycle: health_check_max_per_cycle(),
499 health_check_pool_timeout: health_check_pool_timeout(),
500 },
501 ServerConfig {
502 host: HostName::new("server2.example.com".to_string()).unwrap(),
503 port: Port::new(119).unwrap(),
504 name: ServerName::new("Test Server 2".to_string()).unwrap(),
505 username: None,
506 password: None,
507 max_connections: MaxConnections::new(8).unwrap(),
508 use_tls: false,
509 tls_verify_cert: true,
510 tls_cert_path: None,
511 connection_keepalive: None,
512 health_check_max_per_cycle: health_check_max_per_cycle(),
513 health_check_pool_timeout: health_check_pool_timeout(),
514 },
515 ServerConfig {
516 host: HostName::new("server3.example.com".to_string()).unwrap(),
517 port: Port::new(119).unwrap(),
518 name: ServerName::new("Test Server 3".to_string()).unwrap(),
519 username: None,
520 password: None,
521 max_connections: MaxConnections::new(12).unwrap(),
522 use_tls: false,
523 tls_verify_cert: true,
524 tls_cert_path: None,
525 connection_keepalive: None,
526 health_check_max_per_cycle: health_check_max_per_cycle(),
527 health_check_pool_timeout: health_check_pool_timeout(),
528 },
529 ],
530 ..Default::default()
531 }
532 }
533
534 #[test]
535 fn test_proxy_creation_with_servers() {
536 let config = create_test_config();
537 let proxy = Arc::new(
538 NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
539 );
540
541 assert_eq!(proxy.servers().len(), 3);
542 assert_eq!(proxy.servers()[0].name.as_str(), "Test Server 1");
543 }
544
545 #[test]
546 fn test_proxy_creation_with_empty_servers() {
547 let config = Config {
548 servers: vec![],
549 ..Default::default()
550 };
551 let result = NntpProxy::new(config, RoutingMode::Standard);
552
553 assert!(result.is_err());
554 assert!(
555 result
556 .unwrap_err()
557 .to_string()
558 .contains("No servers configured")
559 );
560 }
561
562 #[test]
563 fn test_proxy_has_router() {
564 let config = create_test_config();
565 let proxy = Arc::new(
566 NntpProxy::new(config, RoutingMode::Standard).expect("Failed to create proxy"),
567 );
568
569 assert_eq!(proxy.router.backend_count(), 3);
571 }
572
573 #[test]
574 fn test_builder_basic_usage() {
575 let config = create_test_config();
576 let proxy = NntpProxy::builder(config)
577 .build()
578 .expect("Failed to build proxy");
579
580 assert_eq!(proxy.servers().len(), 3);
581 assert_eq!(proxy.router.backend_count(), 3);
582 }
583
584 #[test]
585 fn test_builder_with_routing_mode() {
586 let config = create_test_config();
587 let proxy = NntpProxy::builder(config)
588 .with_routing_mode(RoutingMode::PerCommand)
589 .build()
590 .expect("Failed to build proxy");
591
592 assert_eq!(proxy.servers().len(), 3);
593 }
594
595 #[test]
596 fn test_builder_with_custom_buffer_pool() {
597 let config = create_test_config();
598 let proxy = NntpProxy::builder(config)
599 .with_buffer_pool_size(512 * 1024)
600 .with_buffer_pool_count(64)
601 .build()
602 .expect("Failed to build proxy");
603
604 assert_eq!(proxy.servers().len(), 3);
605 }
607
608 #[test]
609 fn test_builder_with_all_options() {
610 let config = create_test_config();
611 let proxy = NntpProxy::builder(config)
612 .with_routing_mode(RoutingMode::Hybrid)
613 .with_buffer_pool_size(1024 * 1024)
614 .with_buffer_pool_count(16)
615 .build()
616 .expect("Failed to build proxy");
617
618 assert_eq!(proxy.servers().len(), 3);
619 assert_eq!(proxy.router.backend_count(), 3);
620 }
621
622 #[test]
623 fn test_builder_empty_servers_error() {
624 let config = Config {
625 servers: vec![],
626 ..Default::default()
627 };
628 let result = NntpProxy::builder(config).build();
629
630 assert!(result.is_err());
631 assert!(
632 result
633 .unwrap_err()
634 .to_string()
635 .contains("No servers configured")
636 );
637 }
638
639 #[test]
640 fn test_backward_compatibility_new() {
641 let config = create_test_config();
643 let proxy = NntpProxy::new(config, RoutingMode::Standard)
644 .expect("Failed to create proxy with new()");
645
646 assert_eq!(proxy.servers().len(), 3);
647 assert_eq!(proxy.router.backend_count(), 3);
648 }
649}