gel_stream/common/
resolver.rs1use std::borrow::Cow;
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::{future::Future, str::FromStr, task::Poll};
4
5use crate::{MaybeResolvedTarget, ResolvedTarget, TargetName, TcpResolve};
6
7#[derive(Clone)]
9pub struct Resolver {
10 #[cfg(feature = "hickory")]
11 resolver: std::sync::Arc<hickory_resolver::TokioResolver>,
12}
13
14#[cfg(feature = "tokio")]
15#[allow(unused)]
16async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<ResolvedTarget> {
17 let res = tokio::task::spawn_blocking(move || format!("{host}:0").to_socket_addrs())
18 .await
19 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
20 res.into_iter()
21 .next()
22 .ok_or(std::io::Error::new(
23 std::io::ErrorKind::NotFound,
24 "No address found",
25 ))
26 .map(ResolvedTarget::SocketAddr)
27}
28
29impl Resolver {
30 pub fn new() -> Result<Self, std::io::Error> {
32 Ok(Self {
33 #[cfg(feature = "hickory")]
34 resolver: hickory_resolver::Resolver::builder_tokio()?.build().into(),
35 })
36 }
37
38 pub(crate) fn resolve_remote(
39 &self,
40 host: &MaybeResolvedTarget,
41 ) -> ResolveResult<ResolvedTarget> {
42 match host {
43 MaybeResolvedTarget::Resolved(resolved) => {
44 ResolveResult::new_sync(Ok(resolved.clone()))
45 }
46 MaybeResolvedTarget::Unresolved(host, port, _) => {
47 if let Ok(ip) = IpAddr::from_str(host) {
48 ResolveResult::new_sync(Ok(ResolvedTarget::SocketAddr(SocketAddr::from((
49 ip, *port,
50 )))))
51 } else {
52 #[cfg(feature = "hickory")]
53 {
54 let resolver = self.resolver.clone();
55 let host = host.to_string();
56 let port = *port;
57 ResolveResult::new_async(async move {
58 let f = resolver.lookup_ip(host);
59 let Some(addr) = f.await?.iter().next() else {
60 return Err(std::io::Error::new(
61 std::io::ErrorKind::NotFound,
62 "No address found",
63 ));
64 };
65 Ok(ResolvedTarget::SocketAddr(SocketAddr::new(addr, port)))
66 })
67 }
68 #[cfg(all(feature = "tokio", not(feature = "hickory")))]
69 {
70 ResolveResult::new_async(resolve_host_to_socket_addrs(host.to_string()))
71 }
72 #[cfg(not(any(feature = "tokio", feature = "hickory")))]
73 {
74 ResolveResult::new_sync(Err(std::io::Error::new(
75 std::io::ErrorKind::Unsupported,
76 "No resolver available",
77 )))
78 }
79 }
80 }
81 }
82 }
83}
84
85pub struct ResolveResult<T> {
88 inner: ResolveResultInner<T>,
89}
90
91impl<T> ResolveResult<T> {
92 fn new_sync(result: Result<T, std::io::Error>) -> Self {
93 Self {
94 inner: ResolveResultInner::Sync(result),
95 }
96 }
97
98 fn new_async(future: impl Future<Output = std::io::Result<T>> + Send + 'static) -> Self {
99 Self {
100 inner: ResolveResultInner::Async(Box::pin(future)),
101 }
102 }
103
104 pub fn sync(&mut self) -> Result<Option<T>, std::io::Error> {
105 if let ResolveResultInner::Sync(_) = &mut self.inner {
106 let this = std::mem::replace(&mut self.inner, ResolveResultInner::Fused);
107 let ResolveResultInner::Sync(result) = this else {
108 unreachable!()
109 };
110 result.map(Some)
111 } else {
112 Ok(None)
113 }
114 }
115
116 pub fn map<U>(self, f: impl (FnOnce(T) -> U) + Send + 'static) -> ResolveResult<U>
117 where
118 T: 'static,
119 {
120 match self.inner {
121 ResolveResultInner::Sync(Ok(t)) => ResolveResult::new_sync(Ok(f(t))),
122 ResolveResultInner::Sync(Err(e)) => ResolveResult::new_sync(Err(e)),
123 ResolveResultInner::Async(future) => {
124 ResolveResult::new_async(async move { Ok(f(future.await?)) })
125 }
126 ResolveResultInner::Fused => ResolveResult::new_sync(Err(std::io::Error::other(
127 "Polled a previously awaited result",
128 ))),
129 }
130 }
131}
132
133enum ResolveResultInner<T> {
134 Sync(Result<T, std::io::Error>),
135 Async(std::pin::Pin<Box<dyn Future<Output = std::io::Result<T>> + Send>>),
136 Fused,
137}
138
139impl<T> Future for ResolveResult<T>
140where
141 Self: Unpin,
142{
143 type Output = std::io::Result<T>;
144
145 fn poll(
146 self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 ) -> std::task::Poll<Self::Output> {
149 let this = self.get_mut();
150 match &mut this.inner {
151 ResolveResultInner::Sync(_) => {
152 let this = std::mem::replace(&mut this.inner, ResolveResultInner::Fused);
153 let ResolveResultInner::Sync(result) = this else {
154 unreachable!()
155 };
156 Poll::Ready(result)
157 }
158 ResolveResultInner::Async(future) => future.as_mut().poll(cx),
159 ResolveResultInner::Fused => {
160 panic!("Polled a previously awaited result")
161 }
162 }
163 }
164}
165
166pub trait Resolvable {
168 type Target;
169
170 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target>;
171}
172
173impl Resolvable for String {
174 type Target = IpAddr;
175
176 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
177 resolver
178 .resolve_remote(&MaybeResolvedTarget::Unresolved(
179 Cow::Owned(self.clone()),
180 0,
181 None,
182 ))
183 .map(|target| match target {
184 ResolvedTarget::SocketAddr(addr) => addr.ip(),
185 _ => unreachable!(),
186 })
187 }
188}
189
190impl<T: TcpResolve + Clone> Resolvable for T {
191 type Target = ResolvedTarget;
192
193 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
194 resolver.resolve_remote(&self.clone().into())
195 }
196}
197
198impl Resolvable for TargetName {
199 type Target = ResolvedTarget;
200
201 fn resolve(&self, resolver: &Resolver) -> ResolveResult<Self::Target> {
202 resolver.resolve_remote(self.maybe_resolved())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use std::net::*;
210
211 #[tokio::test]
212 async fn test_resolve_remote() {
213 let resolver = Resolver::new().unwrap();
214 let target = TargetName::new_tcp(("localhost", 8080));
215 let result = target.resolve(&resolver).await.unwrap();
216 assert_eq!(
217 result,
218 ResolvedTarget::SocketAddr(SocketAddr::new(
219 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
220 8080
221 ))
222 );
223 }
224
225 #[cfg(feature = "__manual_tests")]
226 #[tokio::test]
227 async fn test_resolve_real_domain() {
228 let resolver = Resolver::new().unwrap();
229 let target = TargetName::new_tcp(("www.google.com", 443));
230 let result = target.resolve(&resolver).await.unwrap();
231 println!("{result:?}");
232 }
233}