nodecraft/resolver/impls/
address.rs

1use core::{net::SocketAddr, time::Duration};
2
3use super::{super::AddressResolver, CachedSocketAddr};
4use crate::address::{Domain, HostAddr};
5
6use crossbeam_skiplist::SkipMap;
7
8/// The options used to construct a [`AddressResolver`].
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct HostAddrResolverOptions {
12  #[cfg_attr(
13    feature = "serde",
14    serde(with = "humantime_serde", default = "default_record_ttl")
15  )]
16  record_ttl: Duration,
17}
18
19impl Default for HostAddrResolverOptions {
20  fn default() -> Self {
21    Self::new()
22  }
23}
24
25const fn default_record_ttl() -> Duration {
26  Duration::from_secs(60)
27}
28
29impl HostAddrResolverOptions {
30  /// Create a new [`HostAddrResolverOptions`].
31  #[inline]
32  pub const fn new() -> Self {
33    Self {
34      record_ttl: default_record_ttl(),
35    }
36  }
37
38  /// Set the DNS record ttl in builder pattern
39  #[inline]
40  pub const fn with_record_ttl(mut self, val: Duration) -> Self {
41    self.record_ttl = val;
42    self
43  }
44
45  /// Set the DNS record ttl
46  #[inline]
47  pub fn set_record_ttl(&mut self, val: Duration) -> &mut Self {
48    self.record_ttl = val;
49    self
50  }
51
52  /// Returns the DNS record ttl
53  #[inline]
54  pub const fn record_ttl(&self) -> Duration {
55    self.record_ttl
56  }
57}
58
59pub use resolver::HostAddrResolver;
60
61#[cfg(feature = "agnostic")]
62mod resolver {
63  use super::*;
64
65  use agnostic::{RuntimeLite, net::ToSocketAddrs};
66  use hostaddr::Host;
67
68  /// A resolver which supports both `domain:port` and socket address. However,
69  /// it will only use [`ToSocketAddrs`](std::net::ToSocketAddrs)
70  /// to resolve the address.
71  ///
72  /// - If you can make sure, you always play with [`SocketAddr`], you may want to
73  ///   use [`SocketAddrResolver`](crate::resolver::socket_addr::SocketAddrResolver).
74  /// - If you want to send DNS queries, you may want to use [`DnsResolver`](crate::resolver::dns::DnsResolver).
75  ///
76  /// **N.B.** If a domain contains multiple ip addresses, there is no guarantee that
77  /// which one will be used. Users should make sure that the domain only contains
78  /// one ip address, to make sure that [`AddressResolver`] can work properly.
79  ///
80  /// e.g. valid address format:
81  /// 1. `www.example.com:8080` // domain
82  /// 2. `[::1]:8080` // ipv6
83  /// 3. `127.0.0.1:8080` // ipv4
84  ///
85  pub struct HostAddrResolver<R> {
86    cache: SkipMap<Domain, CachedSocketAddr>,
87    record_ttl: Duration,
88    _marker: std::marker::PhantomData<R>,
89  }
90
91  impl<R> Default for HostAddrResolver<R> {
92    fn default() -> Self {
93      Self::new(Default::default())
94    }
95  }
96
97  impl<R: RuntimeLite> AddressResolver for HostAddrResolver<R> {
98    type Address = HostAddr;
99    type ResolvedAddress = SocketAddr;
100    type Error = std::io::Error;
101    type Runtime = R;
102    type Options = HostAddrResolverOptions;
103
104    #[inline]
105    async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
106      Ok(Self {
107        record_ttl: opts.record_ttl,
108        cache: Default::default(),
109        _marker: Default::default(),
110      })
111    }
112
113    async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
114      let Some(port) = address.port() else {
115        return Err(std::io::Error::new(
116          std::io::ErrorKind::InvalidInput,
117          "address missing port",
118        ));
119      };
120      let address: hostaddr::HostAddr<&Domain> = address.into();
121      let host = address.host();
122      match host {
123        Host::Ip(ip) => Ok(SocketAddr::new(*ip, port)),
124        Host::Domain(name) => {
125          // First, check cache
126          if let Some(ent) = self.cache.get(name.as_inner()) {
127            let val = ent.value();
128            if !val.is_expired() {
129              return Ok(val.val);
130            } else {
131              ent.remove();
132            }
133          }
134
135          // Finally, try to find the socket addr locally
136          let res =
137            ToSocketAddrs::<Self::Runtime>::to_socket_addrs(&(name.as_inner().as_str(), port))
138              .await?;
139
140          if let Some(addr) = res.into_iter().next() {
141            self.cache.insert(
142              (*name).clone(),
143              CachedSocketAddr::new(addr, self.record_ttl),
144            );
145            return Ok(addr);
146          }
147
148          Err(std::io::Error::new(
149            std::io::ErrorKind::NotFound,
150            format!("failed to resolve {}", name.as_inner().as_str()),
151          ))
152        }
153      }
154    }
155  }
156
157  impl<R> HostAddrResolver<R> {
158    /// Create a new [`HostAddrResolver`] with the given options.
159    pub fn new(opts: HostAddrResolverOptions) -> Self {
160      Self {
161        record_ttl: opts.record_ttl,
162        cache: Default::default(),
163        _marker: Default::default(),
164      }
165    }
166  }
167
168  #[cfg(test)]
169  mod tests {
170    use super::*;
171
172    #[tokio::test]
173    async fn test_dns_resolver() {
174      use agnostic::tokio::TokioRuntime;
175
176      let resolver = HostAddrResolver::<TokioRuntime>::default();
177      let google_addr = HostAddr::try_from("google.com:8080").unwrap();
178      let ip = resolver.resolve(&google_addr).await.unwrap();
179      println!("google.com:8080 resolved to: {}", ip);
180    }
181
182    #[tokio::test]
183    async fn test_dns_resolver_with_record_ttl() {
184      use agnostic::tokio::TokioRuntime;
185
186      let resolver = HostAddrResolver::<TokioRuntime>::new(
187        HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
188      );
189      let google_addr = HostAddr::try_from("google.com:8080").unwrap();
190      resolver.resolve(&google_addr).await.unwrap();
191      resolver.resolve(&google_addr).await.unwrap();
192      let ip_addr = HostAddr::try_from(("127.0.0.1", 8080)).unwrap();
193      resolver.resolve(&ip_addr).await.unwrap();
194      let dns_name = Domain::try_from("google.com").unwrap();
195      assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
196
197      tokio::time::sleep(Duration::from_millis(100)).await;
198      assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
199      resolver.resolve(&google_addr).await.unwrap();
200
201      let bad_addr = HostAddr::try_from("adasdjkljasidjaosdjaisudnaisudibasd.com:8080").unwrap();
202      assert!(resolver.resolve(&bad_addr).await.is_err());
203    }
204  }
205}
206
207#[cfg(not(feature = "agnostic"))]
208mod resolver {
209  use super::*;
210
211  /// A resolver which supports both `domain:port` and socket address. However,
212  /// it will only use [`ToSocketAddrs`](std::net::ToSocketAddrs)
213  /// to resolve the address.
214  ///
215  /// - If you can make sure, you always play with [`SocketAddr`], you may want to
216  ///   use [`SocketAddrResolver`](crate::resolver::socket_addr::SocketAddrResolver).
217  /// - If you want to send DNS queries, you may want to use [`DnsResolver`](crate::resolver::dns::DnsResolver).
218  ///
219  /// **N.B.** If a domain contains multiple ip addresses, there is no guarantee that
220  /// which one will be used. Users should make sure that the domain only contains
221  /// one ip address, to make sure that [`AddressResolver`] can work properly.
222  ///
223  /// e.g. valid address format:
224  /// 1. `www.example.com:8080` // domain
225  /// 2. `[::1]:8080` // ipv6
226  /// 3. `127.0.0.1:8080` // ipv4
227  ///
228  pub struct HostAddrResolver {
229    cache: SkipMap<Domain, CachedSocketAddr>,
230    record_ttl: Duration,
231  }
232
233  impl AddressResolver for HostAddrResolver {
234    type Address = HostAddr;
235    type ResolvedAddress = SocketAddr;
236    type Error = std::io::Error;
237    type Options = HostAddrResolverOptions;
238
239    #[inline]
240    async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
241      Ok(Self {
242        record_ttl: opts.record_ttl,
243        cache: Default::default(),
244      })
245    }
246
247    async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
248      match address.as_inner() {
249        Either::Left(addr) => Ok(addr),
250        Either::Right((port, name)) => {
251          // First, check cache
252          if let Some(ent) = self.cache.get(name) {
253            let val = ent.value();
254            if !val.is_expired() {
255              return Ok(val.val);
256            } else {
257              ent.remove();
258            }
259          }
260
261          // Finally, try to find the socket addr locally
262          let res = ToSocketAddrs::to_socket_addrs(&(name.as_str(), port))?;
263          if let Some(addr) = res.into_iter().next() {
264            self
265              .cache
266              .insert(name.clone(), CachedSocketAddr::new(addr, self.record_ttl));
267            return Ok(addr);
268          }
269
270          Err(std::io::Error::new(
271            std::io::ErrorKind::NotFound,
272            format!("failed to resolve {}", name),
273          ))
274        }
275      }
276    }
277  }
278
279  impl Default for HostAddrResolver {
280    fn default() -> Self {
281      Self::new(Default::default())
282    }
283  }
284
285  impl HostAddrResolver {
286    /// Create a new [`HostAddrResolver`] with the given options.
287    pub fn new(opts: HostAddrResolverOptions) -> Self {
288      Self {
289        record_ttl: opts.record_ttl,
290        cache: Default::default(),
291      }
292    }
293  }
294
295  #[cfg(test)]
296  mod tests {
297    use super::*;
298
299    #[tokio::test]
300    async fn test_dns_resolver() {
301      let resolver = HostAddrResolver::default();
302      let google_addr = HostAddr::try_from("google.com:8080").unwrap();
303      let ip = resolver.resolve(&google_addr).await.unwrap();
304      println!("google.com:8080 resolved to: {}", ip);
305    }
306
307    #[tokio::test]
308    async fn test_dns_resolver_with_record_ttl() {
309      let resolver = HostAddrResolver::new(
310        HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
311      );
312      let google_addr = HostAddr::try_from("google.com:8080").unwrap();
313      resolver.resolve(&google_addr).await.unwrap();
314      let dns_name = Domain::try_from("google.com").unwrap();
315      assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
316
317      tokio::time::sleep(Duration::from_millis(100)).await;
318      assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
319    }
320  }
321}
322
323#[cfg(test)]
324mod tests {
325  use super::*;
326
327  #[test]
328  fn test_opts() {
329    let opts = HostAddrResolverOptions::default();
330    assert_eq!(opts.record_ttl(), default_record_ttl());
331    let mut opts = opts.with_record_ttl(Duration::from_secs(10));
332    assert_eq!(opts.record_ttl(), Duration::from_secs(10));
333    opts.set_record_ttl(Duration::from_secs(11));
334    assert_eq!(opts.record_ttl(), Duration::from_secs(11));
335  }
336}