librqbit_dualstack_sockets/
socket.rs

1use std::{
2    net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3    task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10    Error,
11    addr::{ToV6Mapped, TryToV4},
12};
13
14#[derive(Clone, Copy, Debug)]
15pub enum SocketAddrKind {
16    V4(SocketAddrV4),
17    V6 {
18        addr: SocketAddrV6,
19        is_dualstack: bool,
20    },
21}
22
23impl SocketAddrKind {
24    fn as_socketaddr(&self) -> SocketAddr {
25        match *self {
26            SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
27            SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
28        }
29    }
30}
31
32pub struct MaybeDualstackSocket<S> {
33    socket: S,
34    addr_kind: SocketAddrKind,
35}
36
37impl<S> MaybeDualstackSocket<S> {
38    pub fn socket(&self) -> &S {
39        &self.socket
40    }
41
42    pub fn bind_addr(&self) -> SocketAddr {
43        self.addr_kind.as_socketaddr()
44    }
45
46    pub fn is_dualstack(&self) -> bool {
47        matches!(
48            self.addr_kind,
49            SocketAddrKind::V6 {
50                is_dualstack: true,
51                ..
52            }
53        )
54    }
55
56    fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
57        if self.is_dualstack() {
58            return SocketAddr::V6(addr.to_ipv6_mapped());
59        }
60        addr
61    }
62}
63
64impl MaybeDualstackSocket<Socket> {
65    fn bind(addr: SocketAddr, request_dualstack: bool, is_udp: bool) -> crate::Result<Self> {
66        let socket = Socket::new(
67            if addr.is_ipv6() {
68                Domain::IPV6
69            } else {
70                Domain::IPV4
71            },
72            if is_udp {
73                socket2::Type::DGRAM
74            } else {
75                socket2::Type::STREAM
76            },
77            Some(if is_udp {
78                socket2::Protocol::UDP
79            } else {
80                socket2::Protocol::TCP
81            }),
82        )
83        .map_err(Error::SocketNew)?;
84
85        let mut set_dualstack = false;
86
87        let addr_kind = match (request_dualstack, addr) {
88            (request_dualstack, SocketAddr::V6(addr))
89                if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
90            {
91                let value = !request_dualstack;
92                trace!(?addr, only_v6 = value, "setting only_v6");
93                socket
94                    .set_only_v6(value)
95                    .map_err(|e| Error::OnlyV6 { value, source: e })?;
96                #[cfg(not(windows))] // socket.only_v6() panics on windows somehow
97                trace!(?addr, only_v6=?socket.only_v6());
98                set_dualstack = true;
99                SocketAddrKind::V6 {
100                    addr,
101                    is_dualstack: request_dualstack,
102                }
103            }
104            (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
105                addr,
106                is_dualstack: false,
107            },
108            (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
109        };
110
111        if !set_dualstack {
112            debug!(
113                ?addr,
114                "ignored dualstack request as it only applies to [::] address"
115            );
116        }
117
118        #[cfg(not(windows))]
119        {
120            if !is_udp {
121                socket
122                    .set_reuse_address(true)
123                    .map_err(Error::ReuseAddress)?;
124            }
125        }
126
127        socket.bind(&addr.into()).map_err(|e| {
128            trace!(?addr, "error binding: {e:#}");
129            Error::Bind(e)
130        })?;
131
132        let local_addr: SocketAddr = socket
133            .local_addr()
134            .map_err(Error::LocalAddr)?
135            .as_socket()
136            .ok_or(Error::AsSocket)?;
137
138        let addr_kind = match (addr_kind, local_addr) {
139            (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
140            (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
141                SocketAddrKind::V6 {
142                    addr: received,
143                    is_dualstack,
144                }
145            }
146            _ => {
147                tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
148                return Err(Error::LocalBindAddrMismatch);
149            }
150        };
151
152        socket
153            .set_nonblocking(true)
154            .map_err(Error::SetNonblocking)?;
155
156        Ok(Self { socket, addr_kind })
157    }
158}
159
160impl MaybeDualstackSocket<tokio::net::TcpListener> {
161    pub fn bind_tcp(addr: SocketAddr, request_dualstack: bool) -> crate::Result<Self> {
162        let sock = MaybeDualstackSocket::bind(addr, request_dualstack, false)?;
163
164        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
165        sock.socket().listen(1024).map_err(Error::Listen)?;
166
167        Ok(Self {
168            socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
169                .map_err(Error::TokioFromStd)?,
170            addr_kind: sock.addr_kind,
171        })
172    }
173
174    pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
175        let (s, addr) = self.socket.accept().await?;
176        Ok((s, addr.try_to_ipv4()))
177    }
178}
179
180#[cfg(feature = "axum")]
181pub mod axum {
182    use std::net::SocketAddr;
183
184    use crate::socket::MaybeDualstackSocket;
185
186    #[derive(Clone, Copy)]
187    pub struct WrappedSocketAddr(pub SocketAddr);
188    impl core::fmt::Debug for WrappedSocketAddr {
189        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190            write!(f, "{:?}", self.0)
191        }
192    }
193    impl From<SocketAddr> for WrappedSocketAddr {
194        fn from(value: SocketAddr) -> Self {
195            Self(value)
196        }
197    }
198    impl From<WrappedSocketAddr> for SocketAddr {
199        fn from(value: WrappedSocketAddr) -> Self {
200            value.0
201        }
202    }
203
204    impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
205        type Io = tokio::net::TcpStream;
206
207        type Addr = WrappedSocketAddr;
208
209        async fn accept(&mut self) -> (Self::Io, Self::Addr) {
210            use backon::{ExponentialBuilder, Retryable};
211            let (l, a) = (|| MaybeDualstackSocket::accept(self))
212                .retry(
213                    ExponentialBuilder::new()
214                        .without_max_times()
215                        .with_max_delay(std::time::Duration::from_secs(5)),
216                )
217                .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
218                .await
219                .unwrap();
220            (l, a.into())
221        }
222
223        fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
224            Ok(self.bind_addr().into())
225        }
226    }
227
228    impl
229        axum::extract::connect_info::Connected<
230            axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
231        > for WrappedSocketAddr
232    {
233        fn connect_info(
234            stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
235        ) -> Self {
236            *stream.remote_addr()
237        }
238    }
239}
240
241impl MaybeDualstackSocket<tokio::net::UdpSocket> {
242    pub fn bind_udp(addr: SocketAddr, request_dualstack: bool) -> crate::Result<Self> {
243        let sock = MaybeDualstackSocket::bind(addr, request_dualstack, true)?;
244
245        debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
246
247        Ok(Self {
248            socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
249                .map_err(Error::TokioFromStd)?,
250            addr_kind: sock.addr_kind,
251        })
252    }
253
254    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
255        let (size, addr) = self.socket.recv_from(buf).await?;
256        Ok((size, addr.try_to_ipv4()))
257    }
258
259    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
260        let target = self.convert_addr_for_send(target);
261        self.socket.send_to(buf, target).await
262    }
263
264    pub fn poll_send_to(
265        &self,
266        cx: &mut std::task::Context<'_>,
267        buf: &[u8],
268        target: SocketAddr,
269    ) -> Poll<std::io::Result<usize>> {
270        let target = self.convert_addr_for_send(target);
271        self.socket.poll_send_to(cx, buf, target)
272    }
273}