1use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::sync::RwLock;
7use std::task::{ready, Context, Poll};
8
9use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
10use hickory_resolver::name_server::TokioConnectionProvider;
11use hickory_resolver::TokioResolver;
12use once_cell::sync::Lazy;
13use tokio::task::JoinHandle;
14
15type Result<T, E = std::io::Error> = std::result::Result<T, E>;
16type ReadyFuture<T> = future::Ready<Result<T>>;
17
18macro_rules! invalid_input {
19 ($msg:expr) => {
20 std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
21 };
22}
23
24macro_rules! try_opt {
25 ($call:expr, $msg:expr) => {
26 match $call {
27 Some(v) => v,
28 None => Err(invalid_input!($msg))?,
29 }
30 };
31}
32
33macro_rules! try_ret {
34 ($call:expr, $msg:expr) => {
35 match $call {
36 Ok(v) => v,
37 Err(e) => Err(invalid_input!(format!("{} ,detail:{e}", $msg)))?,
38 }
39 };
40}
41
42pub trait ToSocketAddrs {
49 type Iter: Iterator<Item = SocketAddr> + Send + 'static;
51 type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
53
54 fn to_socket_addrs(&self) -> Self::Future;
56}
57
58impl ToSocketAddrs for SocketAddr {
59 type Future = ReadyFuture<Self::Iter>;
60 type Iter = std::option::IntoIter<SocketAddr>;
61
62 fn to_socket_addrs(&self) -> Self::Future {
63 let iter = Some(*self).into_iter();
64 future::ready(Ok(iter))
65 }
66}
67
68impl ToSocketAddrs for SocketAddrV4 {
69 type Future = ReadyFuture<Self::Iter>;
70 type Iter = std::option::IntoIter<SocketAddr>;
71
72 fn to_socket_addrs(&self) -> Self::Future {
73 SocketAddr::V4(*self).to_socket_addrs()
74 }
75}
76
77impl ToSocketAddrs for SocketAddrV6 {
78 type Future = ReadyFuture<Self::Iter>;
79 type Iter = std::option::IntoIter<SocketAddr>;
80
81 fn to_socket_addrs(&self) -> Self::Future {
82 SocketAddr::V6(*self).to_socket_addrs()
83 }
84}
85
86impl ToSocketAddrs for (IpAddr, u16) {
87 type Future = ReadyFuture<Self::Iter>;
88 type Iter = std::option::IntoIter<SocketAddr>;
89
90 fn to_socket_addrs(&self) -> Self::Future {
91 let iter = Some(SocketAddr::from(*self)).into_iter();
92 future::ready(Ok(iter))
93 }
94}
95
96impl ToSocketAddrs for (Ipv4Addr, u16) {
97 type Future = ReadyFuture<Self::Iter>;
98 type Iter = std::option::IntoIter<SocketAddr>;
99
100 fn to_socket_addrs(&self) -> Self::Future {
101 let (ip, port) = *self;
102 SocketAddrV4::new(ip, port).to_socket_addrs()
103 }
104}
105
106impl ToSocketAddrs for (Ipv6Addr, u16) {
107 type Future = ReadyFuture<Self::Iter>;
108 type Iter = std::option::IntoIter<SocketAddr>;
109
110 fn to_socket_addrs(&self) -> Self::Future {
111 let (ip, port) = *self;
112 SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
113 }
114}
115
116impl ToSocketAddrs for &[SocketAddr] {
117 type Future = ReadyFuture<Self::Iter>;
118 type Iter = std::vec::IntoIter<SocketAddr>;
119
120 fn to_socket_addrs(&self) -> Self::Future {
121 #[inline]
122 fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
123 addrs.to_vec()
124 }
125
126 let iter = slice_to_vec(self).into_iter();
132 future::ready(Ok(iter))
133 }
134}
135
136#[derive(Debug)]
138pub enum OneOrMore {
139 One(std::option::IntoIter<SocketAddr>),
141 More(std::vec::IntoIter<SocketAddr>),
143}
144
145#[derive(Debug)]
146enum State {
147 Ready(Option<SocketAddr>),
148 Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
149}
150
151#[derive(Debug)]
153pub struct MaybeReady(State);
154
155impl Future for MaybeReady {
156 type Output = Result<OneOrMore>;
157
158 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 match self.0 {
160 State::Ready(ref mut i) => {
161 let iter = OneOrMore::One(i.take().into_iter());
162 Poll::Ready(Ok(iter))
163 }
164 State::Blocking(ref mut rx) => {
165 let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
166
167 Poll::Ready(res)
168 }
169 }
170 }
171}
172
173impl Iterator for OneOrMore {
174 type Item = SocketAddr;
175
176 fn next(&mut self) -> Option<Self::Item> {
177 match self {
178 OneOrMore::One(i) => i.next(),
179 OneOrMore::More(i) => i.next(),
180 }
181 }
182
183 fn size_hint(&self) -> (usize, Option<usize>) {
184 match self {
185 OneOrMore::One(i) => i.size_hint(),
186 OneOrMore::More(i) => i.size_hint(),
187 }
188 }
189}
190
191impl ToSocketAddrs for str {
194 type Future = MaybeReady;
195 type Iter = OneOrMore;
196
197 fn to_socket_addrs(&self) -> Self::Future {
198 let res: Result<SocketAddr, _> = self.parse();
200 if let Ok(addr) = res {
201 return MaybeReady(State::Ready(Some(addr)));
202 }
203
204 let s = self.to_owned();
206
207 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
208 get_socket_addrs_inner(&s).map(|v| v.into_iter())
211 })))
212 }
213}
214
215impl<T> ToSocketAddrs for &T
218where
219 T: ToSocketAddrs + ?Sized,
220{
221 type Future = T::Future;
222 type Iter = T::Iter;
223
224 fn to_socket_addrs(&self) -> Self::Future {
225 (**self).to_socket_addrs()
226 }
227}
228
229impl ToSocketAddrs for (&str, u16) {
232 type Future = MaybeReady;
233 type Iter = OneOrMore;
234
235 fn to_socket_addrs(&self) -> Self::Future {
236 let (host, port) = *self;
237
238 if let Ok(addr) = host.parse::<Ipv4Addr>() {
240 let addr = SocketAddrV4::new(addr, port);
241 let addr = SocketAddr::V4(addr);
242
243 return MaybeReady(State::Ready(Some(addr)));
244 }
245
246 if let Ok(addr) = host.parse::<Ipv6Addr>() {
247 let addr = SocketAddrV6::new(addr, port, 0, 0);
248 let addr = SocketAddr::V6(addr);
249
250 return MaybeReady(State::Ready(Some(addr)));
251 }
252
253 let host = host.to_owned();
254
255 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
256 get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
257 })))
258 }
259}
260
261impl ToSocketAddrs for (String, u16) {
264 type Future = MaybeReady;
265 type Iter = OneOrMore;
266
267 fn to_socket_addrs(&self) -> Self::Future {
268 (self.0.as_str(), self.1).to_socket_addrs()
269 }
270}
271
272impl ToSocketAddrs for String {
275 type Future = <str as ToSocketAddrs>::Future;
276 type Iter = <str as ToSocketAddrs>::Iter;
277
278 fn to_socket_addrs(&self) -> Self::Future {
279 self[..].to_socket_addrs()
280 }
281}
282
283const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
285 IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
287 IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), ];
291
292static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
294 Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
295
296const DNS_QUERY_PORT: u16 = 53;
297
298#[inline]
299fn get_custom_resolver() -> Result<TokioResolver> {
300 let dns_group = try_ret!(DNS_SERVER_GROUP.read(), "read dns server");
301 let config = ResolverConfig::from_parts(
302 None,
303 vec![],
304 NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
305 );
306 let mut builder =
307 TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
308 *builder.options_mut() = ResolverOpts::default();
309 Ok(builder.build())
310}
311
312#[inline]
315pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
316 let mut writer = DNS_SERVER_GROUP
317 .write()
318 .map_err(|e| invalid_input!(format!("get dns server writer, detail:{e}")))?;
319 let servers: &mut Vec<IpAddr> = writer.as_mut();
320 servers.clear();
321 dns_addrs.iter().for_each(|&a| servers.push(a));
322 Ok(())
323}
324
325pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
328 let s = s.to_owned();
329 tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
330 .await
331 .map_err(|_| invalid_input!("get ip addrs"))?
332}
333
334fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
337 thread_local! {
338 static RESOLVER:Option<TokioResolver> = {
339 match get_custom_resolver(){
340 Ok(v) => Some(v),
341 Err(e) => {
342 tracing::error!("create resolver error:{e}");
343 None
344 },
345 }
346 };
347 }
348 let resolver = RESOLVER.with(|r| r.clone());
349 let resolver = try_opt!(resolver, "custom resolver not exist");
350 let handle = tokio::runtime::Handle::try_current()
351 .map_err(|_| invalid_input!("tokio runtime not found"))?;
352 let lookup = handle
353 .block_on(resolver.lookup_ip(s))
354 .map_err(|e| invalid_input!(e))?;
355 Ok(lookup.into_iter().collect())
356}
357
358#[inline]
360pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
361 let s = s.to_owned();
362 tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
363 .await
364 .map_err(|_| invalid_input!("get socket addrs from host port"))?
365}
366
367#[inline]
370fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
371 match get_ip_addrs_inner(host) {
372 Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
373 Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
375 }
376}
377
378#[inline]
380pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
381 let s = s.to_owned();
382 tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
383 .await
384 .map_err(|_| invalid_input!("get socket addrs"))?
385}
386
387#[inline]
390fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
391 let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
392 let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
393 get_socket_addrs_from_host_port_inner(host, port)
394}
395
396pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
398where
399 F: Fn(SocketAddr) -> R,
400 R: std::future::Future<Output = Result<T>>,
401{
402 let addrs = match addr.to_socket_addrs().await {
403 Ok(addrs) => addrs,
404 Err(e) => return Err(e),
405 };
406 let mut last_err = None;
407 for addr in addrs {
408 match f(addr).await {
409 Ok(l) => return Ok(l),
410 Err(e) => last_err = Some(e),
411 }
412 }
413 Err(last_err.unwrap_or_else(|| {
414 std::io::Error::new(
415 std::io::ErrorKind::InvalidInput,
416 "could not resolve to any addresses",
417 )
418 }))
419}