librqbit_dualstack_sockets/
socket.rs1use std::{
2 net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3 task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10 Error,
11 addr::{ToV6Mapped, TryToV4},
12};
13
14#[derive(Clone, Copy, Debug)]
15pub enum SocketAddrKind {
16 V4(SocketAddrV4),
17 V6 {
18 addr: SocketAddrV6,
19 is_dualstack: bool,
20 },
21}
22
23impl SocketAddrKind {
24 fn as_socketaddr(&self) -> SocketAddr {
25 match *self {
26 SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
27 SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
28 }
29 }
30}
31
32#[derive(Clone, Copy, Debug)]
33pub struct BindOpts {
34 pub request_dualstack: bool,
35 pub reuseport: bool,
36}
37
38impl Default for BindOpts {
39 fn default() -> Self {
40 Self {
41 request_dualstack: true,
42 reuseport: false,
43 }
44 }
45}
46
47pub struct MaybeDualstackSocket<S> {
48 socket: S,
49 addr_kind: SocketAddrKind,
50}
51
52impl<S> MaybeDualstackSocket<S> {
53 pub fn socket(&self) -> &S {
54 &self.socket
55 }
56
57 pub fn bind_addr(&self) -> SocketAddr {
58 self.addr_kind.as_socketaddr()
59 }
60
61 pub fn is_dualstack(&self) -> bool {
62 matches!(
63 self.addr_kind,
64 SocketAddrKind::V6 {
65 is_dualstack: true,
66 ..
67 }
68 )
69 }
70
71 fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
72 if self.is_dualstack() {
73 return SocketAddr::V6(addr.to_ipv6_mapped());
74 }
75 addr
76 }
77}
78
79impl MaybeDualstackSocket<Socket> {
80 fn bind(addr: SocketAddr, opts: BindOpts, is_udp: bool) -> crate::Result<Self> {
81 let socket = Socket::new(
82 if addr.is_ipv6() {
83 Domain::IPV6
84 } else {
85 Domain::IPV4
86 },
87 if is_udp {
88 socket2::Type::DGRAM
89 } else {
90 socket2::Type::STREAM
91 },
92 Some(if is_udp {
93 socket2::Protocol::UDP
94 } else {
95 socket2::Protocol::TCP
96 }),
97 )
98 .map_err(Error::SocketNew)?;
99
100 let mut set_dualstack = false;
101
102 let addr_kind = match (opts.request_dualstack, addr) {
103 (request_dualstack, SocketAddr::V6(addr))
104 if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
105 {
106 let value = !request_dualstack;
107 trace!(?addr, only_v6 = value, "setting only_v6");
108 socket
109 .set_only_v6(value)
110 .map_err(|e| Error::OnlyV6 { value, source: e })?;
111 #[cfg(not(windows))] trace!(?addr, only_v6=?socket.only_v6());
113 set_dualstack = true;
114 SocketAddrKind::V6 {
115 addr,
116 is_dualstack: request_dualstack,
117 }
118 }
119 (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
120 addr,
121 is_dualstack: false,
122 },
123 (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
124 };
125
126 if !set_dualstack {
127 debug!(
128 ?addr,
129 "ignored dualstack request as it only applies to [::] address"
130 );
131 }
132
133 #[cfg(not(windows))]
134 {
135 socket
136 .set_reuse_address(true)
137 .map_err(Error::ReuseAddress)?;
138 }
139
140 #[cfg(not(windows))]
141 if opts.reuseport {
142 socket.set_reuse_port(true).map_err(Error::ReusePort)?;
143 debug!(reuse_port=?socket.reuse_port());
144 debug!(reuse_addr=?socket.reuse_address());
145 }
146
147 socket.bind(&addr.into()).map_err(|e| {
148 trace!(?addr, "error binding: {e:#}");
149 Error::Bind(e)
150 })?;
151
152 let local_addr: SocketAddr = socket
153 .local_addr()
154 .map_err(Error::LocalAddr)?
155 .as_socket()
156 .ok_or(Error::AsSocket)?;
157
158 let addr_kind = match (addr_kind, local_addr) {
159 (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
160 (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
161 SocketAddrKind::V6 {
162 addr: received,
163 is_dualstack,
164 }
165 }
166 _ => {
167 tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
168 return Err(Error::LocalBindAddrMismatch);
169 }
170 };
171
172 socket
173 .set_nonblocking(true)
174 .map_err(Error::SetNonblocking)?;
175
176 Ok(Self { socket, addr_kind })
177 }
178}
179
180impl MaybeDualstackSocket<tokio::net::TcpListener> {
181 pub fn bind_tcp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
182 let sock = MaybeDualstackSocket::bind(addr, opts, false)?;
183
184 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
185 sock.socket().listen(1024).map_err(Error::Listen)?;
186
187 Ok(Self {
188 socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
189 .map_err(Error::TokioFromStd)?,
190 addr_kind: sock.addr_kind,
191 })
192 }
193
194 pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
195 let (s, addr) = self.socket.accept().await?;
196 Ok((s, addr.try_to_ipv4()))
197 }
198}
199
200#[cfg(feature = "axum")]
201pub mod axum {
202 use std::net::SocketAddr;
203
204 use crate::socket::MaybeDualstackSocket;
205
206 #[derive(Clone, Copy)]
207 pub struct WrappedSocketAddr(pub SocketAddr);
208 impl core::fmt::Debug for WrappedSocketAddr {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 write!(f, "{:?}", self.0)
211 }
212 }
213 impl From<SocketAddr> for WrappedSocketAddr {
214 fn from(value: SocketAddr) -> Self {
215 Self(value)
216 }
217 }
218 impl From<WrappedSocketAddr> for SocketAddr {
219 fn from(value: WrappedSocketAddr) -> Self {
220 value.0
221 }
222 }
223
224 impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
225 type Io = tokio::net::TcpStream;
226
227 type Addr = WrappedSocketAddr;
228
229 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
230 use backon::{ExponentialBuilder, Retryable};
231 let (l, a) = (|| MaybeDualstackSocket::accept(self))
232 .retry(
233 ExponentialBuilder::new()
234 .without_max_times()
235 .with_max_delay(std::time::Duration::from_secs(5)),
236 )
237 .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
238 .await
239 .unwrap();
240 (l, a.into())
241 }
242
243 fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
244 Ok(self.bind_addr().into())
245 }
246 }
247
248 impl
249 axum::extract::connect_info::Connected<
250 axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
251 > for WrappedSocketAddr
252 {
253 fn connect_info(
254 stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
255 ) -> Self {
256 *stream.remote_addr()
257 }
258 }
259}
260
261impl MaybeDualstackSocket<tokio::net::UdpSocket> {
262 pub fn bind_udp(addr: SocketAddr, opts: BindOpts) -> crate::Result<Self> {
263 let sock = MaybeDualstackSocket::bind(addr, opts, true)?;
264
265 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
266
267 Ok(Self {
268 socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
269 .map_err(Error::TokioFromStd)?,
270 addr_kind: sock.addr_kind,
271 })
272 }
273
274 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
275 let (size, addr) = self.socket.recv_from(buf).await?;
276 Ok((size, addr.try_to_ipv4()))
277 }
278
279 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
280 let target = self.convert_addr_for_send(target);
281 self.socket.send_to(buf, target).await
282 }
283
284 pub fn poll_send_to(
285 &self,
286 cx: &mut std::task::Context<'_>,
287 buf: &[u8],
288 target: SocketAddr,
289 ) -> Poll<std::io::Result<usize>> {
290 let target = self.convert_addr_for_send(target);
291 self.socket.poll_send_to(cx, buf, target)
292 }
293}