gel_stream/common/
resolver.rs

1use std::borrow::Cow;
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::{future::Future, str::FromStr, task::Poll};
4
5use crate::{MaybeResolvedTarget, ResolvedTarget, TargetName, TcpResolve};
6
7/// An async resolver for hostnames to IP addresses.
8#[derive(Clone)]
9pub struct Resolver {
10    #[cfg(feature = "hickory")]
11    resolver: std::sync::Arc<hickory_resolver::TokioResolver>,
12}
13
14#[cfg(feature = "tokio")]
15#[allow(unused)]
16async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<ResolvedTarget> {
17    let res = tokio::task::spawn_blocking(move || format!("{host}:0").to_socket_addrs())
18        .await
19        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
20    res.into_iter()
21        .next()
22        .ok_or(std::io::Error::new(
23            std::io::ErrorKind::NotFound,
24            "No address found",
25        ))
26        .map(ResolvedTarget::SocketAddr)
27}
28
29impl Resolver {
30    /// Create a new resolver.
31    pub fn new() -> Result<Self, std::io::Error> {
32        Ok(Self {
33            #[cfg(feature = "hickory")]
34            resolver: hickory_resolver::Resolver::builder_tokio()?.build().into(),
35        })
36    }
37
38    pub(crate) fn resolve_remote(
39        &self,
40        host: &MaybeResolvedTarget,
41    ) -> ResolveResult<ResolvedTarget> {
42        match host {
43            MaybeResolvedTarget::Resolved(resolved) => {
44                ResolveResult::new_sync(Ok(resolved.clone()))
45            }
46            MaybeResolvedTarget::Unresolved(host, port, _) => {
47                if let Ok(ip) = IpAddr::from_str(host) {
48                    ResolveResult::new_sync(Ok(ResolvedTarget::SocketAddr(SocketAddr::from((
49                        ip, *port,
50                    )))))
51                } else {
52                    #[cfg(feature = "hickory")]
53                    {
54                        let resolver = self.resolver.clone();
55                        let host = host.to_string();
56                        let port = *port;
57                        ResolveResult::new_async(async move {
58                            let f = resolver.lookup_ip(host);
59                            let Some(addr) = f.await?.iter().next() else {
60                                return Err(std::io::Error::new(
61                                    std::io::ErrorKind::NotFound,
62                                    "No address found",
63                                ));
64                            };
65                            Ok(ResolvedTarget::SocketAddr(SocketAddr::new(addr, port)))
66                        })
67                    }
68                    #[cfg(all(feature = "tokio", not(feature = "hickory")))]
69                    {
70                        ResolveResult::new_async(resolve_host_to_socket_addrs(host.to_string()))
71                    }
72                    #[cfg(not(any(feature = "tokio", feature = "hickory")))]
73                    {
74                        ResolveResult::new_sync(Err(std::io::Error::new(
75                            std::io::ErrorKind::Unsupported,
76                            "No resolver available",
77                        )))
78                    }
79                }
80            }
81        }
82    }
83}
84
85/// The result of a resolution. It may be synchronous or asynchronous, but you
86/// can always call `.await` on it.
87pub struct ResolveResult<T> {
88    inner: ResolveResultInner<T>,
89}
90
91impl<T> ResolveResult<T> {
92    fn new_sync(result: Result<T, std::io::Error>) -> Self {
93        Self {
94            inner: ResolveResultInner::Sync(result),
95        }
96    }
97
98    fn new_async(future: impl Future<Output = std::io::Result<T>> + Send + 'static) -> Self {
99        Self {
100            inner: ResolveResultInner::Async(Box::pin(future)),
101        }
102    }
103
104    pub fn sync(&mut self) -> Result<Option<T>, std::io::Error> {
105        if let ResolveResultInner::Sync(_) = &mut self.inner {
106            let this = std::mem::replace(&mut self.inner, ResolveResultInner::Fused);
107            let ResolveResultInner::Sync(result) = this else {
108                unreachable!()
109            };
110            result.map(Some)
111        } else {
112            Ok(None)
113        }
114    }
115
116    pub fn map<U>(self, f: impl (FnOnce(T) -> U) + Send + 'static) -> ResolveResult<U>
117    where
118        T: 'static,
119    {
120        match self.inner {
121            ResolveResultInner::Sync(Ok(t)) => ResolveResult::new_sync(Ok(f(t))),
122            ResolveResultInner::Sync(Err(e)) => ResolveResult::new_sync(Err(e)),
123            ResolveResultInner::Async(future) => {
124                ResolveResult::new_async(async move { Ok(f(future.await?)) })
125            }
126            ResolveResultInner::Fused => ResolveResult::new_sync(Err(std::io::Error::other(
127                "Polled a previously awaited result",
128            ))),
129        }
130    }
131}
132
133enum ResolveResultInner<T> {
134    Sync(Result<T, std::io::Error>),
135    Async(std::pin::Pin<Box<dyn Future<Output = std::io::Result<T>> + Send>>),
136    Fused,
137}
138
139impl<T> Future for ResolveResult<T>
140where
141    Self: Unpin,
142{
143    type Output = std::io::Result<T>;
144
145    fn poll(
146        self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148    ) -> std::task::Poll<Self::Output> {
149        let this = self.get_mut();
150        match &mut this.inner {
151            ResolveResultInner::Sync(_) => {
152                let this = std::mem::replace(&mut this.inner, ResolveResultInner::Fused);
153                let ResolveResultInner::Sync(result) = this else {
154                    unreachable!()
155                };
156                Poll::Ready(result)
157            }
158            ResolveResultInner::Async(future) => future.as_mut().poll(cx),
159            ResolveResultInner::Fused => {
160                panic!("Polled a previously awaited result")
161            }
162        }
163    }
164}
165
166/// A trait for types that can be resolved to a target.
167pub trait Resolvable {
168    type Target;
169
170    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target>;
171}
172
173impl Resolvable for String {
174    type Target = IpAddr;
175
176    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
177        resolver
178            .resolve_remote(&MaybeResolvedTarget::Unresolved(
179                Cow::Owned(self.clone()),
180                0,
181                None,
182            ))
183            .map(|target| match target {
184                ResolvedTarget::SocketAddr(addr) => addr.ip(),
185                _ => unreachable!(),
186            })
187    }
188}
189
190impl<T: TcpResolve + Clone> Resolvable for T {
191    type Target = ResolvedTarget;
192
193    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
194        resolver.resolve_remote(&self.clone().into())
195    }
196}
197
198impl Resolvable for TargetName {
199    type Target = ResolvedTarget;
200
201    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
202        resolver.resolve_remote(self.maybe_resolved())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use std::net::*;
210
211    #[tokio::test]
212    async fn test_resolve_remote() {
213        let resolver = Resolver::new().unwrap();
214        let target = TargetName::new_tcp(("localhost", 8080));
215        let result = target.resolve(&resolver).await.unwrap();
216        assert_eq!(
217            result,
218            ResolvedTarget::SocketAddr(SocketAddr::new(
219                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
220                8080
221            ))
222        );
223    }
224
225    #[cfg(feature = "__manual_tests")]
226    #[tokio::test]
227    async fn test_resolve_real_domain() {
228        let resolver = Resolver::new().unwrap();
229        let target = TargetName::new_tcp(("www.google.com", 443));
230        let result = target.resolve(&resolver).await.unwrap();
231        println!("{result:?}");
232    }
233}