nntp_proxy/network/
optimizers.rs

1//! Network optimization traits and implementations
2//!
3//! This module provides a trait-based approach to network optimizations,
4//! allowing different optimization strategies for TCP and TLS connections.
5
6use crate::stream::ConnectionStream;
7use crate::tls::TlsStream;
8use anyhow::{Context, Result};
9use socket2::SockRef;
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tracing::debug;
13
14/// SO_LINGER timeout - prevents indefinite blocking on socket close
15const LINGER_TIMEOUT: Duration = Duration::from_secs(5);
16
17/// TCP_USER_TIMEOUT - faster dead connection detection on Linux
18const TCP_USER_TIMEOUT: Duration = Duration::from_secs(30);
19
20/// IP_TOS value for throughput optimization
21const TOS_THROUGHPUT: u32 = 0x08;
22
23/// Trait for network optimization strategies
24pub trait NetworkOptimizer {
25    /// Apply optimizations to improve network performance
26    fn optimize(&self) -> Result<()>;
27
28    /// Get a description of the optimization strategy
29    fn description(&self) -> &'static str;
30}
31
32/// Apply core TCP optimizations to a socket reference
33fn apply_core_optimizations(
34    sock_ref: &SockRef,
35    recv_buffer_size: usize,
36    send_buffer_size: usize,
37) -> Result<()> {
38    sock_ref
39        .set_recv_buffer_size(recv_buffer_size)
40        .context("Failed to set TCP receive buffer size")?;
41
42    sock_ref
43        .set_send_buffer_size(send_buffer_size)
44        .context("Failed to set TCP send buffer size")?;
45
46    sock_ref
47        .set_linger(Some(LINGER_TIMEOUT))
48        .context("Failed to set SO_LINGER timeout")?;
49
50    Ok(())
51}
52
53/// Apply Linux-specific TCP optimizations (best-effort)
54#[cfg(target_os = "linux")]
55fn apply_linux_optimizations(sock_ref: &SockRef, context: &str) {
56    [
57        (
58            "TCP_USER_TIMEOUT",
59            sock_ref.set_tcp_user_timeout(Some(TCP_USER_TIMEOUT)),
60        ),
61        ("IP_TOS", sock_ref.set_tos_v4(TOS_THROUGHPUT)),
62    ]
63    .into_iter()
64    .filter_map(|(name, result)| result.err().map(|e| (name, e)))
65    .for_each(|(name, err)| {
66        debug!("Failed to set {} on {}: {}", name, context, err);
67    });
68}
69
70/// Get platform-specific optimization description
71const fn platform_optimization_desc() -> &'static str {
72    match () {
73        #[cfg(target_os = "linux")]
74        () => ", tcp_user_timeout=30s, tos=0x08",
75        #[cfg(target_os = "windows")]
76        () => " (Windows)",
77        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
78        () => "",
79    }
80}
81
82/// TCP-specific optimizations for high-throughput scenarios
83pub struct TcpOptimizer<'a> {
84    stream: &'a TcpStream,
85    recv_buffer_size: usize,
86    send_buffer_size: usize,
87}
88
89impl<'a> TcpOptimizer<'a> {
90    /// Create a new TCP optimizer with default high-throughput settings
91    pub fn new(stream: &'a TcpStream) -> Self {
92        Self {
93            stream,
94            recv_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER,
95            send_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER,
96        }
97    }
98
99    /// Create optimizer with custom buffer sizes using builder pattern
100    pub const fn with_buffer_sizes(
101        stream: &'a TcpStream,
102        recv_size: usize,
103        send_size: usize,
104    ) -> Self {
105        Self {
106            stream,
107            recv_buffer_size: recv_size,
108            send_buffer_size: send_size,
109        }
110    }
111}
112
113impl<'a> NetworkOptimizer for TcpOptimizer<'a> {
114    fn optimize(&self) -> Result<()> {
115        let sock_ref = SockRef::from(self.stream);
116
117        // Core optimizations (required)
118        apply_core_optimizations(&sock_ref, self.recv_buffer_size, self.send_buffer_size)
119            .context("Failed to apply core TCP optimizations")?;
120
121        // Platform-specific optimizations (best-effort)
122        #[cfg(target_os = "linux")]
123        apply_linux_optimizations(&sock_ref, "TCP stream");
124
125        debug!(
126            "Applied TCP optimizations: recv_buffer={}, send_buffer={}, linger={}s{}",
127            self.recv_buffer_size,
128            self.send_buffer_size,
129            LINGER_TIMEOUT.as_secs(),
130            platform_optimization_desc()
131        );
132
133        Ok(())
134    }
135
136    fn description(&self) -> &'static str {
137        "TCP high-throughput optimization"
138    }
139}
140
141/// TLS-specific optimizations that work on the underlying TCP stream
142pub struct TlsOptimizer<'a> {
143    stream: &'a TlsStream<TcpStream>,
144    recv_buffer_size: usize,
145    send_buffer_size: usize,
146}
147
148impl<'a> TlsOptimizer<'a> {
149    /// Create a new TLS optimizer with default settings
150    pub fn new(stream: &'a TlsStream<TcpStream>) -> Self {
151        Self {
152            stream,
153            recv_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER,
154            send_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER,
155        }
156    }
157
158    /// Create optimizer with custom buffer sizes using builder pattern
159    pub const fn with_buffer_sizes(
160        stream: &'a TlsStream<TcpStream>,
161        recv_size: usize,
162        send_size: usize,
163    ) -> Self {
164        Self {
165            stream,
166            recv_buffer_size: recv_size,
167            send_buffer_size: send_size,
168        }
169    }
170}
171
172impl<'a> NetworkOptimizer for TlsOptimizer<'a> {
173    fn optimize(&self) -> Result<()> {
174        // Get the underlying TCP stream for optimization
175        let tcp_stream = self.stream.get_ref().0;
176        let sock_ref = SockRef::from(tcp_stream);
177
178        // Core optimizations (required)
179        apply_core_optimizations(&sock_ref, self.recv_buffer_size, self.send_buffer_size)
180            .context("Failed to apply core TCP optimizations to TLS stream")?;
181
182        // Platform-specific optimizations (best-effort)
183        #[cfg(target_os = "linux")]
184        apply_linux_optimizations(&sock_ref, "TLS stream");
185
186        debug!(
187            "Applied TLS optimizations to underlying TCP stream: recv_buffer={}, send_buffer={}, linger={}s{}",
188            self.recv_buffer_size,
189            self.send_buffer_size,
190            LINGER_TIMEOUT.as_secs(),
191            platform_optimization_desc()
192        );
193
194        Ok(())
195    }
196
197    fn description(&self) -> &'static str {
198        "TLS optimization via underlying TCP stream"
199    }
200}
201
202/// High-level optimizer that works with ConnectionStream
203pub struct ConnectionOptimizer<'a> {
204    stream: &'a ConnectionStream,
205    recv_buffer_size: Option<usize>,
206    send_buffer_size: Option<usize>,
207}
208
209impl<'a> ConnectionOptimizer<'a> {
210    /// Create a new connection optimizer with default buffer sizes
211    pub fn new(stream: &'a ConnectionStream) -> Self {
212        Self {
213            stream,
214            recv_buffer_size: None,
215            send_buffer_size: None,
216        }
217    }
218
219    /// Create a connection optimizer with custom buffer sizes
220    pub fn with_buffer_sizes(
221        stream: &'a ConnectionStream,
222        recv_size: usize,
223        send_size: usize,
224    ) -> Self {
225        Self {
226            stream,
227            recv_buffer_size: Some(recv_size),
228            send_buffer_size: Some(send_size),
229        }
230    }
231}
232
233impl<'a> NetworkOptimizer for ConnectionOptimizer<'a> {
234    fn optimize(&self) -> Result<()> {
235        // Use functional pattern matching to create and optimize in one step
236        let optimize_fn = |desc: &str, result: Result<()>| {
237            debug!("Using {}", desc);
238            result
239        };
240
241        match (self.recv_buffer_size, self.send_buffer_size, self.stream) {
242            // Custom buffer sizes
243            (Some(recv), Some(send), ConnectionStream::Plain(tcp)) => optimize_fn(
244                "TCP high-throughput optimization with custom buffers",
245                TcpOptimizer::with_buffer_sizes(tcp, recv, send).optimize(),
246            ),
247            (Some(recv), Some(send), ConnectionStream::Tls(tls)) => optimize_fn(
248                "TLS optimization via underlying TCP stream with custom buffers",
249                TlsOptimizer::with_buffer_sizes(tls.as_ref(), recv, send).optimize(),
250            ),
251            // Default buffer sizes
252            (_, _, ConnectionStream::Plain(tcp)) => optimize_fn(
253                "TCP high-throughput optimization",
254                TcpOptimizer::new(tcp).optimize(),
255            ),
256            (_, _, ConnectionStream::Tls(tls)) => optimize_fn(
257                "TLS optimization via underlying TCP stream",
258                TlsOptimizer::new(tls.as_ref()).optimize(),
259            ),
260        }
261    }
262
263    fn description(&self) -> &'static str {
264        match self.stream {
265            ConnectionStream::Plain(_) => "Connection-level TCP optimization",
266            ConnectionStream::Tls(_) => "Connection-level TLS optimization",
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use tokio::net::TcpListener;
275
276    #[tokio::test]
277    async fn test_tcp_optimizer() {
278        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
279        let addr = listener.local_addr().unwrap();
280
281        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
282        let optimizer = TcpOptimizer::new(&stream);
283
284        assert_eq!(optimizer.description(), "TCP high-throughput optimization");
285
286        // Should not panic - actual socket optimization might fail in test environment
287        let _ = optimizer.optimize();
288    }
289
290    #[tokio::test]
291    async fn test_connection_optimizer_tcp() {
292        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
293        let addr = listener.local_addr().unwrap();
294
295        let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
296        let connection_stream = ConnectionStream::Plain(tcp_stream);
297        let optimizer = ConnectionOptimizer::new(&connection_stream);
298
299        assert_eq!(optimizer.description(), "Connection-level TCP optimization");
300
301        // Should not panic - actual socket optimization might fail in test environment
302        let _ = optimizer.optimize();
303    }
304
305    #[tokio::test]
306    async fn test_connection_optimizer_trait_usage() {
307        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
308        let addr = listener.local_addr().unwrap();
309
310        let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
311        let connection_stream = ConnectionStream::Plain(tcp_stream);
312
313        // Test that ConnectionOptimizer implements NetworkOptimizer trait
314        let optimizer: Box<dyn NetworkOptimizer> =
315            Box::new(ConnectionOptimizer::new(&connection_stream));
316
317        assert_eq!(optimizer.description(), "Connection-level TCP optimization");
318        let _ = optimizer.optimize();
319    }
320
321    #[tokio::test]
322    async fn test_connection_optimizer_with_custom_buffers() {
323        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
324        let addr = listener.local_addr().unwrap();
325
326        let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
327        let connection_stream = ConnectionStream::Plain(tcp_stream);
328        let optimizer = ConnectionOptimizer::with_buffer_sizes(&connection_stream, 4096, 8192);
329
330        assert_eq!(optimizer.description(), "Connection-level TCP optimization");
331        let _ = optimizer.optimize();
332    }
333
334    #[tokio::test]
335    async fn test_optimizer_creation() {
336        // Test that we can create optimizers with tokio streams
337        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
338        let addr = listener.local_addr().unwrap();
339
340        let tokio_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
341
342        let optimizer = TcpOptimizer::new(&tokio_stream);
343        assert_eq!(
344            optimizer.recv_buffer_size,
345            crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER
346        );
347        assert_eq!(
348            optimizer.send_buffer_size,
349            crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER
350        );
351    }
352
353    #[tokio::test]
354    async fn test_custom_buffer_sizes() {
355        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
356        let addr = listener.local_addr().unwrap();
357
358        let tokio_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
359
360        let optimizer = TcpOptimizer::with_buffer_sizes(&tokio_stream, 1024, 2048);
361        assert_eq!(optimizer.recv_buffer_size, 1024);
362        assert_eq!(optimizer.send_buffer_size, 2048);
363    }
364}