veilid_tools/
socket_tools.rs

1use super::*;
2use async_io::Async;
3use std::io;
4
5cfg_if! {
6    if #[cfg(feature="rt-async-std")] {
7        pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
8    } else if #[cfg(feature="rt-tokio")] {
9        pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
10        pub use tokio_util::compat::*;
11    } else {
12        compile_error!("needs executor implementation");
13    }
14}
15
16use socket2::{Domain, Protocol, SockAddr, Socket, Type};
17
18//////////////////////////////////////////////////////////////////////////////////////////
19
20pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
21    let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
22        return Ok(None);
23    };
24
25    // Make an async UdpSocket from the socket2 socket
26    let std_udp_socket: std::net::UdpSocket = socket.into();
27    cfg_if! {
28        if #[cfg(feature="rt-async-std")] {
29            let udp_socket = UdpSocket::from(std_udp_socket);
30        } else if #[cfg(feature="rt-tokio")] {
31            std_udp_socket.set_nonblocking(true)?;
32            let udp_socket = UdpSocket::from_std(std_udp_socket)?;
33        } else {
34            compile_error!("needs executor implementation");
35        }
36    }
37    Ok(Some(udp_socket))
38}
39
40pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
41    // Create a default non-shared socket and bind it
42    let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
43        return Ok(None);
44    };
45
46    // Drop the socket so we can make another shared socket in its place
47    drop(socket);
48
49    // Create a shared socket and bind it now we have determined the port is free
50    let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
51        return Ok(None);
52    };
53
54    // Listen on the socket
55    if socket.listen(128).is_err() {
56        return Ok(None);
57    }
58
59    // Make an async tcplistener from the socket2 socket
60    let std_listener: std::net::TcpListener = socket.into();
61    cfg_if! {
62        if #[cfg(feature="rt-async-std")] {
63            let listener = TcpListener::from(std_listener);
64        } else if #[cfg(feature="rt-tokio")] {
65            std_listener.set_nonblocking(true)?;
66            let listener = TcpListener::from_std(std_listener)?;
67        } else {
68            compile_error!("needs executor implementation");
69        }
70    }
71    Ok(Some(listener))
72}
73
74pub async fn connect_async_tcp_stream(
75    local_address: Option<SocketAddr>,
76    remote_address: SocketAddr,
77    timeout_ms: u32,
78) -> io::Result<TimeoutOr<TcpStream>> {
79    let socket = match local_address {
80        Some(a) => {
81            new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
82        }
83        None => new_default_socket2_tcp(domain_for_address(remote_address))?,
84    };
85
86    // Non-blocking connect to remote address
87    nonblocking_connect(socket, remote_address, timeout_ms).await
88}
89
90pub fn set_tcp_stream_linger(
91    tcp_stream: &TcpStream,
92    linger: Option<core::time::Duration>,
93) -> io::Result<()> {
94    #[cfg(all(feature = "rt-async-std", unix))]
95    {
96        // async-std does not directly support linger on TcpStream yet
97        use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
98        unsafe {
99            let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
100            let res = s.set_linger(linger);
101            let _ = s.into_raw_fd();
102            res
103        }
104    }
105    #[cfg(all(feature = "rt-async-std", windows))]
106    {
107        // async-std does not directly support linger on TcpStream yet
108        use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
109        unsafe {
110            let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
111            let res = s.set_linger(linger);
112            let _ = s.into_raw_socket();
113            res
114        }
115    }
116    #[cfg(not(feature = "rt-async-std"))]
117    tcp_stream.set_linger(linger)
118}
119
120cfg_if! {
121    if #[cfg(feature="rt-async-std")] {
122        pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>;
123        pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>;
124    } else if #[cfg(feature="rt-tokio")] {
125        pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
126        pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
127    } else {
128        compile_error!("needs executor implementation");
129    }
130}
131
132#[must_use]
133pub fn async_tcp_listener_incoming(
134    tcp_listener: TcpListener,
135) -> Pin<Box<impl futures_util::stream::Stream<Item = std::io::Result<TcpStream>> + Send>> {
136    cfg_if! {
137        if #[cfg(feature="rt-async-std")] {
138            Box::pin(tcp_listener.into_incoming())
139        } else if #[cfg(feature="rt-tokio")] {
140            Box::pin(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener))
141        } else {
142            compile_error!("needs executor implementation");
143        }
144    }
145}
146
147#[must_use]
148pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
149    cfg_if! {
150        if #[cfg(feature="rt-async-std")] {
151            use futures_util::AsyncReadExt;
152            tcp_stream.split()
153        } else if #[cfg(feature="rt-tokio")] {
154            tcp_stream.into_split()
155        } else {
156            compile_error!("needs executor implementation");
157        }
158    }
159}
160
161//////////////////////////////////////////////////////////////////////////////////////////
162
163fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
164    let domain = Domain::from(domain);
165    let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
166    if domain == Domain::IPV6 {
167        socket.set_only_v6(true)?;
168    }
169
170    Ok(socket)
171}
172
173fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
174    let domain = domain_for_address(local_address);
175    let socket = new_default_udp_socket(domain)?;
176    let socket2_addr = SockAddr::from(local_address);
177
178    if socket.bind(&socket2_addr).is_err() {
179        return Ok(None);
180    }
181
182    Ok(Some(socket))
183}
184
185pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
186    let domain = Domain::from(domain);
187    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
188    socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
189    socket.set_nodelay(true)?;
190    if domain == Domain::IPV6 {
191        socket.set_only_v6(true)?;
192    }
193    Ok(socket)
194}
195
196fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
197    let domain = Domain::from(domain);
198    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
199    socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
200    socket.set_nodelay(true)?;
201    if domain == Domain::IPV6 {
202        socket.set_only_v6(true)?;
203    }
204    socket.set_reuse_address(true)?;
205    cfg_if! {
206        if #[cfg(unix)] {
207            socket.set_reuse_port(true)?;
208        }
209    }
210
211    Ok(socket)
212}
213
214fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
215    let domain = domain_for_address(local_address);
216    let socket = new_default_socket2_tcp(domain)?;
217    let socket2_addr = SockAddr::from(local_address);
218    if socket.bind(&socket2_addr).is_err() {
219        return Ok(None);
220    }
221
222    Ok(Some(socket))
223}
224
225fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
226    // Create the reuseaddr/reuseport socket now that we've asserted the port is free
227    let domain = domain_for_address(local_address);
228    let socket = new_shared_socket2_tcp(domain)?;
229    let socket2_addr = SockAddr::from(local_address);
230    if socket.bind(&socket2_addr).is_err() {
231        return Ok(None);
232    }
233
234    Ok(Some(socket))
235}
236
237// Non-blocking connect is tricky when you want to start with a prepared socket
238// Errors should not be logged as they are valid conditions for this function
239async fn nonblocking_connect(
240    socket: Socket,
241    addr: SocketAddr,
242    timeout_ms: u32,
243) -> io::Result<TimeoutOr<TcpStream>> {
244    // Set for non blocking connect
245    socket.set_nonblocking(true)?;
246
247    // Make socket2 SockAddr
248    let socket2_addr = socket2::SockAddr::from(addr);
249
250    // Connect to the remote address
251    match socket.connect(&socket2_addr) {
252        Ok(()) => Ok(()),
253        #[cfg(unix)]
254        Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
255        Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
256        Err(e) => Err(e),
257    }?;
258    let async_stream = Async::new(std::net::TcpStream::from(socket))?;
259
260    // The stream becomes writable when connected
261    timeout_or_try!(timeout(timeout_ms, async_stream.writable())
262        .await
263        .into_timeout_or()
264        .into_result()?);
265
266    // Check low level error
267    let async_stream = match async_stream.get_ref().take_error()? {
268        None => Ok(async_stream),
269        Some(err) => Err(err),
270    }?;
271
272    // Convert back to inner and then return async version
273    cfg_if! {
274        if #[cfg(feature="rt-async-std")] {
275            Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
276        } else if #[cfg(feature="rt-tokio")] {
277            Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
278        } else {
279            compile_error!("needs executor implementation");
280        }
281    }
282}
283
284#[must_use]
285pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
286    socket2::Domain::for_address(address).into()
287}
288
289// Run operations on underlying socket
290cfg_if! {
291    if #[cfg(unix)] {
292        pub fn socket2_operation<S: std::os::fd::AsRawFd, F: FnOnce(&mut socket2::Socket) -> R, R>(
293            s: &S,
294            callback: F,
295        ) -> R {
296            use std::os::fd::{FromRawFd, IntoRawFd};
297            let mut s = unsafe { socket2::Socket::from_raw_fd(s.as_raw_fd()) };
298            let res = callback(&mut s);
299            let _ = s.into_raw_fd();
300            res
301        }
302    } else if #[cfg(windows)] {
303        pub fn socket2_operation<
304            S: std::os::windows::io::AsRawSocket,
305            F: FnOnce(&mut socket2::Socket) -> R,
306            R,
307        >(
308            s: &S,
309            callback: F,
310        ) -> R {
311            use std::os::windows::io::{FromRawSocket, IntoRawSocket};
312            let mut s = unsafe { socket2::Socket::from_raw_socket(s.as_raw_socket()) };
313            let res = callback(&mut s);
314            let _ = s.into_raw_socket();
315            res
316        }
317    } else {
318        #[compile_error("unimplemented")]
319    }
320}