gel_stream/common/
resolver.rs

1use std::borrow::Cow;
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::{future::Future, str::FromStr, task::Poll};
4
5use crate::{MaybeResolvedTarget, ResolvedTarget, TargetName, TcpResolve};
6
7/// An async resolver for hostnames to IP addresses.
8#[derive(Clone)]
9pub struct Resolver {
10    #[cfg(feature = "hickory")]
11    resolver: std::sync::Arc<hickory_resolver::TokioResolver>,
12}
13
14#[cfg(feature = "tokio")]
15#[allow(unused)]
16async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<ResolvedTarget> {
17    let res = tokio::task::spawn_blocking(move || format!("{}:0", host).to_socket_addrs())
18        .await
19        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
20    res.into_iter()
21        .next()
22        .ok_or(std::io::Error::new(
23            std::io::ErrorKind::NotFound,
24            "No address found",
25        ))
26        .map(|addr| ResolvedTarget::SocketAddr(addr))
27}
28
29impl Resolver {
30    /// Create a new resolver.
31    pub fn new() -> Result<Self, std::io::Error> {
32        Ok(Self {
33            #[cfg(feature = "hickory")]
34            resolver: hickory_resolver::Resolver::builder_tokio()?.build().into(),
35        })
36    }
37
38    pub(crate) fn resolve_remote(
39        &self,
40        host: &MaybeResolvedTarget,
41    ) -> ResolveResult<ResolvedTarget> {
42        match host {
43            MaybeResolvedTarget::Resolved(resolved) => {
44                ResolveResult::new_sync(Ok(resolved.clone()))
45            }
46            MaybeResolvedTarget::Unresolved(host, port, _) => {
47                if let Ok(ip) = IpAddr::from_str(&host) {
48                    ResolveResult::new_sync(Ok(ResolvedTarget::SocketAddr(SocketAddr::from((
49                        ip, *port,
50                    )))))
51                } else {
52                    #[cfg(feature = "hickory")]
53                    {
54                        let resolver = self.resolver.clone();
55                        let host = host.to_string();
56                        let port = *port;
57                        ResolveResult::new_async(async move {
58                            let f = resolver.lookup_ip(host);
59                            let Some(addr) = f.await?.iter().next() else {
60                                return Err(std::io::Error::new(
61                                    std::io::ErrorKind::NotFound,
62                                    "No address found",
63                                ));
64                            };
65                            Ok(ResolvedTarget::SocketAddr(SocketAddr::new(addr, port)))
66                        })
67                    }
68                    #[cfg(all(feature = "tokio", not(feature = "hickory")))]
69                    {
70                        ResolveResult::new_async(resolve_host_to_socket_addrs(host.to_string()))
71                    }
72                    #[cfg(not(any(feature = "tokio", feature = "hickory")))]
73                    {
74                        ResolveResult::new_sync(Err(std::io::Error::new(
75                            std::io::ErrorKind::Unsupported,
76                            "No resolver available",
77                        )))
78                    }
79                }
80            }
81        }
82    }
83}
84
85/// The result of a resolution. It may be synchronous or asynchronous, but you
86/// can always call `.await` on it.
87pub struct ResolveResult<T> {
88    inner: ResolveResultInner<T>,
89}
90
91impl<T> ResolveResult<T> {
92    fn new_sync(result: Result<T, std::io::Error>) -> Self {
93        Self {
94            inner: ResolveResultInner::Sync(result),
95        }
96    }
97
98    fn new_async(future: impl Future<Output = std::io::Result<T>> + Send + 'static) -> Self {
99        Self {
100            inner: ResolveResultInner::Async(Box::pin(future)),
101        }
102    }
103
104    pub fn sync(&mut self) -> Result<Option<T>, std::io::Error> {
105        if let ResolveResultInner::Sync(_) = &mut self.inner {
106            let this = std::mem::replace(&mut self.inner, ResolveResultInner::Fused);
107            let ResolveResultInner::Sync(result) = this else {
108                unreachable!()
109            };
110            result.map(Some)
111        } else {
112            Ok(None)
113        }
114    }
115
116    pub fn map<U>(self, f: impl (FnOnce(T) -> U) + Send + 'static) -> ResolveResult<U>
117    where
118        T: 'static,
119    {
120        match self.inner {
121            ResolveResultInner::Sync(Ok(t)) => ResolveResult::new_sync(Ok(f(t))),
122            ResolveResultInner::Sync(Err(e)) => ResolveResult::new_sync(Err(e)),
123            ResolveResultInner::Async(future) => {
124                ResolveResult::new_async(async move { Ok(f(future.await?)) })
125            }
126            ResolveResultInner::Fused => ResolveResult::new_sync(Err(std::io::Error::new(
127                std::io::ErrorKind::Other,
128                "Polled a previously awaited result",
129            ))),
130        }
131    }
132}
133
134enum ResolveResultInner<T> {
135    Sync(Result<T, std::io::Error>),
136    Async(std::pin::Pin<Box<dyn Future<Output = std::io::Result<T>> + Send>>),
137    Fused,
138}
139
140impl<T> Future for ResolveResult<T>
141where
142    Self: Unpin,
143{
144    type Output = std::io::Result<T>;
145
146    fn poll(
147        self: std::pin::Pin<&mut Self>,
148        cx: &mut std::task::Context<'_>,
149    ) -> std::task::Poll<Self::Output> {
150        let this = self.get_mut();
151        match &mut this.inner {
152            ResolveResultInner::Sync(_) => {
153                let this = std::mem::replace(&mut this.inner, ResolveResultInner::Fused);
154                let ResolveResultInner::Sync(result) = this else {
155                    unreachable!()
156                };
157                Poll::Ready(result)
158            }
159            ResolveResultInner::Async(future) => future.as_mut().poll(cx),
160            ResolveResultInner::Fused => {
161                panic!("Polled a previously awaited result")
162            }
163        }
164    }
165}
166
167/// A trait for types that can be resolved to a target.
168pub trait Resolvable {
169    type Target;
170
171    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target>;
172}
173
174impl Resolvable for String {
175    type Target = IpAddr;
176
177    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
178        resolver
179            .resolve_remote(&MaybeResolvedTarget::Unresolved(
180                Cow::Owned(self.clone()),
181                0,
182                None,
183            ))
184            .map(|target| match target {
185                ResolvedTarget::SocketAddr(addr) => addr.ip(),
186                _ => unreachable!(),
187            })
188    }
189}
190
191impl<T: TcpResolve + Clone> Resolvable for T {
192    type Target = ResolvedTarget;
193
194    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
195        resolver.resolve_remote(&self.clone().into())
196    }
197}
198
199impl Resolvable for TargetName {
200    type Target = ResolvedTarget;
201
202    fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
203        resolver.resolve_remote(self.maybe_resolved())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::net::*;
211
212    #[tokio::test]
213    async fn test_resolve_remote() {
214        let resolver = Resolver::new().unwrap();
215        let target = TargetName::new_tcp(("localhost", 8080));
216        let result = target.resolve(&resolver).await.unwrap();
217        assert_eq!(
218            result,
219            ResolvedTarget::SocketAddr(SocketAddr::new(
220                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
221                8080
222            ))
223        );
224    }
225
226    #[cfg(feature = "__manual_tests")]
227    #[tokio::test]
228    async fn test_resolve_real_domain() {
229        let resolver = Resolver::new().unwrap();
230        let target = TargetName::new_tcp(("www.google.com", 443));
231        let result = target.resolve(&resolver).await.unwrap();
232        println!("{:?}", result);
233    }
234}