1use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
9use hickory_resolver::name_server::TokioConnectionProvider;
10use hickory_resolver::TokioResolver;
11use once_cell::sync::Lazy;
12use parking_lot::RwLock;
13use tokio::task::JoinHandle;
14
15type Result<T, E = std::io::Error> = std::result::Result<T, E>;
16type ReadyFuture<T> = future::Ready<Result<T>>;
17
18macro_rules! invalid_input {
19 ($msg:expr) => {
20 std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
21 };
22}
23
24macro_rules! try_opt {
25 ($call:expr, $msg:expr) => {
26 match $call {
27 Some(v) => v,
28 None => Err(invalid_input!($msg))?,
29 }
30 };
31}
32
33pub trait ToSocketAddrs {
40 type Iter: Iterator<Item = SocketAddr> + Send + 'static;
42 type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
44
45 fn to_socket_addrs(&self) -> Self::Future;
47}
48
49impl ToSocketAddrs for SocketAddr {
50 type Future = ReadyFuture<Self::Iter>;
51 type Iter = std::option::IntoIter<SocketAddr>;
52
53 fn to_socket_addrs(&self) -> Self::Future {
54 let iter = Some(*self).into_iter();
55 future::ready(Ok(iter))
56 }
57}
58
59impl ToSocketAddrs for SocketAddrV4 {
60 type Future = ReadyFuture<Self::Iter>;
61 type Iter = std::option::IntoIter<SocketAddr>;
62
63 fn to_socket_addrs(&self) -> Self::Future {
64 SocketAddr::V4(*self).to_socket_addrs()
65 }
66}
67
68impl ToSocketAddrs for SocketAddrV6 {
69 type Future = ReadyFuture<Self::Iter>;
70 type Iter = std::option::IntoIter<SocketAddr>;
71
72 fn to_socket_addrs(&self) -> Self::Future {
73 SocketAddr::V6(*self).to_socket_addrs()
74 }
75}
76
77impl ToSocketAddrs for (IpAddr, u16) {
78 type Future = ReadyFuture<Self::Iter>;
79 type Iter = std::option::IntoIter<SocketAddr>;
80
81 fn to_socket_addrs(&self) -> Self::Future {
82 let iter = Some(SocketAddr::from(*self)).into_iter();
83 future::ready(Ok(iter))
84 }
85}
86
87impl ToSocketAddrs for (Ipv4Addr, u16) {
88 type Future = ReadyFuture<Self::Iter>;
89 type Iter = std::option::IntoIter<SocketAddr>;
90
91 fn to_socket_addrs(&self) -> Self::Future {
92 let (ip, port) = *self;
93 SocketAddrV4::new(ip, port).to_socket_addrs()
94 }
95}
96
97impl ToSocketAddrs for (Ipv6Addr, u16) {
98 type Future = ReadyFuture<Self::Iter>;
99 type Iter = std::option::IntoIter<SocketAddr>;
100
101 fn to_socket_addrs(&self) -> Self::Future {
102 let (ip, port) = *self;
103 SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
104 }
105}
106
107impl ToSocketAddrs for &[SocketAddr] {
108 type Future = ReadyFuture<Self::Iter>;
109 type Iter = std::vec::IntoIter<SocketAddr>;
110
111 fn to_socket_addrs(&self) -> Self::Future {
112 #[inline]
113 fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
114 addrs.to_vec()
115 }
116
117 let iter = slice_to_vec(self).into_iter();
123 future::ready(Ok(iter))
124 }
125}
126
127#[derive(Debug)]
129pub enum OneOrMore {
130 One(std::option::IntoIter<SocketAddr>),
132 More(std::vec::IntoIter<SocketAddr>),
134}
135
136#[derive(Debug)]
137enum State {
138 Ready(Option<SocketAddr>),
139 Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
140}
141
142#[derive(Debug)]
144pub struct MaybeReady(State);
145
146impl Future for MaybeReady {
147 type Output = Result<OneOrMore>;
148
149 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 match self.0 {
151 State::Ready(ref mut i) => {
152 let iter = OneOrMore::One(i.take().into_iter());
153 Poll::Ready(Ok(iter))
154 }
155 State::Blocking(ref mut rx) => {
156 let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
157
158 Poll::Ready(res)
159 }
160 }
161 }
162}
163
164impl Iterator for OneOrMore {
165 type Item = SocketAddr;
166
167 fn next(&mut self) -> Option<Self::Item> {
168 match self {
169 OneOrMore::One(i) => i.next(),
170 OneOrMore::More(i) => i.next(),
171 }
172 }
173
174 fn size_hint(&self) -> (usize, Option<usize>) {
175 match self {
176 OneOrMore::One(i) => i.size_hint(),
177 OneOrMore::More(i) => i.size_hint(),
178 }
179 }
180}
181
182impl ToSocketAddrs for str {
185 type Future = MaybeReady;
186 type Iter = OneOrMore;
187
188 fn to_socket_addrs(&self) -> Self::Future {
189 let res: Result<SocketAddr, _> = self.parse();
191 if let Ok(addr) = res {
192 return MaybeReady(State::Ready(Some(addr)));
193 }
194
195 let s = self.to_owned();
197
198 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
199 get_socket_addrs_inner(&s).map(|v| v.into_iter())
202 })))
203 }
204}
205
206impl<T> ToSocketAddrs for &T
209where
210 T: ToSocketAddrs + ?Sized,
211{
212 type Future = T::Future;
213 type Iter = T::Iter;
214
215 fn to_socket_addrs(&self) -> Self::Future {
216 (**self).to_socket_addrs()
217 }
218}
219
220impl ToSocketAddrs for (&str, u16) {
223 type Future = MaybeReady;
224 type Iter = OneOrMore;
225
226 fn to_socket_addrs(&self) -> Self::Future {
227 let (host, port) = *self;
228
229 if let Ok(addr) = host.parse::<Ipv4Addr>() {
231 let addr = SocketAddrV4::new(addr, port);
232 let addr = SocketAddr::V4(addr);
233
234 return MaybeReady(State::Ready(Some(addr)));
235 }
236
237 if let Ok(addr) = host.parse::<Ipv6Addr>() {
238 let addr = SocketAddrV6::new(addr, port, 0, 0);
239 let addr = SocketAddr::V6(addr);
240
241 return MaybeReady(State::Ready(Some(addr)));
242 }
243
244 let host = host.to_owned();
245
246 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
247 get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
248 })))
249 }
250}
251
252impl ToSocketAddrs for (String, u16) {
255 type Future = MaybeReady;
256 type Iter = OneOrMore;
257
258 fn to_socket_addrs(&self) -> Self::Future {
259 (self.0.as_str(), self.1).to_socket_addrs()
260 }
261}
262
263impl ToSocketAddrs for String {
266 type Future = <str as ToSocketAddrs>::Future;
267 type Iter = <str as ToSocketAddrs>::Iter;
268
269 fn to_socket_addrs(&self) -> Self::Future {
270 self[..].to_socket_addrs()
271 }
272}
273
274const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
276 IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
278 IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), ];
282
283static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
285 Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
286
287const DNS_QUERY_PORT: u16 = 53;
288
289#[inline]
290fn get_custom_resolver() -> Result<TokioResolver> {
291 let dns_group = DNS_SERVER_GROUP.read();
292 let config = ResolverConfig::from_parts(
293 None,
294 vec![],
295 NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
296 );
297 let mut builder =
298 TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
299 *builder.options_mut() = ResolverOpts::default();
300 Ok(builder.build())
301}
302
303#[inline]
306pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
307 let mut writer = DNS_SERVER_GROUP.write();
308 let servers: &mut Vec<IpAddr> = writer.as_mut();
309 servers.clear();
310 dns_addrs.iter().for_each(|&a| servers.push(a));
311 Ok(())
312}
313
314pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
317 let s = s.to_owned();
318 tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
319 .await
320 .map_err(|_| invalid_input!("get ip addrs"))?
321}
322
323fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
326 thread_local! {
327 static RESOLVER:Option<TokioResolver> = {
328 match get_custom_resolver(){
329 Ok(v) => Some(v),
330 Err(e) => {
331 tracing::error!("create resolver error:{e}");
332 None
333 },
334 }
335 };
336 }
337 let resolver = RESOLVER.with(|r| r.clone());
338 let resolver = try_opt!(resolver, "custom resolver not exist");
339 let handle = tokio::runtime::Handle::try_current()
340 .map_err(|_| invalid_input!("tokio runtime not found"))?;
341 let lookup = handle
342 .block_on(resolver.lookup_ip(s))
343 .map_err(|e| invalid_input!(e))?;
344 Ok(lookup.into_iter().collect())
345}
346
347#[inline]
349pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
350 let s = s.to_owned();
351 tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
352 .await
353 .map_err(|_| invalid_input!("get socket addrs from host port"))?
354}
355
356#[inline]
359fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
360 match get_ip_addrs_inner(host) {
361 Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
362 Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
364 }
365}
366
367#[inline]
369pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
370 let s = s.to_owned();
371 tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
372 .await
373 .map_err(|_| invalid_input!("get socket addrs"))?
374}
375
376#[inline]
379fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
380 let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
381 let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
382 get_socket_addrs_from_host_port_inner(host, port)
383}
384
385pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
387where
388 F: Fn(SocketAddr) -> R,
389 R: std::future::Future<Output = Result<T>>,
390{
391 let addrs = match addr.to_socket_addrs().await {
392 Ok(addrs) => addrs,
393 Err(e) => return Err(e),
394 };
395 let mut last_err = None;
396 for addr in addrs {
397 match f(addr).await {
398 Ok(l) => return Ok(l),
399 Err(e) => last_err = Some(e),
400 }
401 }
402 Err(last_err.unwrap_or_else(|| {
403 std::io::Error::new(
404 std::io::ErrorKind::InvalidInput,
405 "could not resolve to any addresses",
406 )
407 }))
408}