librqbit_dualstack_sockets/
socket.rs

1use std::{
2    net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3    task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10    Error,
11    addr::{ToV6Mapped, TryToV4},
12    bind_device::BindDevice,
13};
14
15#[derive(Clone, Copy, Debug)]
16pub enum SocketAddrKind {
17    V4(SocketAddrV4),
18    V6 {
19        addr: SocketAddrV6,
20        is_dualstack: bool,
21    },
22}
23
24impl SocketAddrKind {
25    fn is_v6(&self) -> bool {
26        matches!(self, SocketAddrKind::V6 { .. })
27    }
28
29    fn as_socketaddr(&self) -> SocketAddr {
30        match *self {
31            SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
32            SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
33        }
34    }
35}
36
37#[derive(Clone, Copy, Debug)]
38pub struct BindOpts<'a> {
39    pub request_dualstack: bool,
40    pub reuseport: bool,
41    pub device: Option<&'a BindDevice>,
42}
43
44impl Default for BindOpts<'_> {
45    fn default() -> Self {
46        Self {
47            request_dualstack: true,
48            reuseport: false,
49            device: None,
50        }
51    }
52}
53
54pub struct MaybeDualstackSocket<S> {
55    socket: S,
56    addr_kind: SocketAddrKind,
57}
58
59impl<S> MaybeDualstackSocket<S> {
60    pub fn socket(&self) -> &S {
61        &self.socket
62    }
63
64    pub fn bind_addr(&self) -> SocketAddr {
65        self.addr_kind.as_socketaddr()
66    }
67
68    pub fn is_dualstack(&self) -> bool {
69        matches!(
70            self.addr_kind,
71            SocketAddrKind::V6 {
72                is_dualstack: true,
73                ..
74            }
75        )
76    }
77
78    pub(crate) fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
79        if self.is_dualstack() {
80            return SocketAddr::V6(addr.to_ipv6_mapped());
81        }
82        addr
83    }
84}
85
86impl MaybeDualstackSocket<Socket> {
87    fn bind(addr: SocketAddr, opts: BindOpts, is_udp: bool) -> crate::Result<Self> {
88        let socket = Socket::new(
89            if addr.is_ipv6() {
90                Domain::IPV6
91            } else {
92                Domain::IPV4
93            },
94            if is_udp {
95                socket2::Type::DGRAM
96            } else {
97                socket2::Type::STREAM
98            },
99            Some(if is_udp {
100                socket2::Protocol::UDP
101            } else {
102                socket2::Protocol::TCP
103            }),
104        )
105        .map_err(Error::SocketNew)?;
106
107        let mut set_dualstack = false;
108
109        let addr_kind = match (opts.request_dualstack, addr) {
110            (request_dualstack, SocketAddr::V6(addr))
111                if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
112            {
113                let value = !request_dualstack;
114                trace!(?addr, only_v6 = value, "setting only_v6");
115                socket
116                    .set_only_v6(value)
117                    .map_err(|e| Error::OnlyV6 { value, source: e })?;
118                #[cfg(not(windows))] // socket.only_v6() panics on windows somehow
119                trace!(?addr, only_v6=?socket.only_v6());
120                set_dualstack = true;
121                SocketAddrKind::V6 {
122                    addr,
123                    is_dualstack: request_dualstack,
124                }
125            }
126            (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
127                addr,
128                is_dualstack: false,
129            },
130            (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
131        };
132
133        if !set_dualstack {
134            debug!(
135                ?addr,
136                "ignored dualstack request as it only applies to [::] address"
137            );
138        }
139
140        #[cfg(not(windows))]
141        {
142            socket
143                .set_reuse_address(true)
144                .map_err(Error::ReuseAddress)?;
145        }
146
147        #[cfg(windows)]
148        if opts.reuseport || !is_udp {
149            socket
150                .set_reuse_address(true)
151                .map_err(Error::ReuseAddress)?;
152        }
153
154        #[cfg(not(windows))]
155        if opts.reuseport {
156            socket.set_reuse_port(true).map_err(Error::ReusePort)?;
157            debug!(reuse_port=?socket.reuse_port());
158            debug!(reuse_addr=?socket.reuse_address());
159        }
160
161        if let Some(bd) = opts.device {
162            bd.bind_sref(&socket, addr_kind.is_v6())?;
163        }
164
165        socket.bind(&addr.into()).map_err(|e| {
166            trace!(?addr, "error binding: {e:#}");
167            Error::Bind(e)
168        })?;
169
170        let local_addr: SocketAddr = socket
171            .local_addr()
172            .map_err(Error::LocalAddr)?
173            .as_socket()
174            .ok_or(Error::AsSocket)?;
175
176        let addr_kind = match (addr_kind, local_addr) {
177            (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
178            (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
179                SocketAddrKind::V6 {
180                    addr: received,
181                    is_dualstack,
182                }
183            }
184            _ => {
185                tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
186                return Err(Error::LocalBindAddrMismatch);
187            }
188        };
189
190        socket
191            .set_nonblocking(true)
192            .map_err(Error::SetNonblocking)?;
193
194        Ok(Self { socket, addr_kind })
195    }
196}
197
198impl MaybeDualstackSocket<tokio::net::TcpListener> {
199    pub fn bind_tcp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
200        let sock = MaybeDualstackSocket::bind(addr, opts, false)?;
201
202        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
203        sock.socket().listen(1024).map_err(Error::Listen)?;
204
205        Ok(Self {
206            socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
207                .map_err(Error::TokioFromStd)?,
208            addr_kind: sock.addr_kind,
209        })
210    }
211
212    pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
213        let (s, addr) = self.socket.accept().await?;
214        Ok((s, addr.try_to_ipv4()))
215    }
216}
217
218#[cfg(feature = "axum")]
219pub mod axum {
220    use std::net::SocketAddr;
221
222    use crate::socket::MaybeDualstackSocket;
223
224    #[derive(Clone, Copy)]
225    pub struct WrappedSocketAddr(pub SocketAddr);
226    impl core::fmt::Debug for WrappedSocketAddr {
227        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228            write!(f, "{:?}", self.0)
229        }
230    }
231    impl From<SocketAddr> for WrappedSocketAddr {
232        fn from(value: SocketAddr) -> Self {
233            Self(value)
234        }
235    }
236    impl From<WrappedSocketAddr> for SocketAddr {
237        fn from(value: WrappedSocketAddr) -> Self {
238            value.0
239        }
240    }
241
242    impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
243        type Io = tokio::net::TcpStream;
244
245        type Addr = WrappedSocketAddr;
246
247        async fn accept(&mut self) -> (Self::Io, Self::Addr) {
248            use backon::{ExponentialBuilder, Retryable};
249            let (l, a) = (|| MaybeDualstackSocket::accept(self))
250                .retry(
251                    ExponentialBuilder::new()
252                        .without_max_times()
253                        .with_max_delay(std::time::Duration::from_secs(5)),
254                )
255                .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
256                .await
257                .unwrap();
258            (l, a.into())
259        }
260
261        fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
262            Ok(self.bind_addr().into())
263        }
264    }
265
266    impl
267        axum::extract::connect_info::Connected<
268            axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
269        > for WrappedSocketAddr
270    {
271        fn connect_info(
272            stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
273        ) -> Self {
274            *stream.remote_addr()
275        }
276    }
277}
278
279impl MaybeDualstackSocket<tokio::net::UdpSocket> {
280    pub fn bind_udp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
281        let sock = MaybeDualstackSocket::bind(addr, opts, true)?;
282
283        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
284
285        Ok(Self {
286            socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
287                .map_err(Error::TokioFromStd)?,
288            addr_kind: sock.addr_kind,
289        })
290    }
291
292    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
293        let (size, addr) = self.socket.recv_from(buf).await?;
294        Ok((size, addr.try_to_ipv4()))
295    }
296
297    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
298        let target = self.convert_addr_for_send(target);
299        self.socket.send_to(buf, target).await
300    }
301
302    pub fn poll_send_to(
303        &self,
304        cx: &mut std::task::Context<'_>,
305        buf: &[u8],
306        target: SocketAddr,
307    ) -> Poll<std::io::Result<usize>> {
308        let target = self.convert_addr_for_send(target);
309        self.socket.poll_send_to(cx, buf, target)
310    }
311}