librqbit_dualstack_sockets/
socket.rs

1use std::{
2    net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3    task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10    Error,
11    addr::{ToV6Mapped, TryToV4},
12};
13
14#[derive(Clone, Copy, Debug)]
15pub enum SocketAddrKind {
16    V4(SocketAddrV4),
17    V6 {
18        addr: SocketAddrV6,
19        is_dualstack: bool,
20    },
21}
22
23impl SocketAddrKind {
24    fn as_socketaddr(&self) -> SocketAddr {
25        match *self {
26            SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
27            SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
28        }
29    }
30}
31
32#[derive(Clone, Copy, Debug)]
33pub struct BindOpts {
34    pub request_dualstack: bool,
35    pub reuseport: bool,
36}
37
38impl Default for BindOpts {
39    fn default() -> Self {
40        Self {
41            request_dualstack: true,
42            reuseport: false,
43        }
44    }
45}
46
47pub struct MaybeDualstackSocket<S> {
48    socket: S,
49    addr_kind: SocketAddrKind,
50}
51
52impl<S> MaybeDualstackSocket<S> {
53    pub fn socket(&self) -> &S {
54        &self.socket
55    }
56
57    pub fn bind_addr(&self) -> SocketAddr {
58        self.addr_kind.as_socketaddr()
59    }
60
61    pub fn is_dualstack(&self) -> bool {
62        matches!(
63            self.addr_kind,
64            SocketAddrKind::V6 {
65                is_dualstack: true,
66                ..
67            }
68        )
69    }
70
71    fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
72        if self.is_dualstack() {
73            return SocketAddr::V6(addr.to_ipv6_mapped());
74        }
75        addr
76    }
77}
78
79impl MaybeDualstackSocket<Socket> {
80    fn bind(addr: SocketAddr, opts: BindOpts, is_udp: bool) -> crate::Result<Self> {
81        let socket = Socket::new(
82            if addr.is_ipv6() {
83                Domain::IPV6
84            } else {
85                Domain::IPV4
86            },
87            if is_udp {
88                socket2::Type::DGRAM
89            } else {
90                socket2::Type::STREAM
91            },
92            Some(if is_udp {
93                socket2::Protocol::UDP
94            } else {
95                socket2::Protocol::TCP
96            }),
97        )
98        .map_err(Error::SocketNew)?;
99
100        let mut set_dualstack = false;
101
102        let addr_kind = match (opts.request_dualstack, addr) {
103            (request_dualstack, SocketAddr::V6(addr))
104                if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
105            {
106                let value = !request_dualstack;
107                trace!(?addr, only_v6 = value, "setting only_v6");
108                socket
109                    .set_only_v6(value)
110                    .map_err(|e| Error::OnlyV6 { value, source: e })?;
111                #[cfg(not(windows))] // socket.only_v6() panics on windows somehow
112                trace!(?addr, only_v6=?socket.only_v6());
113                set_dualstack = true;
114                SocketAddrKind::V6 {
115                    addr,
116                    is_dualstack: request_dualstack,
117                }
118            }
119            (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
120                addr,
121                is_dualstack: false,
122            },
123            (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
124        };
125
126        if !set_dualstack {
127            debug!(
128                ?addr,
129                "ignored dualstack request as it only applies to [::] address"
130            );
131        }
132
133        #[cfg(not(windows))]
134        {
135            if !is_udp {
136                socket
137                    .set_reuse_address(true)
138                    .map_err(Error::ReuseAddress)?;
139            }
140        }
141
142        #[cfg(not(windows))]
143        if opts.reuseport {
144            socket.set_reuse_port(true).map_err(Error::ReusePort)?;
145        }
146
147        socket.bind(&addr.into()).map_err(|e| {
148            trace!(?addr, "error binding: {e:#}");
149            Error::Bind(e)
150        })?;
151
152        let local_addr: SocketAddr = socket
153            .local_addr()
154            .map_err(Error::LocalAddr)?
155            .as_socket()
156            .ok_or(Error::AsSocket)?;
157
158        let addr_kind = match (addr_kind, local_addr) {
159            (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
160            (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
161                SocketAddrKind::V6 {
162                    addr: received,
163                    is_dualstack,
164                }
165            }
166            _ => {
167                tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
168                return Err(Error::LocalBindAddrMismatch);
169            }
170        };
171
172        socket
173            .set_nonblocking(true)
174            .map_err(Error::SetNonblocking)?;
175
176        Ok(Self { socket, addr_kind })
177    }
178}
179
180impl MaybeDualstackSocket<tokio::net::TcpListener> {
181    pub fn bind_tcp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
182        let sock = MaybeDualstackSocket::bind(addr, opts, false)?;
183
184        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
185        sock.socket().listen(1024).map_err(Error::Listen)?;
186
187        Ok(Self {
188            socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
189                .map_err(Error::TokioFromStd)?,
190            addr_kind: sock.addr_kind,
191        })
192    }
193
194    pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
195        let (s, addr) = self.socket.accept().await?;
196        Ok((s, addr.try_to_ipv4()))
197    }
198}
199
200#[cfg(feature = "axum")]
201pub mod axum {
202    use std::net::SocketAddr;
203
204    use crate::socket::MaybeDualstackSocket;
205
206    #[derive(Clone, Copy)]
207    pub struct WrappedSocketAddr(pub SocketAddr);
208    impl core::fmt::Debug for WrappedSocketAddr {
209        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210            write!(f, "{:?}", self.0)
211        }
212    }
213    impl From<SocketAddr> for WrappedSocketAddr {
214        fn from(value: SocketAddr) -> Self {
215            Self(value)
216        }
217    }
218    impl From<WrappedSocketAddr> for SocketAddr {
219        fn from(value: WrappedSocketAddr) -> Self {
220            value.0
221        }
222    }
223
224    impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
225        type Io = tokio::net::TcpStream;
226
227        type Addr = WrappedSocketAddr;
228
229        async fn accept(&mut self) -> (Self::Io, Self::Addr) {
230            use backon::{ExponentialBuilder, Retryable};
231            let (l, a) = (|| MaybeDualstackSocket::accept(self))
232                .retry(
233                    ExponentialBuilder::new()
234                        .without_max_times()
235                        .with_max_delay(std::time::Duration::from_secs(5)),
236                )
237                .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
238                .await
239                .unwrap();
240            (l, a.into())
241        }
242
243        fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
244            Ok(self.bind_addr().into())
245        }
246    }
247
248    impl
249        axum::extract::connect_info::Connected<
250            axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
251        > for WrappedSocketAddr
252    {
253        fn connect_info(
254            stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
255        ) -> Self {
256            *stream.remote_addr()
257        }
258    }
259}
260
261impl MaybeDualstackSocket<tokio::net::UdpSocket> {
262    pub fn bind_udp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
263        let sock = MaybeDualstackSocket::bind(addr, opts, true)?;
264
265        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
266
267        Ok(Self {
268            socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
269                .map_err(Error::TokioFromStd)?,
270            addr_kind: sock.addr_kind,
271        })
272    }
273
274    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
275        let (size, addr) = self.socket.recv_from(buf).await?;
276        Ok((size, addr.try_to_ipv4()))
277    }
278
279    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
280        let target = self.convert_addr_for_send(target);
281        trace!(?target, "sending");
282        self.socket.send_to(buf, target).await
283    }
284
285    pub fn poll_send_to(
286        &self,
287        cx: &mut std::task::Context<'_>,
288        buf: &[u8],
289        target: SocketAddr,
290    ) -> Poll<std::io::Result<usize>> {
291        let target = self.convert_addr_for_send(target);
292        self.socket.poll_send_to(cx, buf, target)
293    }
294}