veilid_tools/
socket_tools.rs1use super::*;
2use async_io::Async;
3use std::io;
4
5cfg_if! {
6 if #[cfg(feature="rt-async-std")] {
7 pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
8 } else if #[cfg(feature="rt-tokio")] {
9 pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
10 pub use tokio_util::compat::*;
11 } else {
12 compile_error!("needs executor implementation");
13 }
14}
15
16use socket2::{Domain, Protocol, SockAddr, Socket, Type};
17
18pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
21 let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
22 return Ok(None);
23 };
24
25 let std_udp_socket: std::net::UdpSocket = socket.into();
27 cfg_if! {
28 if #[cfg(feature="rt-async-std")] {
29 let udp_socket = UdpSocket::from(std_udp_socket);
30 } else if #[cfg(feature="rt-tokio")] {
31 std_udp_socket.set_nonblocking(true)?;
32 let udp_socket = UdpSocket::from_std(std_udp_socket)?;
33 } else {
34 compile_error!("needs executor implementation");
35 }
36 }
37 Ok(Some(udp_socket))
38}
39
40pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
41 let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
43 return Ok(None);
44 };
45
46 drop(socket);
48
49 let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
51 return Ok(None);
52 };
53
54 if socket.listen(128).is_err() {
56 return Ok(None);
57 }
58
59 let std_listener: std::net::TcpListener = socket.into();
61 cfg_if! {
62 if #[cfg(feature="rt-async-std")] {
63 let listener = TcpListener::from(std_listener);
64 } else if #[cfg(feature="rt-tokio")] {
65 std_listener.set_nonblocking(true)?;
66 let listener = TcpListener::from_std(std_listener)?;
67 } else {
68 compile_error!("needs executor implementation");
69 }
70 }
71 Ok(Some(listener))
72}
73
74pub async fn connect_async_tcp_stream(
75 local_address: Option<SocketAddr>,
76 remote_address: SocketAddr,
77 timeout_ms: u32,
78) -> io::Result<TimeoutOr<TcpStream>> {
79 let socket = match local_address {
80 Some(a) => {
81 new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
82 }
83 None => new_default_socket2_tcp(domain_for_address(remote_address))?,
84 };
85
86 nonblocking_connect(socket, remote_address, timeout_ms).await
88}
89
90pub fn set_tcp_stream_linger(
91 tcp_stream: &TcpStream,
92 linger: Option<core::time::Duration>,
93) -> io::Result<()> {
94 #[cfg(all(feature = "rt-async-std", unix))]
95 {
96 use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
98 unsafe {
99 let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
100 let res = s.set_linger(linger);
101 let _ = s.into_raw_fd();
102 res
103 }
104 }
105 #[cfg(all(feature = "rt-async-std", windows))]
106 {
107 use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
109 unsafe {
110 let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
111 let res = s.set_linger(linger);
112 let _ = s.into_raw_socket();
113 res
114 }
115 }
116 #[cfg(not(feature = "rt-async-std"))]
117 tcp_stream.set_linger(linger)
118}
119
120cfg_if! {
121 if #[cfg(feature="rt-async-std")] {
122 pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>;
123 pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>;
124 } else if #[cfg(feature="rt-tokio")] {
125 pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
126 pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
127 } else {
128 compile_error!("needs executor implementation");
129 }
130}
131
132pub fn async_tcp_listener_incoming(
133 tcp_listener: TcpListener,
134) -> Pin<Box<impl futures_util::stream::Stream<Item = std::io::Result<TcpStream>> + Send>> {
135 cfg_if! {
136 if #[cfg(feature="rt-async-std")] {
137 Box::pin(tcp_listener.into_incoming())
138 } else if #[cfg(feature="rt-tokio")] {
139 Box::pin(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener))
140 } else {
141 compile_error!("needs executor implementation");
142 }
143 }
144}
145
146pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
147 cfg_if! {
148 if #[cfg(feature="rt-async-std")] {
149 use futures_util::AsyncReadExt;
150 tcp_stream.split()
151 } else if #[cfg(feature="rt-tokio")] {
152 tcp_stream.into_split()
153 } else {
154 compile_error!("needs executor implementation");
155 }
156 }
157}
158
159fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
162 let domain = Domain::from(domain);
163 let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
164 if domain == Domain::IPV6 {
165 socket.set_only_v6(true)?;
166 }
167
168 Ok(socket)
169}
170
171fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
172 let domain = domain_for_address(local_address);
173 let socket = new_default_udp_socket(domain)?;
174 let socket2_addr = SockAddr::from(local_address);
175
176 if socket.bind(&socket2_addr).is_err() {
177 return Ok(None);
178 }
179
180 Ok(Some(socket))
181}
182
183pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
184 let domain = Domain::from(domain);
185 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
186 socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
187 socket.set_nodelay(true)?;
188 if domain == Domain::IPV6 {
189 socket.set_only_v6(true)?;
190 }
191 Ok(socket)
192}
193
194fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
195 let domain = Domain::from(domain);
196 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
197 socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
198 socket.set_nodelay(true)?;
199 if domain == Domain::IPV6 {
200 socket.set_only_v6(true)?;
201 }
202 socket.set_reuse_address(true)?;
203 cfg_if! {
204 if #[cfg(unix)] {
205 socket.set_reuse_port(true)?;
206 }
207 }
208
209 Ok(socket)
210}
211
212fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
213 let domain = domain_for_address(local_address);
214 let socket = new_default_socket2_tcp(domain)?;
215 let socket2_addr = SockAddr::from(local_address);
216 if socket.bind(&socket2_addr).is_err() {
217 return Ok(None);
218 }
219
220 Ok(Some(socket))
221}
222
223fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
224 let domain = domain_for_address(local_address);
226 let socket = new_shared_socket2_tcp(domain)?;
227 let socket2_addr = SockAddr::from(local_address);
228 if socket.bind(&socket2_addr).is_err() {
229 return Ok(None);
230 }
231
232 Ok(Some(socket))
233}
234
235async fn nonblocking_connect(
238 socket: Socket,
239 addr: SocketAddr,
240 timeout_ms: u32,
241) -> io::Result<TimeoutOr<TcpStream>> {
242 socket.set_nonblocking(true)?;
244
245 let socket2_addr = socket2::SockAddr::from(addr);
247
248 match socket.connect(&socket2_addr) {
250 Ok(()) => Ok(()),
251 #[cfg(unix)]
252 Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
253 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
254 Err(e) => Err(e),
255 }?;
256 let async_stream = Async::new(std::net::TcpStream::from(socket))?;
257
258 timeout_or_try!(timeout(timeout_ms, async_stream.writable())
260 .await
261 .into_timeout_or()
262 .into_result()?);
263
264 let async_stream = match async_stream.get_ref().take_error()? {
266 None => Ok(async_stream),
267 Some(err) => Err(err),
268 }?;
269
270 cfg_if! {
272 if #[cfg(feature="rt-async-std")] {
273 Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
274 } else if #[cfg(feature="rt-tokio")] {
275 Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
276 } else {
277 compile_error!("needs executor implementation");
278 }
279 }
280}
281
282pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
283 socket2::Domain::for_address(address).into()
284}
285
286cfg_if! {
288 if #[cfg(unix)] {
289 pub fn socket2_operation<S: std::os::fd::AsRawFd, F: FnOnce(&mut socket2::Socket) -> R, R>(
290 s: &S,
291 callback: F,
292 ) -> R {
293 use std::os::fd::{FromRawFd, IntoRawFd};
294 let mut s = unsafe { socket2::Socket::from_raw_fd(s.as_raw_fd()) };
295 let res = callback(&mut s);
296 let _ = s.into_raw_fd();
297 res
298 }
299 } else if #[cfg(windows)] {
300 pub fn socket2_operation<
301 S: std::os::windows::io::AsRawSocket,
302 F: FnOnce(&mut socket2::Socket) -> R,
303 R,
304 >(
305 s: &S,
306 callback: F,
307 ) -> R {
308 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
309 let mut s = unsafe { socket2::Socket::from_raw_socket(s.as_raw_socket()) };
310 let res = callback(&mut s);
311 let _ = s.into_raw_socket();
312 res
313 }
314 } else {
315 #[compile_error("unimplemented")]
316 }
317}