Skip to main content

uni_stream/
addr.rs

1//! Provide domain name resolution service
2
3use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
9use hickory_resolver::name_server::TokioConnectionProvider;
10use hickory_resolver::TokioResolver;
11use once_cell::sync::Lazy;
12use parking_lot::RwLock;
13use tokio::task::JoinHandle;
14
15type Result<T, E = std::io::Error> = std::result::Result<T, E>;
16type ReadyFuture<T> = future::Ready<Result<T>>;
17
18macro_rules! invalid_input {
19    ($msg:expr) => {
20        std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
21    };
22}
23
24macro_rules! try_opt {
25    ($call:expr, $msg:expr) => {
26        match $call {
27            Some(v) => v,
28            None => Err(invalid_input!($msg))?,
29        }
30    };
31}
32
33/// Converts or resolves without blocking to one or more `SocketAddr` values.
34///
35/// # DNS
36///
37/// Implemented custom DNS resolution for string type `ToSocketAddrs`,
38/// user can change default dns resolution server via [`set_custom_dns_server`].
39pub trait ToSocketAddrs {
40    /// An iterator over SocketAddr
41    type Iter: Iterator<Item = SocketAddr> + Send + 'static;
42    /// Future representing an iterator
43    type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
44
45    /// Returns an asynchronous iterator for getting `SocketAddr`
46    fn to_socket_addrs(&self) -> Self::Future;
47}
48
49impl ToSocketAddrs for SocketAddr {
50    type Future = ReadyFuture<Self::Iter>;
51    type Iter = std::option::IntoIter<SocketAddr>;
52
53    fn to_socket_addrs(&self) -> Self::Future {
54        let iter = Some(*self).into_iter();
55        future::ready(Ok(iter))
56    }
57}
58
59impl ToSocketAddrs for SocketAddrV4 {
60    type Future = ReadyFuture<Self::Iter>;
61    type Iter = std::option::IntoIter<SocketAddr>;
62
63    fn to_socket_addrs(&self) -> Self::Future {
64        SocketAddr::V4(*self).to_socket_addrs()
65    }
66}
67
68impl ToSocketAddrs for SocketAddrV6 {
69    type Future = ReadyFuture<Self::Iter>;
70    type Iter = std::option::IntoIter<SocketAddr>;
71
72    fn to_socket_addrs(&self) -> Self::Future {
73        SocketAddr::V6(*self).to_socket_addrs()
74    }
75}
76
77impl ToSocketAddrs for (IpAddr, u16) {
78    type Future = ReadyFuture<Self::Iter>;
79    type Iter = std::option::IntoIter<SocketAddr>;
80
81    fn to_socket_addrs(&self) -> Self::Future {
82        let iter = Some(SocketAddr::from(*self)).into_iter();
83        future::ready(Ok(iter))
84    }
85}
86
87impl ToSocketAddrs for (Ipv4Addr, u16) {
88    type Future = ReadyFuture<Self::Iter>;
89    type Iter = std::option::IntoIter<SocketAddr>;
90
91    fn to_socket_addrs(&self) -> Self::Future {
92        let (ip, port) = *self;
93        SocketAddrV4::new(ip, port).to_socket_addrs()
94    }
95}
96
97impl ToSocketAddrs for (Ipv6Addr, u16) {
98    type Future = ReadyFuture<Self::Iter>;
99    type Iter = std::option::IntoIter<SocketAddr>;
100
101    fn to_socket_addrs(&self) -> Self::Future {
102        let (ip, port) = *self;
103        SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
104    }
105}
106
107impl ToSocketAddrs for &[SocketAddr] {
108    type Future = ReadyFuture<Self::Iter>;
109    type Iter = std::vec::IntoIter<SocketAddr>;
110
111    fn to_socket_addrs(&self) -> Self::Future {
112        #[inline]
113        fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
114            addrs.to_vec()
115        }
116
117        // This uses a helper method because clippy doesn't like the `to_vec()`
118        // call here (it will allocate, whereas `self.iter().copied()` would
119        // not), but it's actually necessary in order to ensure that the
120        // returned iterator is valid for the `'static` lifetime, which the
121        // borrowed `slice::Iter` iterator would not be.
122        let iter = slice_to_vec(self).into_iter();
123        future::ready(Ok(iter))
124    }
125}
126
127/// Represents one or more SockeAddr, since a String type may be a domain name or a direct address
128#[derive(Debug)]
129pub enum OneOrMore {
130    /// Direct address
131    One(std::option::IntoIter<SocketAddr>),
132    /// Addresses resolved by dns
133    More(std::vec::IntoIter<SocketAddr>),
134}
135
136#[derive(Debug)]
137enum State {
138    Ready(Option<SocketAddr>),
139    Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
140}
141
142/// Implement Future to return asynchronous results
143#[derive(Debug)]
144pub struct MaybeReady(State);
145
146impl Future for MaybeReady {
147    type Output = Result<OneOrMore>;
148
149    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        match self.0 {
151            State::Ready(ref mut i) => {
152                let iter = OneOrMore::One(i.take().into_iter());
153                Poll::Ready(Ok(iter))
154            }
155            State::Blocking(ref mut rx) => {
156                let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
157
158                Poll::Ready(res)
159            }
160        }
161    }
162}
163
164impl Iterator for OneOrMore {
165    type Item = SocketAddr;
166
167    fn next(&mut self) -> Option<Self::Item> {
168        match self {
169            OneOrMore::One(i) => i.next(),
170            OneOrMore::More(i) => i.next(),
171        }
172    }
173
174    fn size_hint(&self) -> (usize, Option<usize>) {
175        match self {
176            OneOrMore::One(i) => i.size_hint(),
177            OneOrMore::More(i) => i.size_hint(),
178        }
179    }
180}
181
182// ===== impl &str =====
183
184impl ToSocketAddrs for str {
185    type Future = MaybeReady;
186    type Iter = OneOrMore;
187
188    fn to_socket_addrs(&self) -> Self::Future {
189        // First check if the input parses as a socket address
190        let res: Result<SocketAddr, _> = self.parse();
191        if let Ok(addr) = res {
192            return MaybeReady(State::Ready(Some(addr)));
193        }
194
195        // Run DNS lookup on the blocking pool
196        let s = self.to_owned();
197
198        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
199            // Customized dns resolvers are preferred, if a custom resolver does not exist then the
200            // standard library's
201            get_socket_addrs_inner(&s).map(|v| v.into_iter())
202        })))
203    }
204}
205
206/// Implement this trait for &T of type !Sized(such as str), since &T of type Sized all implement it
207/// by default.
208impl<T> ToSocketAddrs for &T
209where
210    T: ToSocketAddrs + ?Sized,
211{
212    type Future = T::Future;
213    type Iter = T::Iter;
214
215    fn to_socket_addrs(&self) -> Self::Future {
216        (**self).to_socket_addrs()
217    }
218}
219
220// ===== impl (&str,u16) =====
221
222impl ToSocketAddrs for (&str, u16) {
223    type Future = MaybeReady;
224    type Iter = OneOrMore;
225
226    fn to_socket_addrs(&self) -> Self::Future {
227        let (host, port) = *self;
228
229        // try to parse the host as a regular IP address first
230        if let Ok(addr) = host.parse::<Ipv4Addr>() {
231            let addr = SocketAddrV4::new(addr, port);
232            let addr = SocketAddr::V4(addr);
233
234            return MaybeReady(State::Ready(Some(addr)));
235        }
236
237        if let Ok(addr) = host.parse::<Ipv6Addr>() {
238            let addr = SocketAddrV6::new(addr, port, 0, 0);
239            let addr = SocketAddr::V6(addr);
240
241            return MaybeReady(State::Ready(Some(addr)));
242        }
243
244        let host = host.to_owned();
245
246        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
247            get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
248        })))
249    }
250}
251
252// ===== impl (String,u16) =====
253
254impl ToSocketAddrs for (String, u16) {
255    type Future = MaybeReady;
256    type Iter = OneOrMore;
257
258    fn to_socket_addrs(&self) -> Self::Future {
259        (self.0.as_str(), self.1).to_socket_addrs()
260    }
261}
262
263// ===== impl String =====
264
265impl ToSocketAddrs for String {
266    type Future = <str as ToSocketAddrs>::Future;
267    type Iter = <str as ToSocketAddrs>::Iter;
268
269    fn to_socket_addrs(&self) -> Self::Future {
270        self[..].to_socket_addrs()
271    }
272}
273
274/// Default dns resolution server
275const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
276    IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), // alibaba
277    IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
278    IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), // tencent
279    IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),      // google
280    IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), // google
281];
282
283/// Customized dns resolution server
284static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
285    Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
286
287const DNS_QUERY_PORT: u16 = 53;
288
289#[inline]
290fn get_custom_resolver() -> Result<TokioResolver> {
291    let dns_group = DNS_SERVER_GROUP.read();
292    let config = ResolverConfig::from_parts(
293        None,
294        vec![],
295        NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
296    );
297    let mut builder =
298        TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
299    *builder.options_mut() = ResolverOpts::default();
300    Ok(builder.build())
301}
302
303/// Set up DNS servers, use `DEFAULT_DNS_SERVER_GROUP` by default
304/// Note: must be called before the first network connection to be effective
305#[inline]
306pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
307    let mut writer = DNS_SERVER_GROUP.write();
308    let servers: &mut Vec<IpAddr> = writer.as_mut();
309    servers.clear();
310    dns_addrs.iter().for_each(|&a| servers.push(a));
311    Ok(())
312}
313
314/// Resolving domain to get `IpAddr`
315/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
316pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
317    let s = s.to_owned();
318    tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
319        .await
320        .map_err(|_| invalid_input!("get ip addrs"))?
321}
322
323/// Resolving domain to get `IpAddr`
324/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
325fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
326    thread_local! {
327        static RESOLVER:Option<TokioResolver> = {
328            match get_custom_resolver(){
329                Ok(v) => Some(v),
330                Err(e) => {
331                    tracing::error!("create resolver error:{e}");
332                    None
333                },
334            }
335        };
336    }
337    let resolver = RESOLVER.with(|r| r.clone());
338    let resolver = try_opt!(resolver, "custom resolver not exist");
339    let handle = tokio::runtime::Handle::try_current()
340        .map_err(|_| invalid_input!("tokio runtime not found"))?;
341    let lookup = handle
342        .block_on(resolver.lookup_ip(s))
343        .map_err(|e| invalid_input!(e))?;
344    Ok(lookup.into_iter().collect())
345}
346
347/// Resolving domain and port to get `SocketAddr`
348#[inline]
349pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
350    let s = s.to_owned();
351    tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
352        .await
353        .map_err(|_| invalid_input!("get socket addrs from host port"))?
354}
355
356/// Resolving domain and port to get `SocketAddr`
357/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
358#[inline]
359fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
360    match get_ip_addrs_inner(host) {
361        Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
362        // Resolve dns properly with the standard library
363        Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
364    }
365}
366
367/// Resolving `domain:port` forms,such as bilibili.com:1080
368#[inline]
369pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
370    let s = s.to_owned();
371    tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
372        .await
373        .map_err(|_| invalid_input!("get socket addrs"))?
374}
375
376/// Resolving `domain:port` forms,such as bilibili.com:1080
377/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
378#[inline]
379fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
380    let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
381    let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
382    get_socket_addrs_from_host_port_inner(host, port)
383}
384
385/// Look up all the socket addr's and pass in the method to get the result
386pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
387where
388    F: Fn(SocketAddr) -> R,
389    R: std::future::Future<Output = Result<T>>,
390{
391    let addrs = match addr.to_socket_addrs().await {
392        Ok(addrs) => addrs,
393        Err(e) => return Err(e),
394    };
395    let mut last_err = None;
396    for addr in addrs {
397        match f(addr).await {
398            Ok(l) => return Ok(l),
399            Err(e) => last_err = Some(e),
400        }
401    }
402    Err(last_err.unwrap_or_else(|| {
403        std::io::Error::new(
404            std::io::ErrorKind::InvalidInput,
405            "could not resolve to any addresses",
406        )
407    }))
408}