nodecraft/resolver/impls/
address.rs1use core::{net::SocketAddr, time::Duration};
2
3use super::{super::AddressResolver, CachedSocketAddr};
4use crate::address::{Domain, HostAddr};
5
6use crossbeam_skiplist::SkipMap;
7
8#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct HostAddrResolverOptions {
12 #[cfg_attr(
13 feature = "serde",
14 serde(with = "humantime_serde", default = "default_record_ttl")
15 )]
16 record_ttl: Duration,
17}
18
19impl Default for HostAddrResolverOptions {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25const fn default_record_ttl() -> Duration {
26 Duration::from_secs(60)
27}
28
29impl HostAddrResolverOptions {
30 #[inline]
32 pub const fn new() -> Self {
33 Self {
34 record_ttl: default_record_ttl(),
35 }
36 }
37
38 #[inline]
40 pub const fn with_record_ttl(mut self, val: Duration) -> Self {
41 self.record_ttl = val;
42 self
43 }
44
45 #[inline]
47 pub fn set_record_ttl(&mut self, val: Duration) -> &mut Self {
48 self.record_ttl = val;
49 self
50 }
51
52 #[inline]
54 pub const fn record_ttl(&self) -> Duration {
55 self.record_ttl
56 }
57}
58
59pub use resolver::HostAddrResolver;
60
61#[cfg(feature = "agnostic")]
62mod resolver {
63 use super::*;
64
65 use agnostic::{RuntimeLite, net::ToSocketAddrs};
66 use hostaddr::Host;
67
68 pub struct HostAddrResolver<R> {
86 cache: SkipMap<Domain, CachedSocketAddr>,
87 record_ttl: Duration,
88 _marker: std::marker::PhantomData<R>,
89 }
90
91 impl<R> Default for HostAddrResolver<R> {
92 fn default() -> Self {
93 Self::new(Default::default())
94 }
95 }
96
97 impl<R: RuntimeLite> AddressResolver for HostAddrResolver<R> {
98 type Address = HostAddr;
99 type ResolvedAddress = SocketAddr;
100 type Error = std::io::Error;
101 type Runtime = R;
102 type Options = HostAddrResolverOptions;
103
104 #[inline]
105 async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
106 Ok(Self {
107 record_ttl: opts.record_ttl,
108 cache: Default::default(),
109 _marker: Default::default(),
110 })
111 }
112
113 async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
114 let Some(port) = address.port() else {
115 return Err(std::io::Error::new(
116 std::io::ErrorKind::InvalidInput,
117 "address missing port",
118 ));
119 };
120 let address: hostaddr::HostAddr<&Domain> = address.into();
121 let host = address.host();
122 match host {
123 Host::Ip(ip) => Ok(SocketAddr::new(*ip, port)),
124 Host::Domain(name) => {
125 if let Some(ent) = self.cache.get(name.as_inner()) {
127 let val = ent.value();
128 if !val.is_expired() {
129 return Ok(val.val);
130 } else {
131 ent.remove();
132 }
133 }
134
135 let res =
137 ToSocketAddrs::<Self::Runtime>::to_socket_addrs(&(name.as_inner().as_str(), port))
138 .await?;
139
140 if let Some(addr) = res.into_iter().next() {
141 self.cache.insert(
142 (*name).clone(),
143 CachedSocketAddr::new(addr, self.record_ttl),
144 );
145 return Ok(addr);
146 }
147
148 Err(std::io::Error::new(
149 std::io::ErrorKind::NotFound,
150 format!("failed to resolve {}", name.as_inner().as_str()),
151 ))
152 }
153 }
154 }
155 }
156
157 impl<R> HostAddrResolver<R> {
158 pub fn new(opts: HostAddrResolverOptions) -> Self {
160 Self {
161 record_ttl: opts.record_ttl,
162 cache: Default::default(),
163 _marker: Default::default(),
164 }
165 }
166 }
167
168 #[cfg(test)]
169 mod tests {
170 use super::*;
171
172 #[tokio::test]
173 async fn test_dns_resolver() {
174 use agnostic::tokio::TokioRuntime;
175
176 let resolver = HostAddrResolver::<TokioRuntime>::default();
177 let google_addr = HostAddr::try_from("google.com:8080").unwrap();
178 let ip = resolver.resolve(&google_addr).await.unwrap();
179 println!("google.com:8080 resolved to: {}", ip);
180 }
181
182 #[tokio::test]
183 async fn test_dns_resolver_with_record_ttl() {
184 use agnostic::tokio::TokioRuntime;
185
186 let resolver = HostAddrResolver::<TokioRuntime>::new(
187 HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
188 );
189 let google_addr = HostAddr::try_from("google.com:8080").unwrap();
190 resolver.resolve(&google_addr).await.unwrap();
191 resolver.resolve(&google_addr).await.unwrap();
192 let ip_addr = HostAddr::try_from(("127.0.0.1", 8080)).unwrap();
193 resolver.resolve(&ip_addr).await.unwrap();
194 let dns_name = Domain::try_from("google.com").unwrap();
195 assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
196
197 tokio::time::sleep(Duration::from_millis(100)).await;
198 assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
199 resolver.resolve(&google_addr).await.unwrap();
200
201 let bad_addr = HostAddr::try_from("adasdjkljasidjaosdjaisudnaisudibasd.com:8080").unwrap();
202 assert!(resolver.resolve(&bad_addr).await.is_err());
203 }
204 }
205}
206
207#[cfg(not(feature = "agnostic"))]
208mod resolver {
209 use super::*;
210
211 pub struct HostAddrResolver {
229 cache: SkipMap<Domain, CachedSocketAddr>,
230 record_ttl: Duration,
231 }
232
233 impl AddressResolver for HostAddrResolver {
234 type Address = HostAddr;
235 type ResolvedAddress = SocketAddr;
236 type Error = std::io::Error;
237 type Options = HostAddrResolverOptions;
238
239 #[inline]
240 async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
241 Ok(Self {
242 record_ttl: opts.record_ttl,
243 cache: Default::default(),
244 })
245 }
246
247 async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
248 match address.as_inner() {
249 Either::Left(addr) => Ok(addr),
250 Either::Right((port, name)) => {
251 if let Some(ent) = self.cache.get(name) {
253 let val = ent.value();
254 if !val.is_expired() {
255 return Ok(val.val);
256 } else {
257 ent.remove();
258 }
259 }
260
261 let res = ToSocketAddrs::to_socket_addrs(&(name.as_str(), port))?;
263 if let Some(addr) = res.into_iter().next() {
264 self
265 .cache
266 .insert(name.clone(), CachedSocketAddr::new(addr, self.record_ttl));
267 return Ok(addr);
268 }
269
270 Err(std::io::Error::new(
271 std::io::ErrorKind::NotFound,
272 format!("failed to resolve {}", name),
273 ))
274 }
275 }
276 }
277 }
278
279 impl Default for HostAddrResolver {
280 fn default() -> Self {
281 Self::new(Default::default())
282 }
283 }
284
285 impl HostAddrResolver {
286 pub fn new(opts: HostAddrResolverOptions) -> Self {
288 Self {
289 record_ttl: opts.record_ttl,
290 cache: Default::default(),
291 }
292 }
293 }
294
295 #[cfg(test)]
296 mod tests {
297 use super::*;
298
299 #[tokio::test]
300 async fn test_dns_resolver() {
301 let resolver = HostAddrResolver::default();
302 let google_addr = HostAddr::try_from("google.com:8080").unwrap();
303 let ip = resolver.resolve(&google_addr).await.unwrap();
304 println!("google.com:8080 resolved to: {}", ip);
305 }
306
307 #[tokio::test]
308 async fn test_dns_resolver_with_record_ttl() {
309 let resolver = HostAddrResolver::new(
310 HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
311 );
312 let google_addr = HostAddr::try_from("google.com:8080").unwrap();
313 resolver.resolve(&google_addr).await.unwrap();
314 let dns_name = Domain::try_from("google.com").unwrap();
315 assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
316
317 tokio::time::sleep(Duration::from_millis(100)).await;
318 assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_opts() {
329 let opts = HostAddrResolverOptions::default();
330 assert_eq!(opts.record_ttl(), default_record_ttl());
331 let mut opts = opts.with_record_ttl(Duration::from_secs(10));
332 assert_eq!(opts.record_ttl(), Duration::from_secs(10));
333 opts.set_record_ttl(Duration::from_secs(11));
334 assert_eq!(opts.record_ttl(), Duration::from_secs(11));
335 }
336}