librqbit_dualstack_sockets/
socket.rs1use std::{
2 net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
3 task::Poll,
4};
5
6use socket2::{Domain, Socket};
7use tracing::{debug, trace};
8
9use crate::{
10 Error,
11 addr::{ToV6Mapped, TryToV4},
12};
13
14#[derive(Clone, Copy, Debug)]
15pub enum SocketAddrKind {
16 V4(SocketAddrV4),
17 V6 {
18 addr: SocketAddrV6,
19 is_dualstack: bool,
20 },
21}
22
23impl SocketAddrKind {
24 fn as_socketaddr(&self) -> SocketAddr {
25 match *self {
26 SocketAddrKind::V4(addr) => SocketAddr::V4(addr),
27 SocketAddrKind::V6 { addr, .. } => SocketAddr::V6(addr),
28 }
29 }
30}
31
32pub struct MaybeDualstackSocket<S> {
33 socket: S,
34 addr_kind: SocketAddrKind,
35}
36
37impl<S> MaybeDualstackSocket<S> {
38 pub fn socket(&self) -> &S {
39 &self.socket
40 }
41
42 pub fn bind_addr(&self) -> SocketAddr {
43 self.addr_kind.as_socketaddr()
44 }
45
46 pub fn is_dualstack(&self) -> bool {
47 matches!(
48 self.addr_kind,
49 SocketAddrKind::V6 {
50 is_dualstack: true,
51 ..
52 }
53 )
54 }
55
56 fn convert_addr_for_send(&self, addr: SocketAddr) -> SocketAddr {
57 if self.is_dualstack() {
58 return SocketAddr::V6(addr.to_ipv6_mapped());
59 }
60 addr
61 }
62}
63
64impl MaybeDualstackSocket<Socket> {
65 fn bind(addr: SocketAddr, request_dualstack: bool, is_udp: bool) -> crate::Result<Self> {
66 let socket = Socket::new(
67 if addr.is_ipv6() {
68 Domain::IPV6
69 } else {
70 Domain::IPV4
71 },
72 if is_udp {
73 socket2::Type::DGRAM
74 } else {
75 socket2::Type::STREAM
76 },
77 Some(if is_udp {
78 socket2::Protocol::UDP
79 } else {
80 socket2::Protocol::TCP
81 }),
82 )
83 .map_err(Error::SocketNew)?;
84
85 let mut set_dualstack = false;
86
87 let addr_kind = match (request_dualstack, addr) {
88 (request_dualstack, SocketAddr::V6(addr))
89 if *addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) =>
90 {
91 let value = !request_dualstack;
92 trace!(?addr, only_v6 = value, "setting only_v6");
93 socket
94 .set_only_v6(value)
95 .map_err(|e| Error::OnlyV6 { value, source: e })?;
96 #[cfg(not(windows))] trace!(?addr, only_v6=?socket.only_v6());
98 set_dualstack = true;
99 SocketAddrKind::V6 {
100 addr,
101 is_dualstack: request_dualstack,
102 }
103 }
104 (_, SocketAddr::V6(addr)) => SocketAddrKind::V6 {
105 addr,
106 is_dualstack: false,
107 },
108 (_, SocketAddr::V4(addr)) => SocketAddrKind::V4(addr),
109 };
110
111 if !set_dualstack {
112 debug!(
113 ?addr,
114 "ignored dualstack request as it only applies to [::] address"
115 );
116 }
117
118 #[cfg(not(windows))]
119 {
120 if !is_udp {
121 socket
122 .set_reuse_address(true)
123 .map_err(Error::ReuseAddress)?;
124 }
125 }
126
127 socket.bind(&addr.into()).map_err(|e| {
128 trace!(?addr, "error binding: {e:#}");
129 Error::Bind(e)
130 })?;
131
132 let local_addr: SocketAddr = socket
133 .local_addr()
134 .map_err(Error::LocalAddr)?
135 .as_socket()
136 .ok_or(Error::AsSocket)?;
137
138 let addr_kind = match (addr_kind, local_addr) {
139 (SocketAddrKind::V4(..), SocketAddr::V4(received)) => SocketAddrKind::V4(received),
140 (SocketAddrKind::V6 { is_dualstack, .. }, SocketAddr::V6(received)) => {
141 SocketAddrKind::V6 {
142 addr: received,
143 is_dualstack,
144 }
145 }
146 _ => {
147 tracing::debug!(?local_addr, bind_addr=?addr, "mismatch between local_addr() and requested bind_addr");
148 return Err(Error::LocalBindAddrMismatch);
149 }
150 };
151
152 socket
153 .set_nonblocking(true)
154 .map_err(Error::SetNonblocking)?;
155
156 Ok(Self { socket, addr_kind })
157 }
158}
159
160impl MaybeDualstackSocket<tokio::net::TcpListener> {
161 pub fn bind_tcp(addr: SocketAddr, request_dualstack: bool) -> crate::Result<Self> {
162 let sock = MaybeDualstackSocket::bind(addr, request_dualstack, false)?;
163
164 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on TCP");
165 sock.socket().listen(1024).map_err(Error::Listen)?;
166
167 Ok(Self {
168 socket: tokio::net::TcpListener::from_std(std::net::TcpListener::from(sock.socket))
169 .map_err(Error::TokioFromStd)?,
170 addr_kind: sock.addr_kind,
171 })
172 }
173
174 pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
175 let (s, addr) = self.socket.accept().await?;
176 Ok((s, addr.try_to_ipv4()))
177 }
178}
179
180#[cfg(feature = "axum")]
181pub mod axum {
182 use std::net::SocketAddr;
183
184 use crate::socket::MaybeDualstackSocket;
185
186 #[derive(Clone, Copy)]
187 pub struct WrappedSocketAddr(pub SocketAddr);
188 impl core::fmt::Debug for WrappedSocketAddr {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "{:?}", self.0)
191 }
192 }
193 impl From<SocketAddr> for WrappedSocketAddr {
194 fn from(value: SocketAddr) -> Self {
195 Self(value)
196 }
197 }
198 impl From<WrappedSocketAddr> for SocketAddr {
199 fn from(value: WrappedSocketAddr) -> Self {
200 value.0
201 }
202 }
203
204 impl axum::serve::Listener for MaybeDualstackSocket<tokio::net::TcpListener> {
205 type Io = tokio::net::TcpStream;
206
207 type Addr = WrappedSocketAddr;
208
209 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
210 use backon::{ExponentialBuilder, Retryable};
211 let (l, a) = (|| MaybeDualstackSocket::accept(self))
212 .retry(
213 ExponentialBuilder::new()
214 .without_max_times()
215 .with_max_delay(std::time::Duration::from_secs(5)),
216 )
217 .notify(|e, retry_in| tracing::trace!(?retry_in, "error accepting: {e:#}"))
218 .await
219 .unwrap();
220 (l, a.into())
221 }
222
223 fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
224 Ok(self.bind_addr().into())
225 }
226 }
227
228 impl
229 axum::extract::connect_info::Connected<
230 axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
231 > for WrappedSocketAddr
232 {
233 fn connect_info(
234 stream: axum::serve::IncomingStream<'_, MaybeDualstackSocket<tokio::net::TcpListener>>,
235 ) -> Self {
236 *stream.remote_addr()
237 }
238 }
239}
240
241impl MaybeDualstackSocket<tokio::net::UdpSocket> {
242 pub fn bind_udp(addr: SocketAddr, request_dualstack: bool) -> crate::Result<Self> {
243 let sock = MaybeDualstackSocket::bind(addr, request_dualstack, true)?;
244
245 debug!(addr=?sock.bind_addr(), requested_addr=?addr, dualstack = sock.is_dualstack(), "listening on UDP");
246
247 Ok(Self {
248 socket: tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(sock.socket))
249 .map_err(Error::TokioFromStd)?,
250 addr_kind: sock.addr_kind,
251 })
252 }
253
254 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
255 let (size, addr) = self.socket.recv_from(buf).await?;
256 Ok((size, addr.try_to_ipv4()))
257 }
258
259 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
260 let target = self.convert_addr_for_send(target);
261 self.socket.send_to(buf, target).await
262 }
263
264 pub fn poll_send_to(
265 &self,
266 cx: &mut std::task::Context<'_>,
267 buf: &[u8],
268 target: SocketAddr,
269 ) -> Poll<std::io::Result<usize>> {
270 let target = self.convert_addr_for_send(target);
271 self.socket.poll_send_to(cx, buf, target)
272 }
273}