1use anyhow::Result;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::net::TcpStream;
11use tracing::{debug, error, info, warn};
12
13use crate::config::{Config, ServerConfig};
14use crate::constants::buffer::{BUFFER_POOL_SIZE, BUFFER_SIZE};
15use crate::constants::stateless_proxy::*;
16use crate::network::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer};
17use crate::pool::{BufferPool, ConnectionProvider, DeadpoolConnectionProvider, prewarm_pools};
18use crate::router;
19use crate::session::ClientSession;
20use crate::types;
21
22#[derive(Debug, Clone)]
23pub struct NntpProxy {
24 servers: Arc<Vec<ServerConfig>>,
25 router: Arc<router::BackendSelector>,
27 connection_providers: Vec<DeadpoolConnectionProvider>,
29 buffer_pool: BufferPool,
31}
32
33impl NntpProxy {
34 pub fn new(config: Config) -> Result<Self> {
35 if config.servers.is_empty() {
36 anyhow::bail!("No servers configured in configuration");
37 }
38
39 let connection_providers: Vec<DeadpoolConnectionProvider> = config
41 .servers
42 .iter()
43 .map(|server| {
44 info!(
45 "Configuring deadpool connection provider for '{}'",
46 server.name
47 );
48 DeadpoolConnectionProvider::from_server_config(server)
49 })
50 .collect();
51
52 let buffer_pool = BufferPool::new(BUFFER_SIZE, BUFFER_POOL_SIZE);
53
54 let servers = Arc::new(config.servers);
55
56 let router = Arc::new({
58 use types::BackendId;
59 connection_providers.iter().enumerate().fold(
60 router::BackendSelector::new(),
61 |mut r, (idx, provider)| {
62 let backend_id = BackendId::from_index(idx);
63 r.add_backend(backend_id, servers[idx].name.clone(), provider.clone());
64 r
65 },
66 )
67 });
68
69 Ok(Self {
70 servers,
71 router,
72 connection_providers,
73 buffer_pool,
74 })
75 }
76
77 pub async fn prewarm_connections(&self) -> Result<()> {
80 prewarm_pools(&self.connection_providers, &self.servers).await
81 }
82
83 pub async fn graceful_shutdown(&self) {
85 info!("Initiating graceful shutdown of all connection pools...");
86
87 for provider in &self.connection_providers {
88 provider.graceful_shutdown().await;
89 }
90
91 info!("All connection pools have been shut down gracefully");
92 }
93
94 #[inline]
96 pub fn servers(&self) -> &[ServerConfig] {
97 &self.servers
98 }
99
100 #[inline]
102 pub fn router(&self) -> &Arc<router::BackendSelector> {
103 &self.router
104 }
105
106 #[inline]
108 pub fn connection_providers(&self) -> &[DeadpoolConnectionProvider] {
109 &self.connection_providers
110 }
111
112 #[inline]
114 pub fn buffer_pool(&self) -> &BufferPool {
115 &self.buffer_pool
116 }
117
118 async fn setup_client_connection(
120 &self,
121 client_stream: &mut TcpStream,
122 client_addr: SocketAddr,
123 ) -> Result<()> {
124 crate::protocol::send_proxy_greeting(client_stream, client_addr).await
126 }
127
128 pub async fn handle_client(
129 &self,
130 mut client_stream: TcpStream,
131 client_addr: SocketAddr,
132 ) -> Result<()> {
133 debug!("New client connection from {}", client_addr);
134
135 use types::ClientId;
137 let client_id = ClientId::new();
138
139 let backend_id = self.router.route_command_sync(client_id, "")?;
141 let server_idx = backend_id.as_index();
142 let server = &self.servers[server_idx];
143
144 info!(
145 "Routing client {} to backend {:?} ({}:{})",
146 client_addr, backend_id, server.host, server.port
147 );
148
149 self.setup_client_connection(&mut client_stream, client_addr)
151 .await?;
152
153 let pool_status = self.connection_providers[server_idx].status();
155 debug!(
156 "Pool status for {}: {}/{} available, {} created",
157 server.name, pool_status.available, pool_status.max_size, pool_status.created
158 );
159
160 let mut backend_conn = match self.connection_providers[server_idx]
161 .get_pooled_connection()
162 .await
163 {
164 Ok(conn) => {
165 debug!("Got pooled connection for {}", server.name);
166 conn
167 }
168 Err(e) => {
169 error!(
170 "Failed to get pooled connection for {} (client {}): {}",
171 server.name, client_addr, e
172 );
173 let _ = client_stream.write_all(NNTP_BACKEND_UNAVAILABLE).await;
174 return Err(anyhow::anyhow!(
175 "Failed to get pooled connection for backend '{}' (client {}): {}",
176 server.name,
177 client_addr,
178 e
179 ));
180 }
181 };
182
183 let client_optimizer = TcpOptimizer::new(&client_stream);
185 if let Err(e) = client_optimizer.optimize() {
186 debug!("Failed to optimize client socket: {}", e);
187 }
188
189 let backend_optimizer = ConnectionOptimizer::new(&backend_conn);
190 if let Err(e) = backend_optimizer.optimize() {
191 debug!("Failed to optimize backend socket: {}", e);
192 }
193
194 let session = ClientSession::new(client_addr, self.buffer_pool.clone());
196 debug!("Starting session for client {}", client_addr);
197
198 let copy_result = session
199 .handle_with_pooled_backend(client_stream, &mut *backend_conn)
200 .await;
201
202 debug!("Session completed for client {}", client_addr);
203
204 self.router.complete_command_sync(backend_id);
206
207 match copy_result {
209 Ok((client_to_backend_bytes, backend_to_client_bytes)) => {
210 info!(
211 "Connection closed for client {}: {} bytes sent, {} bytes received",
212 client_addr, client_to_backend_bytes, backend_to_client_bytes
213 );
214 }
215 Err(e) => {
216 warn!("Session error for client {}: {}", client_addr, e);
217 }
218 }
219
220 debug!("Connection returned to pool for {}", server.name);
221 Ok(())
222 }
223
224 pub async fn handle_client_per_command_routing(
229 &self,
230 mut client_stream: TcpStream,
231 client_addr: SocketAddr,
232 ) -> Result<()> {
233 debug!(
234 "New per-command routing client connection from {}",
235 client_addr
236 );
237
238 let _ = client_stream.set_nodelay(true);
240
241 self.setup_client_connection(&mut client_stream, client_addr)
243 .await?;
244
245 let session = ClientSession::new_with_router(
247 client_addr,
248 self.buffer_pool.clone(),
249 self.router.clone(),
250 );
251
252 info!(
253 "Client {} (ID: {}) connected in per-command routing mode",
254 client_addr,
255 session.client_id()
256 );
257
258 let result = session.handle_per_command_routing(client_stream).await;
260
261 match result {
263 Ok((client_to_backend, backend_to_client)) => {
264 info!(
265 "Per-command routing session closed for {} (ID: {}): {} bytes sent, {} bytes received",
266 client_addr,
267 session.client_id(),
268 client_to_backend,
269 backend_to_client
270 );
271 }
272 Err(e) => {
273 let is_broken_pipe = if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
275 matches!(io_err.kind(), std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset)
276 } else {
277 false
278 };
279
280 if is_broken_pipe {
281 debug!(
282 "Client {} (ID: {}) disconnected during session: {} - This is normal for test connections",
283 client_addr,
284 session.client_id(),
285 e
286 );
287 } else {
288 warn!(
289 "Per-command routing session error for {} (ID: {}): {}",
290 client_addr,
291 session.client_id(),
292 e
293 );
294 }
295
296 debug!(
299 "Session error details for {} (ID: {}): Error occurred during per-command routing. \
300 This may be a client test connection or early disconnection. \
301 Check session debug logs above for command/response details.",
302 client_addr,
303 session.client_id()
304 );
305 }
306 }
307
308 debug!(
309 "Per-command routing connection closed for {} (ID: {})",
310 client_addr,
311 session.client_id()
312 );
313 Ok(())
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use std::sync::Arc;
321
322 fn create_test_config() -> Config {
323 Config {
324 servers: vec![
325 ServerConfig {
326 host: "server1.example.com".to_string(),
327 port: 119,
328 name: "Test Server 1".to_string(),
329 username: None,
330 password: None,
331 max_connections: 5,
332 use_tls: false,
333 tls_verify_cert: true,
334 tls_cert_path: None,
335 },
336 ServerConfig {
337 host: "server2.example.com".to_string(),
338 port: 119,
339 name: "Test Server 2".to_string(),
340 username: None,
341 password: None,
342 max_connections: 8,
343 use_tls: false,
344 tls_verify_cert: true,
345 tls_cert_path: None,
346 },
347 ServerConfig {
348 host: "server3.example.com".to_string(),
349 port: 119,
350 name: "Test Server 3".to_string(),
351 username: None,
352 password: None,
353 max_connections: 12,
354 use_tls: false,
355 tls_verify_cert: true,
356 tls_cert_path: None,
357 },
358 ],
359 ..Default::default()
360 }
361 }
362
363 #[test]
364 fn test_proxy_creation_with_servers() {
365 let config = create_test_config();
366 let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
367
368 assert_eq!(proxy.servers().len(), 3);
369 assert_eq!(proxy.servers()[0].name, "Test Server 1");
370 }
371
372 #[test]
373 fn test_proxy_creation_with_empty_servers() {
374 let config = Config {
375 servers: vec![],
376 ..Default::default()
377 };
378 let result = NntpProxy::new(config);
379
380 assert!(result.is_err());
381 assert!(
382 result
383 .unwrap_err()
384 .to_string()
385 .contains("No servers configured")
386 );
387 }
388
389 #[test]
390 fn test_proxy_has_router() {
391 let config = create_test_config();
392 let proxy = Arc::new(NntpProxy::new(config).expect("Failed to create proxy"));
393
394 assert_eq!(proxy.router.backend_count(), 3);
396 }
397}