Skip to main content

async_rs/implementors/
hickory.rs

1use crate::{
2    traits::AsyncToSocketAddrs,
3    util::{self, SocketAddrsFromIpAddrs},
4};
5use hickory_resolver::{TokioResolver, proto::rr::IntoName};
6use std::{
7    io,
8    net::{IpAddr, SocketAddr, ToSocketAddrs},
9    str::FromStr,
10    sync::OnceLock,
11    vec,
12};
13
14static RESOLVER: OnceLock<TokioResolver> = OnceLock::new();
15
16fn get_or_init_resolver() -> io::Result<&'static TokioResolver> {
17    // FIXME: replace with RESOLVER.get_or_try_init(...) once it stabilises (rust#109737)
18    if let Some(r) = RESOLVER.get() {
19        return Ok(r);
20    }
21    let resolver = TokioResolver::builder_tokio()
22        .map_err(io::Error::other)?
23        .build()
24        .map_err(io::Error::other)?;
25    Ok(RESOLVER.get_or_init(|| resolver))
26}
27
28/// Perform async DNS resolution using hickory-dns
29#[derive(Debug, Clone)]
30pub struct HickoryToSocketAddrs<T: IntoName + Send + 'static> {
31    host: T,
32    port: u16,
33}
34
35impl<H: IntoName + Send + 'static> HickoryToSocketAddrs<H> {
36    /// Create a `HickoryToSocketAddrs` from split host and port components.
37    pub fn new(host: H, port: u16) -> Self {
38        Self { host, port }
39    }
40
41    async fn lookup(self) -> io::Result<SocketAddrsFromIpAddrs<vec::IntoIter<IpAddr>>> {
42        if !util::inside_tokio() {
43            return Err(io::Error::other(
44                "hickory-dns is only supported in a tokio context",
45            ));
46        }
47
48        let resolver = get_or_init_resolver()?;
49
50        Ok(SocketAddrsFromIpAddrs(
51            resolver
52                .lookup_ip(self.host)
53                .await
54                .map_err(io::Error::other)?
55                .iter()
56                .collect::<Vec<_>>() // FIXME: don't collect if we get back into_iter
57                .into_iter(),
58            self.port,
59        ))
60    }
61}
62
63impl FromStr for HickoryToSocketAddrs<String> {
64    type Err = io::Error;
65
66    fn from_str(s: &str) -> io::Result<Self> {
67        let (host, port_str) = s
68            .rsplit_once(':')
69            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid socket address"))?;
70        let port = port_str
71            .parse()
72            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid port value"))?;
73        Ok(Self::new(host.to_owned(), port))
74    }
75}
76
77impl<T: IntoName + Clone + Send + 'static> ToSocketAddrs for HickoryToSocketAddrs<T> {
78    type Iter = SocketAddrsFromIpAddrs<vec::IntoIter<IpAddr>>;
79
80    fn to_socket_addrs(&self) -> io::Result<Self::Iter> {
81        util::block_on_tokio(self.clone().lookup())
82    }
83}
84
85impl<T: IntoName + Send + 'static> AsyncToSocketAddrs for HickoryToSocketAddrs<T> {
86    fn to_socket_addrs(
87        self,
88    ) -> impl Future<Output = io::Result<impl Iterator<Item = SocketAddr> + Send + 'static>>
89    + Send
90    + 'static {
91        self.lookup()
92    }
93}