librqbit_dualstack_sockets/
socket.rs1use std::{
2 net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3 task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10 Error,
11 addr::{ToV6Mapped, TryToV4},
12 bind_device::BindDevice,
13};
14
15#[derive(Clone, Copy, Debug)]
16pub enum SocketAddrKind {
17 V4(SocketAddrV4),
18 V6 {
19 addr: SocketAddrV6,
20 is_dualstack: bool,
21 },
22}
23
24impl SocketAddrKind {
25 fn is_v6(&self) -> bool {
26 matches!(self, SocketAddrKind::V6 { .. })
27 }
28
29 fn as_socketaddr(&self) -> SocketAddr {
30 match *self {
31 SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
32 SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
33 }
34 }
35}
36
37#[derive(Clone, Copy, Debug)]
38pub struct BindOpts<'a> {
39 pub request_dualstack: bool,
40 pub reuseport: bool,
41 pub device: Option<&'a BindDevice>,
42}
43
44impl Default for BindOpts<'_> {
45 fn default() -> Self {
46 Self {
47 request_dualstack: true,
48 reuseport: false,
49 device: None,
50 }
51 }
52}
53
54pub struct MaybeDualstackSocket<S> {
55 socket: S,
56 addr_kind: SocketAddrKind,
57}
58
59impl<S> MaybeDualstackSocket<S> {
60 pub fn socket(&self) -> &S {
61 &self.socket
62 }
63
64 pub fn bind_addr(&self) -> SocketAddr {
65 self.addr_kind.as_socketaddr()
66 }
67
68 pub fn is_dualstack(&self) -> bool {
69 matches!(
70 self.addr_kind,
71 SocketAddrKind::V6 {
72 is_dualstack: true,
73 ..
74 }
75 )
76 }
77
78 pub(crate) fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
79 if self.is_dualstack() {
80 return SocketAddr::V6(addr.to_ipv6_mapped());
81 }
82 addr
83 }
84}
85
86impl MaybeDualstackSocket<Socket> {
87 fn bind(addr: SocketAddr, opts: BindOpts, is_udp: bool) -> crate::Result<Self> {
88 let socket = Socket::new(
89 if addr.is_ipv6() {
90 Domain::IPV6
91 } else {
92 Domain::IPV4
93 },
94 if is_udp {
95 socket2::Type::DGRAM
96 } else {
97 socket2::Type::STREAM
98 },
99 Some(if is_udp {
100 socket2::Protocol::UDP
101 } else {
102 socket2::Protocol::TCP
103 }),
104 )
105 .map_err(Error::SocketNew)?;
106
107 let mut set_dualstack = false;
108
109 let addr_kind = match (opts.request_dualstack, addr) {
110 (request_dualstack, SocketAddr::V6(addr))
111 if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
112 {
113 let value = !request_dualstack;
114 trace!(?addr, only_v6 = value, "setting only_v6");
115 socket
116 .set_only_v6(value)
117 .map_err(|e| Error::OnlyV6 { value, source: e })?;
118 #[cfg(not(windows))] trace!(?addr, only_v6=?socket.only_v6());
120 set_dualstack = true;
121 SocketAddrKind::V6 {
122 addr,
123 is_dualstack: request_dualstack,
124 }
125 }
126 (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
127 addr,
128 is_dualstack: false,
129 },
130 (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
131 };
132
133 if !set_dualstack {
134 debug!(
135 ?addr,
136 "ignored dualstack request as it only applies to [::] address"
137 );
138 }
139
140 #[cfg(not(windows))]
141 {
142 socket
143 .set_reuse_address(true)
144 .map_err(Error::ReuseAddress)?;
145 }
146
147 #[cfg(windows)]
148 if opts.reuseport || !is_udp {
149 socket
150 .set_reuse_address(true)
151 .map_err(Error::ReuseAddress)?;
152 }
153
154 #[cfg(not(windows))]
155 if opts.reuseport {
156 socket.set_reuse_port(true).map_err(Error::ReusePort)?;
157 debug!(reuse_port=?socket.reuse_port());
158 debug!(reuse_addr=?socket.reuse_address());
159 }
160
161 if let Some(bd) = opts.device {
162 bd.bind_sref(&socket, addr_kind.is_v6())?;
163 }
164
165 socket.bind(&addr.into()).map_err(|e| {
166 trace!(?addr, "error binding: {e:#}");
167 Error::Bind(e)
168 })?;
169
170 let local_addr: SocketAddr = socket
171 .local_addr()
172 .map_err(Error::LocalAddr)?
173 .as_socket()
174 .ok_or(Error::AsSocket)?;
175
176 let addr_kind = match (addr_kind, local_addr) {
177 (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
178 (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
179 SocketAddrKind::V6 {
180 addr: received,
181 is_dualstack,
182 }
183 }
184 _ => {
185 tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
186 return Err(Error::LocalBindAddrMismatch);
187 }
188 };
189
190 socket
191 .set_nonblocking(true)
192 .map_err(Error::SetNonblocking)?;
193
194 Ok(Self { socket, addr_kind })
195 }
196}
197
198impl MaybeDualstackSocket<tokio::net::TcpListener> {
199 pub fn bind_tcp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
200 let sock = MaybeDualstackSocket::bind(addr, opts, false)?;
201
202 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
203 sock.socket().listen(1024).map_err(Error::Listen)?;
204
205 Ok(Self {
206 socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
207 .map_err(Error::TokioFromStd)?,
208 addr_kind: sock.addr_kind,
209 })
210 }
211
212 pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
213 let (s, addr) = self.socket.accept().await?;
214 Ok((s, addr.try_to_ipv4()))
215 }
216}
217
218#[cfg(feature = "axum")]
219pub mod axum {
220 use std::net::SocketAddr;
221
222 use crate::socket::MaybeDualstackSocket;
223
224 #[derive(Clone, Copy)]
225 pub struct WrappedSocketAddr(pub SocketAddr);
226 impl core::fmt::Debug for WrappedSocketAddr {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 write!(f, "{:?}", self.0)
229 }
230 }
231 impl From<SocketAddr> for WrappedSocketAddr {
232 fn from(value: SocketAddr) -> Self {
233 Self(value)
234 }
235 }
236 impl From<WrappedSocketAddr> for SocketAddr {
237 fn from(value: WrappedSocketAddr) -> Self {
238 value.0
239 }
240 }
241
242 impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
243 type Io = tokio::net::TcpStream;
244
245 type Addr = WrappedSocketAddr;
246
247 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
248 use backon::{ExponentialBuilder, Retryable};
249 let (l, a) = (|| MaybeDualstackSocket::accept(self))
250 .retry(
251 ExponentialBuilder::new()
252 .without_max_times()
253 .with_max_delay(std::time::Duration::from_secs(5)),
254 )
255 .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
256 .await
257 .unwrap();
258 (l, a.into())
259 }
260
261 fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
262 Ok(self.bind_addr().into())
263 }
264 }
265
266 impl
267 axum::extract::connect_info::Connected<
268 axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
269 > for WrappedSocketAddr
270 {
271 fn connect_info(
272 stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
273 ) -> Self {
274 *stream.remote_addr()
275 }
276 }
277}
278
279impl MaybeDualstackSocket<tokio::net::UdpSocket> {
280 pub fn bind_udp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
281 let sock = MaybeDualstackSocket::bind(addr, opts, true)?;
282
283 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
284
285 Ok(Self {
286 socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
287 .map_err(Error::TokioFromStd)?,
288 addr_kind: sock.addr_kind,
289 })
290 }
291
292 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
293 let (size, addr) = self.socket.recv_from(buf).await?;
294 Ok((size, addr.try_to_ipv4()))
295 }
296
297 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
298 let target = self.convert_addr_for_send(target);
299 self.socket.send_to(buf, target).await
300 }
301
302 pub fn poll_send_to(
303 &self,
304 cx: &mut std::task::Context<'_>,
305 buf: &[u8],
306 target: SocketAddr,
307 ) -> Poll<std::io::Result<usize>> {
308 let target = self.convert_addr_for_send(target);
309 self.socket.poll_send_to(cx, buf, target)
310 }
311}