librqbit_dualstack_sockets/
socket.rs1use std::{
2 net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3 task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10 Error,
11 addr::{ToV6Mapped, TryToV4},
12 bind_device::BindDevice,
13};
14
15#[derive(Clone, Copy, Debug)]
16pub enum SocketAddrKind {
17 V4(SocketAddrV4),
18 V6 {
19 addr: SocketAddrV6,
20 is_dualstack: bool,
21 },
22}
23
24impl SocketAddrKind {
25 fn is_v6(&self) -> bool {
26 matches!(self, SocketAddrKind::V6 { .. })
27 }
28
29 fn as_socketaddr(&self) -> SocketAddr {
30 match *self {
31 SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
32 SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
33 }
34 }
35}
36
37#[derive(Clone, Copy, Debug)]
38pub struct BindOpts<'a> {
39 pub request_dualstack: bool,
40 pub reuseport: bool,
41 pub device: Option<&'a BindDevice>,
42}
43
44impl Default for BindOpts<'_> {
45 fn default() -> Self {
46 Self {
47 request_dualstack: true,
48 reuseport: false,
49 device: None,
50 }
51 }
52}
53
54pub struct MaybeDualstackSocket<S> {
55 socket: S,
56 addr_kind: SocketAddrKind,
57}
58
59impl<S> MaybeDualstackSocket<S> {
60 pub fn socket(&self) -> &S {
61 &self.socket
62 }
63
64 pub fn bind_addr(&self) -> SocketAddr {
65 self.addr_kind.as_socketaddr()
66 }
67
68 pub fn is_dualstack(&self) -> bool {
69 matches!(
70 self.addr_kind,
71 SocketAddrKind::V6 {
72 is_dualstack: true,
73 ..
74 }
75 )
76 }
77
78 pub(crate) fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
79 if self.is_dualstack() {
80 return SocketAddr::V6(addr.to_ipv6_mapped());
81 }
82 addr
83 }
84}
85
86impl MaybeDualstackSocket<Socket> {
87 fn bind(addr: SocketAddr, opts: BindOpts, is_udp: bool) -> crate::Result<Self> {
88 let socket = Socket::new(
89 if addr.is_ipv6() {
90 Domain::IPV6
91 } else {
92 Domain::IPV4
93 },
94 if is_udp {
95 socket2::Type::DGRAM
96 } else {
97 socket2::Type::STREAM
98 },
99 Some(if is_udp {
100 socket2::Protocol::UDP
101 } else {
102 socket2::Protocol::TCP
103 }),
104 )
105 .map_err(Error::SocketNew)?;
106
107 let mut set_dualstack = false;
108
109 let addr_kind = match (opts.request_dualstack, addr) {
110 (request_dualstack, SocketAddr::V6(addr))
111 if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
112 {
113 let value = !request_dualstack;
114 trace!(?addr, only_v6 = value, "setting only_v6");
115 socket
116 .set_only_v6(value)
117 .map_err(|e| Error::OnlyV6 { value, source: e })?;
118 #[cfg(not(windows))] trace!(?addr, only_v6=?socket.only_v6());
120 set_dualstack = true;
121 SocketAddrKind::V6 {
122 addr,
123 is_dualstack: request_dualstack,
124 }
125 }
126 (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
127 addr,
128 is_dualstack: false,
129 },
130 (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
131 };
132
133 if !set_dualstack {
134 debug!(
135 ?addr,
136 "ignored dualstack request as it only applies to [::] address"
137 );
138 }
139
140 #[cfg(not(windows))]
141 {
142 socket
143 .set_reuse_address(true)
144 .map_err(Error::ReuseAddress)?;
145 }
146
147 #[cfg(windows)]
148 if opts.reuseport || !is_udp {
149 socket
150 .set_reuse_address(true)
151 .map_err(Error::ReuseAddress)?;
152 }
153
154 #[cfg(not(windows))]
155 if opts.reuseport {
156 socket.set_reuse_port(true).map_err(Error::ReusePort)?;
157 debug!(reuse_port=?socket.reuse_port());
158 debug!(reuse_addr=?socket.reuse_address());
159 }
160
161 if let Some(bd) = opts.device {
162 bd.bind_sref(&socket, addr_kind.is_v6())?;
163 }
164
165 socket.bind(&addr.into()).map_err(|e| {
166 trace!(?addr, "error binding: {e:#}");
167 Error::Bind(e)
168 })?;
169
170 let local_addr: SocketAddr = socket
171 .local_addr()
172 .map_err(Error::LocalAddr)?
173 .as_socket()
174 .ok_or(Error::AsSocket)?;
175
176 let addr_kind = match (addr_kind, local_addr) {
177 (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
178 (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
179 SocketAddrKind::V6 {
180 addr: received,
181 is_dualstack,
182 }
183 }
184 _ => {
185 tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
186 return Err(Error::LocalBindAddrMismatch);
187 }
188 };
189
190 socket
191 .set_nonblocking(true)
192 .map_err(Error::SetNonblocking)?;
193
194 Ok(Self { socket, addr_kind })
195 }
196}
197
198#[cfg(target_os = "linux")]
199impl TryFrom<std::os::fd::OwnedFd> for MaybeDualstackSocket<tokio::net::TcpListener> {
200 type Error = crate::Error;
201 fn try_from(fd: std::os::fd::OwnedFd) -> Result<Self, Self::Error> {
206 use std::io;
207 let sock = Socket::from(fd);
208 match sock.protocol().map_err(Error::SocketFromFd)? {
209 Some(socket2::Protocol::TCP) => {}
210 Some(proto) => {
211 return Err(Error::SocketFromFd(io::Error::other(format!(
212 "expected a TCP socket, got a {proto:?} socket"
213 ))));
214 }
215 None => {
216 return Err(Error::SocketFromFd(io::Error::other(
217 "socket has no protocol",
218 )));
219 }
220 };
221
222 if !sock.is_listener().map_err(Error::SocketFromFd)? {
223 return Err(Error::SocketFromFd(io::Error::other(
224 "expected a listening TCP socket",
225 )));
226 }
227
228 let addr_kind = match sock
229 .local_addr()
230 .map_err(Error::LocalAddr)?
231 .as_socket()
232 .ok_or(Error::AsSocket)?
233 {
234 SocketAddr::V4(addr) => SocketAddrKind::V4(addr),
235 SocketAddr::V6(addr) => SocketAddrKind::V6 {
236 addr,
237 is_dualstack: addr.ip().is_unspecified()
238 && !sock.only_v6().map_err(Error::SocketFromFd)?,
239 },
240 };
241
242 sock.set_nonblocking(true).map_err(Error::SetNonblocking)?;
243
244 Ok(Self {
245 addr_kind,
246 socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock))
247 .map_err(Error::TokioFromStd)?,
248 })
249 }
250}
251
252impl MaybeDualstackSocket<tokio::net::TcpListener> {
253 pub fn bind_tcp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
254 let sock = MaybeDualstackSocket::bind(addr, opts, false)?;
255
256 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
257 sock.socket().listen(1024).map_err(Error::Listen)?;
258
259 Ok(Self {
260 socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
261 .map_err(Error::TokioFromStd)?,
262 addr_kind: sock.addr_kind,
263 })
264 }
265
266 pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
267 let (s, addr) = self.socket.accept().await?;
268 Ok((s, addr.try_to_ipv4()))
269 }
270}
271
272#[cfg(feature = "axum")]
273pub mod axum {
274 use std::net::SocketAddr;
275
276 use crate::socket::MaybeDualstackSocket;
277
278 #[derive(Clone, Copy)]
279 pub struct WrappedSocketAddr(pub SocketAddr);
280 impl core::fmt::Debug for WrappedSocketAddr {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 write!(f, "{:?}", self.0)
283 }
284 }
285 impl From<SocketAddr> for WrappedSocketAddr {
286 fn from(value: SocketAddr) -> Self {
287 Self(value)
288 }
289 }
290 impl From<WrappedSocketAddr> for SocketAddr {
291 fn from(value: WrappedSocketAddr) -> Self {
292 value.0
293 }
294 }
295
296 impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
297 type Io = tokio::net::TcpStream;
298
299 type Addr = WrappedSocketAddr;
300
301 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
302 use backon::{ExponentialBuilder, Retryable};
303 let (l, a) = (|| MaybeDualstackSocket::accept(self))
304 .retry(
305 ExponentialBuilder::new()
306 .without_max_times()
307 .with_max_delay(std::time::Duration::from_secs(5)),
308 )
309 .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
310 .await
311 .unwrap();
312 (l, a.into())
313 }
314
315 fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
316 Ok(self.bind_addr().into())
317 }
318 }
319
320 impl
321 axum::extract::connect_info::Connected<
322 axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
323 > for WrappedSocketAddr
324 {
325 fn connect_info(
326 stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
327 ) -> Self {
328 *stream.remote_addr()
329 }
330 }
331}
332
333impl MaybeDualstackSocket<tokio::net::UdpSocket> {
334 pub fn bind_udp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
335 let sock = MaybeDualstackSocket::bind(addr, opts, true)?;
336
337 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
338
339 Ok(Self {
340 socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
341 .map_err(Error::TokioFromStd)?,
342 addr_kind: sock.addr_kind,
343 })
344 }
345
346 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
347 let (size, addr) = self.socket.recv_from(buf).await?;
348 Ok((size, addr.try_to_ipv4()))
349 }
350
351 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
352 let target = self.convert_addr_for_send(target);
353 self.socket.send_to(buf, target).await
354 }
355
356 pub fn poll_send_to(
357 &self,
358 cx: &mut std::task::Context<'_>,
359 buf: &[u8],
360 target: SocketAddr,
361 ) -> Poll<std::io::Result<usize>> {
362 let target = self.convert_addr_for_send(target);
363 self.socket.poll_send_to(cx, buf, target)
364 }
365}