1pub mod optimizers;
11
12use crate::constants::socket::{HIGH_THROUGHPUT_RECV_BUFFER, HIGH_THROUGHPUT_SEND_BUFFER};
13use crate::stream::ConnectionStream;
14use std::io;
15use tokio::net::TcpStream;
16use tracing::debug;
17
18pub use optimizers::{ConnectionOptimizer, NetworkOptimizer, TcpOptimizer, TlsOptimizer};
20
21pub struct SocketOptimizer;
23
24impl SocketOptimizer {
25    pub fn optimize_for_throughput(stream: &TcpStream) -> Result<(), io::Error> {
27        use socket2::SockRef;
28
29        let sock_ref = SockRef::from(stream);
30
31        sock_ref.set_recv_buffer_size(HIGH_THROUGHPUT_RECV_BUFFER)?;
33        sock_ref.set_send_buffer_size(HIGH_THROUGHPUT_SEND_BUFFER)?;
34
35        Ok(())
40    }
41
42    pub fn apply_to_connection_streams(
44        client_stream: &ConnectionStream,
45        backend_stream: &ConnectionStream,
46    ) -> anyhow::Result<()> {
47        debug!("Applying connection optimizations with trait-based approach");
48
49        let client_optimizer = ConnectionOptimizer::new(client_stream);
50        let backend_optimizer = ConnectionOptimizer::new(backend_stream);
51
52        client_optimizer.optimize()?;
53        backend_optimizer.optimize()?;
54
55        Ok(())
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[test]
64    fn test_constants() {
65        assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER, 16 * 1024 * 1024);
66        assert_eq!(HIGH_THROUGHPUT_SEND_BUFFER, 16 * 1024 * 1024);
67    }
68
69    #[test]
70    fn test_buffer_size_is_reasonable() {
71        const _: () = assert!(HIGH_THROUGHPUT_RECV_BUFFER >= 1024 * 1024); const _: () = assert!(HIGH_THROUGHPUT_RECV_BUFFER <= 128 * 1024 * 1024); const _: () = assert!(HIGH_THROUGHPUT_SEND_BUFFER >= 1024 * 1024);
77        const _: () = assert!(HIGH_THROUGHPUT_SEND_BUFFER <= 128 * 1024 * 1024);
78    }
79
80    #[test]
81    fn test_buffer_sizes_are_equal() {
82        assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER, HIGH_THROUGHPUT_SEND_BUFFER);
84    }
85
86    #[test]
87    fn test_buffer_sizes_are_power_of_two_or_multiple() {
88        let size = HIGH_THROUGHPUT_RECV_BUFFER;
90
91        assert_eq!(size % (1024 * 1024), 0);
93    }
94
95    #[test]
96    fn test_socket_optimizer_exists() {
97        let _ = SocketOptimizer;
99    }
100
101    #[tokio::test]
102    async fn test_optimize_for_throughput_with_real_socket() {
103        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
105        let addr = listener.local_addr().unwrap();
106
107        let client_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
109
110        let result = SocketOptimizer::optimize_for_throughput(&client_stream);
112
113        match result {
116            Ok(()) => {
117                assert!(client_stream.peer_addr().is_ok());
119            }
120            Err(e) => {
121                println!(
123                    "Buffer size not supported (expected on some systems): {}",
124                    e
125                );
126            }
127        }
128    }
129
130    #[tokio::test]
131    async fn test_apply_to_connection_streams() {
132        use crate::stream::ConnectionStream;
133
134        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
136        let addr = listener.local_addr().unwrap();
137
138        let client_tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
139        let (server_tcp, _) = listener.accept().await.unwrap();
140
141        let client_stream = ConnectionStream::plain(client_tcp);
142        let server_stream = ConnectionStream::plain(server_tcp);
143
144        let result = SocketOptimizer::apply_to_connection_streams(&client_stream, &server_stream);
146
147        assert!(result.is_ok());
149
150        assert!(client_stream.as_tcp_stream().unwrap().peer_addr().is_ok());
152        assert!(server_stream.as_tcp_stream().unwrap().peer_addr().is_ok());
153    }
154
155    #[tokio::test]
156    async fn test_connection_optimizer() {
157        use crate::stream::ConnectionStream;
158
159        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
161        let addr = listener.local_addr().unwrap();
162
163        let tcp_stream = std::net::TcpStream::connect(addr).unwrap();
164        tcp_stream.set_nonblocking(true).unwrap();
165        let tokio_stream = TcpStream::from_std(tcp_stream).unwrap();
166
167        let conn_stream = ConnectionStream::plain(tokio_stream);
168
169        let optimizer = ConnectionOptimizer::new(&conn_stream);
171        let result = optimizer.optimize();
172        assert!(result.is_ok());
173    }
174
175    #[test]
176    fn test_buffer_size_calculation() {
177        assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER, 16 * 1024 * 1024);
179
180        assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER, 16_777_216);
182
183        assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER / 1024, 16384); assert_eq!(HIGH_THROUGHPUT_RECV_BUFFER / (1024 * 1024), 16); }
187
188    #[test]
189    fn test_buffer_size_for_large_articles() {
190        let typical_large_article = 10 * 1024 * 1024; let very_large_article = 100 * 1024 * 1024; assert!(HIGH_THROUGHPUT_RECV_BUFFER > typical_large_article);
197
198        assert!(HIGH_THROUGHPUT_RECV_BUFFER < very_large_article);
200    }
201
202    #[tokio::test]
203    async fn test_new_trait_based_approach() {
204        use crate::stream::ConnectionStream;
205
206        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
208        let addr = listener.local_addr().unwrap();
209
210        let client_tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
211        let (server_tcp, _) = listener.accept().await.unwrap();
212
213        let client_stream = ConnectionStream::plain(client_tcp);
214        let server_stream = ConnectionStream::plain(server_tcp);
215
216        let result = SocketOptimizer::apply_to_connection_streams(&client_stream, &server_stream);
218        assert!(result.is_ok());
219
220        let client_optimizer = ConnectionOptimizer::new(&client_stream);
222        let server_optimizer = ConnectionOptimizer::new(&server_stream);
223
224        assert!(client_optimizer.optimize().is_ok());
225        assert!(server_optimizer.optimize().is_ok());
226    }
227}