fluvio_future/net/
tcp_stream.rs1use std::time::Duration;
2
3use async_net::AsyncToSocketAddrs;
4
5use socket2::SockRef;
6use socket2::TcpKeepalive;
7use tracing::debug;
8
9use crate::net::TcpStream;
10
11const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(7200);
15const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(75);
20#[cfg(not(windows))]
24const TCP_KEEPALIVE_PROBES: u32 = 9;
25
26#[derive(Debug, Clone, Default)]
27pub struct SocketOpts {
28 pub nodelay: Option<bool>,
29 pub keepalive: Option<KeepaliveOpts>,
30}
31
32#[derive(Debug, Clone)]
33pub struct KeepaliveOpts {
34 pub time: Option<Duration>,
35 pub interval: Option<Duration>,
36 #[cfg(not(windows))]
37 pub retries: Option<u32>,
38}
39
40#[cfg(not(windows))]
41impl Default for KeepaliveOpts {
42 fn default() -> Self {
43 Self {
44 time: Some(TCP_KEEPALIVE_TIME),
45 interval: Some(TCP_KEEPALIVE_INTERVAL),
46 retries: Some(TCP_KEEPALIVE_PROBES),
47 }
48 }
49}
50
51#[cfg(windows)]
52impl Default for KeepaliveOpts {
53 fn default() -> Self {
54 Self {
55 time: Some(TCP_KEEPALIVE_TIME),
56 interval: Some(TCP_KEEPALIVE_INTERVAL),
57 }
58 }
59}
60
61impl From<&KeepaliveOpts> for TcpKeepalive {
62 fn from(value: &KeepaliveOpts) -> Self {
63 let mut result = TcpKeepalive::new();
64 if let Some(time) = value.time {
65 result = result.with_time(time);
66 }
67 if let Some(interval) = value.interval {
68 result = result.with_interval(interval);
69 }
70 cfg_if::cfg_if! {
71 if #[cfg(not(windows))] {
72 if let Some(retries) = value.retries {
73 result = result.with_retries(retries);
74 }
75 }
76 }
77 result
78 }
79}
80
81pub async fn stream<A: AsyncToSocketAddrs>(addr: A) -> Result<TcpStream, std::io::Error> {
82 let socket_opts = SocketOpts {
83 keepalive: Some(Default::default()),
84 ..Default::default()
85 };
86 stream_with_opts(addr, Some(socket_opts)).await
87}
88
89pub async fn stream_with_opts<A: AsyncToSocketAddrs>(
90 addr: A,
91 socket_opts: Option<SocketOpts>,
92) -> Result<TcpStream, std::io::Error> {
93 debug!(?socket_opts);
94 let tcp_stream = TcpStream::connect(addr).await?;
95 if let Some(socket_opts) = socket_opts {
96 let socket_ref = SockRef::from(&tcp_stream);
97 if let Some(nodelay) = socket_opts.nodelay {
98 socket_ref.set_nodelay(nodelay)?;
99 }
100 if let Some(ref keepalive_opts) = socket_opts.keepalive {
101 let keepalive = TcpKeepalive::from(keepalive_opts);
102 socket_ref.set_tcp_keepalive(&keepalive)?;
103 }
104 }
105 Ok(tcp_stream)
106}
107
108#[cfg(test)]
109mod tests {
110 use std::io::Error;
111
112 use super::*;
113 use crate::test_async;
114 use crate::timer::sleep;
115 use async_net::SocketAddr;
116 use async_net::TcpListener;
117 use bytes::BufMut;
118 use bytes::Bytes;
119 use bytes::BytesMut;
120 use futures_lite::AsyncReadExt;
121 use futures_lite::future::zip;
122 use futures_util::SinkExt;
123 use futures_util::StreamExt;
124 use tokio_util::codec::BytesCodec;
125 use tokio_util::codec::Framed;
126 use tokio_util::compat::FuturesAsyncReadCompatExt;
127 use tracing::debug;
128
129 fn to_bytes(bytes: Vec<u8>) -> Bytes {
130 let mut buf = BytesMut::with_capacity(bytes.len());
131 buf.put_slice(&bytes);
132 buf.freeze()
133 }
134
135 #[test_async]
136 async fn test_async_tcp() -> Result<(), Error> {
137 let addr = "127.0.0.1:9998".parse::<SocketAddr>().expect("parse");
138
139 let server_ft = async {
140 debug!("server: binding");
141 let listener = TcpListener::bind(&addr).await?;
142 debug!("server: successfully binding. waiting for incoming");
143 let mut incoming = listener.incoming();
144 let stream = incoming.next().await.expect("no stream");
145 debug!("server: got connection from client");
146 let tcp_stream = stream?;
147 let mut framed = Framed::new(tcp_stream.compat(), BytesCodec::new());
148 debug!("server: sending values to client");
149 let data = vec![0x05, 0x0a, 0x63];
150 framed.send(to_bytes(data)).await?;
151 Ok(()) as Result<(), Error>
152 };
153
154 let client_ft = async {
155 debug!("client: sleep to give server chance to come up");
156 sleep(Duration::from_millis(100)).await;
157 debug!("client: trying to connect");
158 let tcp_stream = stream(&addr).await?;
159 let mut framed = Framed::new(tcp_stream.compat(), BytesCodec::new());
160 debug!("client: got connection. waiting");
161 let value = framed.next().await.expect("no value received");
162 debug!("client :received first value from server");
163 let bytes = value?;
164 debug!("client :received bytes len: {}", bytes.len());
165 assert_eq!(bytes.len(), 3);
166 let values = bytes.take(3).into_inner();
167 assert_eq!(values[0], 0x05);
168 assert_eq!(values[1], 0x0a);
169 assert_eq!(values[2], 0x63);
170
171 Ok(()) as Result<(), Error>
172 };
173
174 let _ = zip(client_ft, server_ft).await;
175
176 Ok(())
177 }
178
179 #[test_async]
180 async fn test_tcp_stream_socket_opts() -> Result<(), Error> {
181 let addr = "127.0.0.1:9997".parse::<SocketAddr>().expect("parse");
182
183 let server_ft = async {
184 debug!("server: binding");
185 let listener = TcpListener::bind(&addr).await?;
186 debug!("server: successfully binding. waiting for incoming");
187 let mut incoming = listener.incoming();
188 let _stream = incoming.next().await.expect("no stream");
189 let _stream = incoming.next().await.expect("no stream");
190 debug!("server: got connection from client");
191 Ok(()) as Result<(), Error>
192 };
193
194 let client_ft = async {
195 debug!("client: sleep to give server chance to come up");
196 sleep(Duration::from_millis(100)).await;
197 debug!("client: trying to connect");
198 {
199 let socket_opts = SocketOpts {
200 keepalive: None,
201 nodelay: Some(false),
202 };
203 let tcp_stream = stream_with_opts(&addr, Some(socket_opts)).await?;
204 assert!(!tcp_stream.nodelay()?);
205 let socket_ref = SockRef::from(&tcp_stream);
206 assert!(!(socket_ref.nodelay()?));
207 assert!(!(socket_ref.keepalive()?));
208 }
209 {
210 let time = Duration::from_secs(7201);
211 let interval = Duration::from_secs(76);
212 cfg_if::cfg_if! {
213 if #[cfg(windows)] {
214 let socket_opts = SocketOpts {
215 keepalive: Some(KeepaliveOpts {
216 time: Some(time),
217 interval: Some(interval),
218 }),
219 nodelay: Some(true),
220 };
221 } else {
222 let retries = 10;
223 let socket_opts = SocketOpts {
224 keepalive: Some(KeepaliveOpts {
225 time: Some(time),
226 interval: Some(interval),
227 retries: Some(retries),
228 }),
229 nodelay: Some(true),
230 };
231 }
232 }
233
234 let tcp_stream = stream_with_opts(&addr, Some(socket_opts)).await?;
235 assert!(tcp_stream.nodelay()?);
236 let socket_ref = SockRef::from(&tcp_stream);
237 assert!(socket_ref.nodelay()?);
238 assert!(socket_ref.keepalive()?);
239 cfg_if::cfg_if! {
240 if #[cfg(not(windows))] {
241 assert_eq!(socket_ref.keepalive_time()?, time);
242 assert_eq!(socket_ref.keepalive_interval()?, interval);
243 assert_eq!(socket_ref.keepalive_retries()?, retries);
244 }
245 }
246 }
247
248 Ok(()) as Result<(), Error>
249 };
250
251 let _ = zip(client_ft, server_ft).await;
252
253 Ok(())
254 }
255}