1use crate::stream::ConnectionStream;
7use crate::tls::TlsStream;
8use anyhow::{Context, Result};
9use socket2::SockRef;
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tracing::debug;
13
14const LINGER_TIMEOUT: Duration = Duration::from_secs(5);
16
17const TCP_USER_TIMEOUT: Duration = Duration::from_secs(30);
19
20const TOS_THROUGHPUT: u32 = 0x08;
22
23pub trait NetworkOptimizer {
25 fn optimize(&self) -> Result<()>;
27
28 fn description(&self) -> &'static str;
30}
31
32fn apply_core_optimizations(
34 sock_ref: &SockRef,
35 recv_buffer_size: usize,
36 send_buffer_size: usize,
37) -> Result<()> {
38 sock_ref
39 .set_recv_buffer_size(recv_buffer_size)
40 .context("Failed to set TCP receive buffer size")?;
41
42 sock_ref
43 .set_send_buffer_size(send_buffer_size)
44 .context("Failed to set TCP send buffer size")?;
45
46 sock_ref
47 .set_linger(Some(LINGER_TIMEOUT))
48 .context("Failed to set SO_LINGER timeout")?;
49
50 Ok(())
51}
52
53#[cfg(target_os = "linux")]
55fn apply_linux_optimizations(sock_ref: &SockRef, context: &str) {
56 [
57 (
58 "TCP_USER_TIMEOUT",
59 sock_ref.set_tcp_user_timeout(Some(TCP_USER_TIMEOUT)),
60 ),
61 ("IP_TOS", sock_ref.set_tos_v4(TOS_THROUGHPUT)),
62 ]
63 .into_iter()
64 .filter_map(|(name, result)| result.err().map(|e| (name, e)))
65 .for_each(|(name, err)| {
66 debug!("Failed to set {} on {}: {}", name, context, err);
67 });
68}
69
70const fn platform_optimization_desc() -> &'static str {
72 match () {
73 #[cfg(target_os = "linux")]
74 () => ", tcp_user_timeout=30s, tos=0x08",
75 #[cfg(target_os = "windows")]
76 () => " (Windows)",
77 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
78 () => "",
79 }
80}
81
82pub struct TcpOptimizer<'a> {
84 stream: &'a TcpStream,
85 recv_buffer_size: usize,
86 send_buffer_size: usize,
87}
88
89impl<'a> TcpOptimizer<'a> {
90 pub fn new(stream: &'a TcpStream) -> Self {
92 Self {
93 stream,
94 recv_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER,
95 send_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER,
96 }
97 }
98
99 pub const fn with_buffer_sizes(
101 stream: &'a TcpStream,
102 recv_size: usize,
103 send_size: usize,
104 ) -> Self {
105 Self {
106 stream,
107 recv_buffer_size: recv_size,
108 send_buffer_size: send_size,
109 }
110 }
111}
112
113impl<'a> NetworkOptimizer for TcpOptimizer<'a> {
114 fn optimize(&self) -> Result<()> {
115 let sock_ref = SockRef::from(self.stream);
116
117 apply_core_optimizations(&sock_ref, self.recv_buffer_size, self.send_buffer_size)
119 .context("Failed to apply core TCP optimizations")?;
120
121 #[cfg(target_os = "linux")]
123 apply_linux_optimizations(&sock_ref, "TCP stream");
124
125 debug!(
126 "Applied TCP optimizations: recv_buffer={}, send_buffer={}, linger={}s{}",
127 self.recv_buffer_size,
128 self.send_buffer_size,
129 LINGER_TIMEOUT.as_secs(),
130 platform_optimization_desc()
131 );
132
133 Ok(())
134 }
135
136 fn description(&self) -> &'static str {
137 "TCP high-throughput optimization"
138 }
139}
140
141pub struct TlsOptimizer<'a> {
143 stream: &'a TlsStream<TcpStream>,
144 recv_buffer_size: usize,
145 send_buffer_size: usize,
146}
147
148impl<'a> TlsOptimizer<'a> {
149 pub fn new(stream: &'a TlsStream<TcpStream>) -> Self {
151 Self {
152 stream,
153 recv_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER,
154 send_buffer_size: crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER,
155 }
156 }
157
158 pub const fn with_buffer_sizes(
160 stream: &'a TlsStream<TcpStream>,
161 recv_size: usize,
162 send_size: usize,
163 ) -> Self {
164 Self {
165 stream,
166 recv_buffer_size: recv_size,
167 send_buffer_size: send_size,
168 }
169 }
170}
171
172impl<'a> NetworkOptimizer for TlsOptimizer<'a> {
173 fn optimize(&self) -> Result<()> {
174 let tcp_stream = self.stream.get_ref().0;
176 let sock_ref = SockRef::from(tcp_stream);
177
178 apply_core_optimizations(&sock_ref, self.recv_buffer_size, self.send_buffer_size)
180 .context("Failed to apply core TCP optimizations to TLS stream")?;
181
182 #[cfg(target_os = "linux")]
184 apply_linux_optimizations(&sock_ref, "TLS stream");
185
186 debug!(
187 "Applied TLS optimizations to underlying TCP stream: recv_buffer={}, send_buffer={}, linger={}s{}",
188 self.recv_buffer_size,
189 self.send_buffer_size,
190 LINGER_TIMEOUT.as_secs(),
191 platform_optimization_desc()
192 );
193
194 Ok(())
195 }
196
197 fn description(&self) -> &'static str {
198 "TLS optimization via underlying TCP stream"
199 }
200}
201
202pub struct ConnectionOptimizer<'a> {
204 stream: &'a ConnectionStream,
205 recv_buffer_size: Option<usize>,
206 send_buffer_size: Option<usize>,
207}
208
209impl<'a> ConnectionOptimizer<'a> {
210 pub fn new(stream: &'a ConnectionStream) -> Self {
212 Self {
213 stream,
214 recv_buffer_size: None,
215 send_buffer_size: None,
216 }
217 }
218
219 pub fn with_buffer_sizes(
221 stream: &'a ConnectionStream,
222 recv_size: usize,
223 send_size: usize,
224 ) -> Self {
225 Self {
226 stream,
227 recv_buffer_size: Some(recv_size),
228 send_buffer_size: Some(send_size),
229 }
230 }
231}
232
233impl<'a> NetworkOptimizer for ConnectionOptimizer<'a> {
234 fn optimize(&self) -> Result<()> {
235 let optimize_fn = |desc: &str, result: Result<()>| {
237 debug!("Using {}", desc);
238 result
239 };
240
241 match (self.recv_buffer_size, self.send_buffer_size, self.stream) {
242 (Some(recv), Some(send), ConnectionStream::Plain(tcp)) => optimize_fn(
244 "TCP high-throughput optimization with custom buffers",
245 TcpOptimizer::with_buffer_sizes(tcp, recv, send).optimize(),
246 ),
247 (Some(recv), Some(send), ConnectionStream::Tls(tls)) => optimize_fn(
248 "TLS optimization via underlying TCP stream with custom buffers",
249 TlsOptimizer::with_buffer_sizes(tls.as_ref(), recv, send).optimize(),
250 ),
251 (_, _, ConnectionStream::Plain(tcp)) => optimize_fn(
253 "TCP high-throughput optimization",
254 TcpOptimizer::new(tcp).optimize(),
255 ),
256 (_, _, ConnectionStream::Tls(tls)) => optimize_fn(
257 "TLS optimization via underlying TCP stream",
258 TlsOptimizer::new(tls.as_ref()).optimize(),
259 ),
260 }
261 }
262
263 fn description(&self) -> &'static str {
264 match self.stream {
265 ConnectionStream::Plain(_) => "Connection-level TCP optimization",
266 ConnectionStream::Tls(_) => "Connection-level TLS optimization",
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use tokio::net::TcpListener;
275
276 #[tokio::test]
277 async fn test_tcp_optimizer() {
278 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
279 let addr = listener.local_addr().unwrap();
280
281 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
282 let optimizer = TcpOptimizer::new(&stream);
283
284 assert_eq!(optimizer.description(), "TCP high-throughput optimization");
285
286 let _ = optimizer.optimize();
288 }
289
290 #[tokio::test]
291 async fn test_connection_optimizer_tcp() {
292 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
293 let addr = listener.local_addr().unwrap();
294
295 let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
296 let connection_stream = ConnectionStream::Plain(tcp_stream);
297 let optimizer = ConnectionOptimizer::new(&connection_stream);
298
299 assert_eq!(optimizer.description(), "Connection-level TCP optimization");
300
301 let _ = optimizer.optimize();
303 }
304
305 #[tokio::test]
306 async fn test_connection_optimizer_trait_usage() {
307 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
308 let addr = listener.local_addr().unwrap();
309
310 let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
311 let connection_stream = ConnectionStream::Plain(tcp_stream);
312
313 let optimizer: Box<dyn NetworkOptimizer> =
315 Box::new(ConnectionOptimizer::new(&connection_stream));
316
317 assert_eq!(optimizer.description(), "Connection-level TCP optimization");
318 let _ = optimizer.optimize();
319 }
320
321 #[tokio::test]
322 async fn test_connection_optimizer_with_custom_buffers() {
323 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
324 let addr = listener.local_addr().unwrap();
325
326 let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
327 let connection_stream = ConnectionStream::Plain(tcp_stream);
328 let optimizer = ConnectionOptimizer::with_buffer_sizes(&connection_stream, 4096, 8192);
329
330 assert_eq!(optimizer.description(), "Connection-level TCP optimization");
331 let _ = optimizer.optimize();
332 }
333
334 #[tokio::test]
335 async fn test_optimizer_creation() {
336 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
338 let addr = listener.local_addr().unwrap();
339
340 let tokio_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
341
342 let optimizer = TcpOptimizer::new(&tokio_stream);
343 assert_eq!(
344 optimizer.recv_buffer_size,
345 crate::constants::socket::HIGH_THROUGHPUT_RECV_BUFFER
346 );
347 assert_eq!(
348 optimizer.send_buffer_size,
349 crate::constants::socket::HIGH_THROUGHPUT_SEND_BUFFER
350 );
351 }
352
353 #[tokio::test]
354 async fn test_custom_buffer_sizes() {
355 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
356 let addr = listener.local_addr().unwrap();
357
358 let tokio_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
359
360 let optimizer = TcpOptimizer::with_buffer_sizes(&tokio_stream, 1024, 2048);
361 assert_eq!(optimizer.recv_buffer_size, 1024);
362 assert_eq!(optimizer.send_buffer_size, 2048);
363 }
364}