gel_stream/common/
resolver.rs1use std::borrow::Cow;
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::{future::Future, str::FromStr, task::Poll};
4
5use crate::{MaybeResolvedTarget, ResolvedTarget, TargetName, TcpResolve};
6
7#[derive(Clone)]
9pub struct Resolver {
10 #[cfg(feature = "hickory")]
11 resolver: std::sync::Arc<hickory_resolver::TokioResolver>,
12}
13
14#[cfg(feature = "tokio")]
15#[allow(unused)]
16async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<ResolvedTarget> {
17 let res = tokio::task::spawn_blocking(move || format!("{}:0", host).to_socket_addrs())
18 .await
19 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
20 res.into_iter()
21 .next()
22 .ok_or(std::io::Error::new(
23 std::io::ErrorKind::NotFound,
24 "No address found",
25 ))
26 .map(|addr| ResolvedTarget::SocketAddr(addr))
27}
28
29impl Resolver {
30 pub fn new() -> Result<Self, std::io::Error> {
32 Ok(Self {
33 #[cfg(feature = "hickory")]
34 resolver: hickory_resolver::Resolver::builder_tokio()?.build().into(),
35 })
36 }
37
38 pub(crate) fn resolve_remote(
39 &self,
40 host: &MaybeResolvedTarget,
41 ) -> ResolveResult<ResolvedTarget> {
42 match host {
43 MaybeResolvedTarget::Resolved(resolved) => {
44 ResolveResult::new_sync(Ok(resolved.clone()))
45 }
46 MaybeResolvedTarget::Unresolved(host, port, _) => {
47 if let Ok(ip) = IpAddr::from_str(&host) {
48 ResolveResult::new_sync(Ok(ResolvedTarget::SocketAddr(SocketAddr::from((
49 ip, *port,
50 )))))
51 } else {
52 #[cfg(feature = "hickory")]
53 {
54 let resolver = self.resolver.clone();
55 let host = host.to_string();
56 let port = *port;
57 ResolveResult::new_async(async move {
58 let f = resolver.lookup_ip(host);
59 let Some(addr) = f.await?.iter().next() else {
60 return Err(std::io::Error::new(
61 std::io::ErrorKind::NotFound,
62 "No address found",
63 ));
64 };
65 Ok(ResolvedTarget::SocketAddr(SocketAddr::new(addr, port)))
66 })
67 }
68 #[cfg(all(feature = "tokio", not(feature = "hickory")))]
69 {
70 ResolveResult::new_async(resolve_host_to_socket_addrs(host.to_string()))
71 }
72 #[cfg(not(any(feature = "tokio", feature = "hickory")))]
73 {
74 ResolveResult::new_sync(Err(std::io::Error::new(
75 std::io::ErrorKind::Unsupported,
76 "No resolver available",
77 )))
78 }
79 }
80 }
81 }
82 }
83}
84
85pub struct ResolveResult<T> {
88 inner: ResolveResultInner<T>,
89}
90
91impl<T> ResolveResult<T> {
92 fn new_sync(result: Result<T, std::io::Error>) -> Self {
93 Self {
94 inner: ResolveResultInner::Sync(result),
95 }
96 }
97
98 fn new_async(future: impl Future<Output = std::io::Result<T>> + Send + 'static) -> Self {
99 Self {
100 inner: ResolveResultInner::Async(Box::pin(future)),
101 }
102 }
103
104 pub fn sync(&mut self) -> Result<Option<T>, std::io::Error> {
105 if let ResolveResultInner::Sync(_) = &mut self.inner {
106 let this = std::mem::replace(&mut self.inner, ResolveResultInner::Fused);
107 let ResolveResultInner::Sync(result) = this else {
108 unreachable!()
109 };
110 result.map(Some)
111 } else {
112 Ok(None)
113 }
114 }
115
116 pub fn map<U>(self, f: impl (FnOnce(T) -> U) + Send + 'static) -> ResolveResult<U>
117 where
118 T: 'static,
119 {
120 match self.inner {
121 ResolveResultInner::Sync(Ok(t)) => ResolveResult::new_sync(Ok(f(t))),
122 ResolveResultInner::Sync(Err(e)) => ResolveResult::new_sync(Err(e)),
123 ResolveResultInner::Async(future) => {
124 ResolveResult::new_async(async move { Ok(f(future.await?)) })
125 }
126 ResolveResultInner::Fused => ResolveResult::new_sync(Err(std::io::Error::new(
127 std::io::ErrorKind::Other,
128 "Polled a previously awaited result",
129 ))),
130 }
131 }
132}
133
134enum ResolveResultInner<T> {
135 Sync(Result<T, std::io::Error>),
136 Async(std::pin::Pin<Box<dyn Future<Output = std::io::Result<T>> + Send>>),
137 Fused,
138}
139
140impl<T> Future for ResolveResult<T>
141where
142 Self: Unpin,
143{
144 type Output = std::io::Result<T>;
145
146 fn poll(
147 self: std::pin::Pin<&mut Self>,
148 cx: &mut std::task::Context<'_>,
149 ) -> std::task::Poll<Self::Output> {
150 let this = self.get_mut();
151 match &mut this.inner {
152 ResolveResultInner::Sync(_) => {
153 let this = std::mem::replace(&mut this.inner, ResolveResultInner::Fused);
154 let ResolveResultInner::Sync(result) = this else {
155 unreachable!()
156 };
157 Poll::Ready(result)
158 }
159 ResolveResultInner::Async(future) => future.as_mut().poll(cx),
160 ResolveResultInner::Fused => {
161 panic!("Polled a previously awaited result")
162 }
163 }
164 }
165}
166
167pub trait Resolvable {
169 type Target;
170
171 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target>;
172}
173
174impl Resolvable for String {
175 type Target = IpAddr;
176
177 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
178 resolver
179 .resolve_remote(&MaybeResolvedTarget::Unresolved(
180 Cow::Owned(self.clone()),
181 0,
182 None,
183 ))
184 .map(|target| match target {
185 ResolvedTarget::SocketAddr(addr) => addr.ip(),
186 _ => unreachable!(),
187 })
188 }
189}
190
191impl<T: TcpResolve + Clone> Resolvable for T {
192 type Target = ResolvedTarget;
193
194 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
195 resolver.resolve_remote(&self.clone().into())
196 }
197}
198
199impl Resolvable for TargetName {
200 type Target = ResolvedTarget;
201
202 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
203 resolver.resolve_remote(self.maybe_resolved())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::net::*;
211
212 #[tokio::test]
213 async fn test_resolve_remote() {
214 let resolver = Resolver::new().unwrap();
215 let target = TargetName::new_tcp(("localhost", 8080));
216 let result = target.resolve(&resolver).await.unwrap();
217 assert_eq!(
218 result,
219 ResolvedTarget::SocketAddr(SocketAddr::new(
220 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
221 8080
222 ))
223 );
224 }
225
226 #[cfg(feature = "__manual_tests")]
227 #[tokio::test]
228 async fn test_resolve_real_domain() {
229 let resolver = Resolver::new().unwrap();
230 let target = TargetName::new_tcp(("www.google.com", 443));
231 let result = target.resolve(&resolver).await.unwrap();
232 println!("{:?}", result);
233 }
234}