fluvio_future/net/
tcp_stream.rs

1use std::time::Duration;
2
3use async_net::AsyncToSocketAddrs;
4
5use socket2::SockRef;
6use socket2::TcpKeepalive;
7use tracing::debug;
8
9use crate::net::TcpStream;
10
11/// This setting determines the time (in seconds) that a connection must be idle before the first keepalive packet is sent.
12/// The default value is 7200 seconds, or 2 hours. This means that if there is no data exchange on a connection for 2 hours,
13/// the system will send a keepalive packet to the remote host to check if the connection is still active
14const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(7200);
15/// This setting specifies the interval (in seconds) between successive keepalive packets if no response (acknowledgment)
16/// is received from the remote host. The default value is 75 seconds. If the first keepalive packet does not receive a response,
17/// the system will send additional keepalive packets every 75 seconds until it receives a response or reaches the maximum number
18/// of allowed probes (as defined by tcp_keepalive_probes).
19const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(75);
20/// This setting defines the maximum number of unacknowledged keepalive packets that the system will send before considering the connection dead.
21/// The default value is 9 probes. If the system sends 9 keepalive packets without receiving a response,
22/// it assumes the connection is dead and closes it.
23#[cfg(not(windows))]
24const TCP_KEEPALIVE_PROBES: u32 = 9;
25
26#[derive(Debug, Clone, Default)]
27pub struct SocketOpts {
28    pub nodelay: Option<bool>,
29    pub keepalive: Option<KeepaliveOpts>,
30}
31
32#[derive(Debug, Clone)]
33pub struct KeepaliveOpts {
34    pub time: Option<Duration>,
35    pub interval: Option<Duration>,
36    #[cfg(not(windows))]
37    pub retries: Option<u32>,
38}
39
40#[cfg(not(windows))]
41impl Default for KeepaliveOpts {
42    fn default() -> Self {
43        Self {
44            time: Some(TCP_KEEPALIVE_TIME),
45            interval: Some(TCP_KEEPALIVE_INTERVAL),
46            retries: Some(TCP_KEEPALIVE_PROBES),
47        }
48    }
49}
50
51#[cfg(windows)]
52impl Default for KeepaliveOpts {
53    fn default() -> Self {
54        Self {
55            time: Some(TCP_KEEPALIVE_TIME),
56            interval: Some(TCP_KEEPALIVE_INTERVAL),
57        }
58    }
59}
60
61impl From<&KeepaliveOpts> for TcpKeepalive {
62    fn from(value: &KeepaliveOpts) -> Self {
63        let mut result = TcpKeepalive::new();
64        if let Some(time) = value.time {
65            result = result.with_time(time);
66        }
67        if let Some(interval) = value.interval {
68            result = result.with_interval(interval);
69        }
70        cfg_if::cfg_if! {
71            if #[cfg(not(windows))] {
72                if let Some(retries) = value.retries {
73                    result = result.with_retries(retries);
74                }
75            }
76        }
77        result
78    }
79}
80
81pub async fn stream<A: AsyncToSocketAddrs>(addr: A) -> Result<TcpStream, std::io::Error> {
82    let socket_opts = SocketOpts {
83        keepalive: Some(Default::default()),
84        ..Default::default()
85    };
86    stream_with_opts(addr, Some(socket_opts)).await
87}
88
89pub async fn stream_with_opts<A: AsyncToSocketAddrs>(
90    addr: A,
91    socket_opts: Option<SocketOpts>,
92) -> Result<TcpStream, std::io::Error> {
93    debug!(?socket_opts);
94    let tcp_stream = TcpStream::connect(addr).await?;
95    if let Some(socket_opts) = socket_opts {
96        let socket_ref = SockRef::from(&tcp_stream);
97        if let Some(nodelay) = socket_opts.nodelay {
98            socket_ref.set_nodelay(nodelay)?;
99        }
100        if let Some(ref keepalive_opts) = socket_opts.keepalive {
101            let keepalive = TcpKeepalive::from(keepalive_opts);
102            socket_ref.set_tcp_keepalive(&keepalive)?;
103        }
104    }
105    Ok(tcp_stream)
106}
107
108#[cfg(test)]
109mod tests {
110    use std::io::Error;
111
112    use super::*;
113    use crate::test_async;
114    use crate::timer::sleep;
115    use async_net::SocketAddr;
116    use async_net::TcpListener;
117    use bytes::BufMut;
118    use bytes::Bytes;
119    use bytes::BytesMut;
120    use futures_lite::AsyncReadExt;
121    use futures_lite::future::zip;
122    use futures_util::SinkExt;
123    use futures_util::StreamExt;
124    use tokio_util::codec::BytesCodec;
125    use tokio_util::codec::Framed;
126    use tokio_util::compat::FuturesAsyncReadCompatExt;
127    use tracing::debug;
128
129    fn to_bytes(bytes: Vec<u8>) -> Bytes {
130        let mut buf = BytesMut::with_capacity(bytes.len());
131        buf.put_slice(&bytes);
132        buf.freeze()
133    }
134
135    #[test_async]
136    async fn test_async_tcp() -> Result<(), Error> {
137        let addr = "127.0.0.1:9998".parse::<SocketAddr>().expect("parse");
138
139        let server_ft = async {
140            debug!("server: binding");
141            let listener = TcpListener::bind(&addr).await?;
142            debug!("server: successfully binding. waiting for incoming");
143            let mut incoming = listener.incoming();
144            let stream = incoming.next().await.expect("no stream");
145            debug!("server: got connection from client");
146            let tcp_stream = stream?;
147            let mut framed = Framed::new(tcp_stream.compat(), BytesCodec::new());
148            debug!("server: sending values to client");
149            let data = vec![0x05, 0x0a, 0x63];
150            framed.send(to_bytes(data)).await?;
151            Ok(()) as Result<(), Error>
152        };
153
154        let client_ft = async {
155            debug!("client: sleep to give server chance to come up");
156            sleep(Duration::from_millis(100)).await;
157            debug!("client: trying to connect");
158            let tcp_stream = stream(&addr).await?;
159            let mut framed = Framed::new(tcp_stream.compat(), BytesCodec::new());
160            debug!("client: got connection. waiting");
161            let value = framed.next().await.expect("no value received");
162            debug!("client :received first value from server");
163            let bytes = value?;
164            debug!("client :received bytes len: {}", bytes.len());
165            assert_eq!(bytes.len(), 3);
166            let values = bytes.take(3).into_inner();
167            assert_eq!(values[0], 0x05);
168            assert_eq!(values[1], 0x0a);
169            assert_eq!(values[2], 0x63);
170
171            Ok(()) as Result<(), Error>
172        };
173
174        let _ = zip(client_ft, server_ft).await;
175
176        Ok(())
177    }
178
179    #[test_async]
180    async fn test_tcp_stream_socket_opts() -> Result<(), Error> {
181        let addr = "127.0.0.1:9997".parse::<SocketAddr>().expect("parse");
182
183        let server_ft = async {
184            debug!("server: binding");
185            let listener = TcpListener::bind(&addr).await?;
186            debug!("server: successfully binding. waiting for incoming");
187            let mut incoming = listener.incoming();
188            let _stream = incoming.next().await.expect("no stream");
189            let _stream = incoming.next().await.expect("no stream");
190            debug!("server: got connection from client");
191            Ok(()) as Result<(), Error>
192        };
193
194        let client_ft = async {
195            debug!("client: sleep to give server chance to come up");
196            sleep(Duration::from_millis(100)).await;
197            debug!("client: trying to connect");
198            {
199                let socket_opts = SocketOpts {
200                    keepalive: None,
201                    nodelay: Some(false),
202                };
203                let tcp_stream = stream_with_opts(&addr, Some(socket_opts)).await?;
204                assert!(!tcp_stream.nodelay()?);
205                let socket_ref = SockRef::from(&tcp_stream);
206                assert!(!(socket_ref.nodelay()?));
207                assert!(!(socket_ref.keepalive()?));
208            }
209            {
210                let time = Duration::from_secs(7201);
211                let interval = Duration::from_secs(76);
212                cfg_if::cfg_if! {
213                    if #[cfg(windows)] {
214                        let socket_opts = SocketOpts {
215                            keepalive: Some(KeepaliveOpts {
216                                time: Some(time),
217                                interval: Some(interval),
218                            }),
219                            nodelay: Some(true),
220                        };
221                    } else {
222                        let retries = 10;
223                        let socket_opts = SocketOpts {
224                            keepalive: Some(KeepaliveOpts {
225                                time: Some(time),
226                                interval: Some(interval),
227                                retries: Some(retries),
228                            }),
229                            nodelay: Some(true),
230                        };
231                    }
232                }
233
234                let tcp_stream = stream_with_opts(&addr, Some(socket_opts)).await?;
235                assert!(tcp_stream.nodelay()?);
236                let socket_ref = SockRef::from(&tcp_stream);
237                assert!(socket_ref.nodelay()?);
238                assert!(socket_ref.keepalive()?);
239                cfg_if::cfg_if! {
240                    if #[cfg(not(windows))] {
241                        assert_eq!(socket_ref.keepalive_time()?, time);
242                        assert_eq!(socket_ref.keepalive_interval()?, interval);
243                        assert_eq!(socket_ref.keepalive_retries()?, retries);
244                    }
245                }
246            }
247
248            Ok(()) as Result<(), Error>
249        };
250
251        let _ = zip(client_ft, server_ft).await;
252
253        Ok(())
254    }
255}