1use super::{super::super::cache::Cache, ResolveAnswers, ResolveOptions, ResolveResult, Resolver};
2use std::{
3 env::temp_dir,
4 fmt::Debug,
5 path::{Path, PathBuf},
6 sync::Arc,
7 time::Duration,
8};
9
10#[cfg(feature = "async")]
11use {super::super::super::cache::AsyncCache, futures::future::BoxFuture};
12
13const DEFAULT_SHRINK_INTERVAL: Duration = Duration::from_secs(120);
14const DEFAULT_CACHE_LIFETIME: Duration = Duration::from_secs(120);
15
16#[derive(Debug)]
22pub struct CachedResolver<R: ?Sized> {
23 resolver: Arc<R>,
24 cache: Cache<String, ResolveAnswers>,
25
26 #[cfg(feature = "async")]
27 async_cache: AsyncCache<String, ResolveAnswers>,
28}
29
30impl<R> CachedResolver<R> {
31 #[inline]
33 pub fn builder(backend: R) -> CachedResolverBuilder<R> {
34 CachedResolverBuilder::new(backend)
35 }
36
37 #[inline]
41 pub fn persistent_path(&self) -> Option<&Path> {
42 self.cache.persistent_path()
43 }
44
45 #[inline]
47 pub fn auto_persistent(&self) -> Option<bool> {
48 self.cache.auto_persistent()
49 }
50
51 fn default_persistent_path() -> PathBuf {
52 let mut path = dirs::cache_dir().unwrap_or_else(temp_dir);
53 path.push(".qiniu-rust-sdk");
54 path.push("resolver-cache.json");
55 path
56 }
57}
58
59impl<R: Default> Default for CachedResolver<R> {
60 #[inline]
61 fn default() -> Self {
62 Self::builder(R::default()).default_load_or_create_from(true)
63 }
64}
65
66impl<R> Clone for CachedResolver<R> {
67 #[inline]
68 fn clone(&self) -> Self {
69 Self {
70 resolver: self.resolver.clone(),
71 cache: self.cache.clone(),
72
73 #[cfg(feature = "async")]
74 async_cache: self.async_cache.clone(),
75 }
76 }
77}
78
79impl<R: Resolver + 'static> Resolver for CachedResolver<R> {
80 fn resolve(&self, domain: &str, opts: ResolveOptions) -> ResolveResult {
81 self.cache.get(domain, || self.resolver.resolve(domain, opts))
82 }
83
84 #[cfg(feature = "async")]
85 #[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
86 fn async_resolve<'a>(&'a self, domain: &'a str, opts: ResolveOptions<'a>) -> BoxFuture<'a, ResolveResult> {
87 Box::pin(async move {
88 self.async_cache
89 .get(domain, self.resolver.async_resolve(domain, opts))
90 .await
91 })
92 }
93}
94
95#[derive(Debug)]
97pub struct CachedResolverBuilder<R: ?Sized> {
98 cache_lifetime: Duration,
99 shrink_interval: Duration,
100 resolver: R,
101}
102
103impl<R> CachedResolverBuilder<R> {
104 #[inline]
106 pub fn new(resolver: R) -> Self {
107 Self {
108 resolver,
109 cache_lifetime: DEFAULT_CACHE_LIFETIME,
110 shrink_interval: DEFAULT_SHRINK_INTERVAL,
111 }
112 }
113
114 #[inline]
116 pub fn cache_lifetime(mut self, cache_lifetime: Duration) -> Self {
117 self.cache_lifetime = cache_lifetime;
118 self
119 }
120
121 #[inline]
123 pub fn shrink_interval(mut self, shrink_interval: Duration) -> Self {
124 self.shrink_interval = shrink_interval;
125 self
126 }
127
128 #[inline]
132 pub fn load_or_create_from(self, path: impl AsRef<Path>, auto_persistent: bool) -> CachedResolver<R> {
133 CachedResolver {
134 resolver: Arc::new(self.resolver),
135 cache: Cache::load_or_create_from(
136 path.as_ref(),
137 auto_persistent,
138 self.cache_lifetime,
139 self.shrink_interval,
140 ),
141
142 #[cfg(feature = "async")]
143 async_cache: AsyncCache::load_or_create_from(
144 path.as_ref(),
145 auto_persistent,
146 self.cache_lifetime,
147 self.shrink_interval,
148 ),
149 }
150 }
151
152 #[inline]
156 pub fn default_load_or_create_from(self, auto_persistent: bool) -> CachedResolver<R> {
157 CachedResolver {
158 resolver: Arc::new(self.resolver),
159 cache: Cache::load_or_create_from(
160 &CachedResolver::<R>::default_persistent_path(),
161 auto_persistent,
162 self.cache_lifetime,
163 self.shrink_interval,
164 ),
165
166 #[cfg(feature = "async")]
167 async_cache: AsyncCache::load_or_create_from(
168 &CachedResolver::<R>::default_persistent_path(),
169 auto_persistent,
170 self.cache_lifetime,
171 self.shrink_interval,
172 ),
173 }
174 }
175
176 #[inline]
180 pub fn in_memory(self) -> CachedResolver<R> {
181 CachedResolver {
182 resolver: Arc::new(self.resolver),
183 cache: Cache::in_memory(self.cache_lifetime, self.shrink_interval),
184
185 #[cfg(feature = "async")]
186 async_cache: AsyncCache::in_memory(self.cache_lifetime, self.shrink_interval),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use dashmap::DashMap;
195 use std::{
196 collections::HashMap,
197 error::Error,
198 fs::File,
199 net::{IpAddr, Ipv4Addr},
200 sync::Arc,
201 thread::{sleep, spawn},
202 };
203 use tap::tap::TapOptional;
204 use tempfile::tempdir;
205
206 #[derive(Debug, Clone, Default)]
207 struct ResolverFromTable {
208 table: HashMap<String, Box<[IpAddr]>>,
209 resolved: DashMap<String, usize>,
210 }
211
212 impl ResolverFromTable {
213 fn add(&mut self, domain: impl Into<String>, ip_addrs: Vec<IpAddr>) {
214 self.table.insert(domain.into(), ip_addrs.into_boxed_slice());
215 }
216
217 fn resolved(&self, domain: impl AsRef<str>) -> Option<usize> {
218 self.resolved.get(domain.as_ref()).map(|v| *v)
219 }
220 }
221
222 impl Resolver for ResolverFromTable {
223 fn resolve(&self, domain: &str, _opts: ResolveOptions) -> ResolveResult {
224 let key = domain.to_owned();
225 Ok(self
226 .table
227 .get(&key)
228 .tap_some(|_| {
229 self.resolved
230 .entry(key)
231 .and_modify(|resolved| *resolved += 1)
232 .or_insert(1);
233 })
234 .cloned()
235 .unwrap_or_default()
236 .into())
237 }
238
239 #[cfg(feature = "async")]
240 #[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
241 fn async_resolve<'a>(&'a self, _domain: &'a str, _opts: ResolveOptions) -> BoxFuture<'a, ResolveResult> {
242 unreachable!()
243 }
244 }
245
246 #[test]
247 fn test_thread_safe_cached_resolver() -> Result<(), Box<dyn Error>> {
248 env_logger::builder().is_test(true).try_init().ok();
249
250 let mut backend = ResolverFromTable::default();
251 backend.add("test_domain_1.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
252 backend.add("test_domain_2.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
253 backend.add("test_domain_3.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))]);
254 let resolver = Arc::new(
255 CachedResolver::builder(backend)
256 .cache_lifetime(Duration::from_secs(5))
257 .in_memory(),
258 );
259 let threads_1 = (0..3).map(|_| {
260 let resolver = resolver.to_owned();
261 spawn(move || {
262 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
263 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
264 })
265 });
266 let threads_2 = (0..5).map(|_| {
267 let resolver = resolver.to_owned();
268 spawn(move || {
269 let result = resolver.resolve("test_domain_2.com", Default::default()).unwrap();
270 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
271 })
272 });
273 let threads_3 = (0..7).map(|_| {
274 let resolver = resolver.to_owned();
275 spawn(move || {
276 let result = resolver.resolve("test_domain_3.com", Default::default()).unwrap();
277 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))]);
278 })
279 });
280 threads_1
281 .into_iter()
282 .chain(threads_2)
283 .chain(threads_3)
284 .try_for_each(|thread| thread.join())
285 .unwrap();
286 let resolver = Arc::try_unwrap(resolver).unwrap();
287 assert_eq!(resolver.resolver.resolved("test_domain_1.com"), Some(1));
288 assert_eq!(resolver.resolver.resolved("test_domain_2.com"), Some(1));
289 assert_eq!(resolver.resolver.resolved("test_domain_3.com"), Some(1));
290 Ok(())
291 }
292
293 #[test]
294 fn test_resolver_cache() -> Result<(), Box<dyn Error>> {
295 env_logger::builder().is_test(true).try_init().ok();
296
297 let mut backend = ResolverFromTable::default();
298 backend.add("test_domain_1.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
299 let resolver = CachedResolver::builder(backend)
300 .cache_lifetime(Duration::from_secs(1))
301 .in_memory();
302
303 for _ in 0..5 {
304 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
305 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
306 }
307
308 assert_eq!(resolver.resolver.resolved("test_domain_1.com"), Some(1));
309
310 sleep(Duration::from_secs(2));
311
312 for _ in 0..5 {
313 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
314 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
315 sleep(Duration::from_millis(50));
316 }
317
318 assert_eq!(resolver.resolver.resolved("test_domain_1.com"), Some(2));
319 Ok(())
320 }
321
322 #[test]
323 fn test_persistent_resolver() -> Result<(), Box<dyn Error>> {
324 env_logger::builder().is_test(true).try_init().ok();
325
326 let mut backend = ResolverFromTable::default();
327 backend.add("test_domain_1.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
328 backend.add("test_domain_2.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
329 backend.add("test_domain_3.com", vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))]);
330
331 let tempdir = tempdir()?;
332 let tempfile_path = {
333 let mut path = tempdir.path().to_owned();
334 path.push("resolve_result");
335 path
336 };
337
338 {
339 let resolver = CachedResolver::builder(backend.to_owned()).load_or_create_from(&tempfile_path, true);
340 {
341 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
342 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
343 }
344 {
345 let result = resolver.resolve("test_domain_2.com", Default::default()).unwrap();
346 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
347 }
348 sleep(Duration::from_secs(1));
349 File::open(resolver.persistent_path().unwrap())?;
350 }
351
352 {
353 let resolver = CachedResolver::builder(backend.to_owned()).load_or_create_from(&tempfile_path, true);
354 {
355 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
356 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
357 }
358 {
359 let result = resolver.resolve("test_domain_2.com", Default::default()).unwrap();
360 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
361 }
362 {
363 let result = resolver.resolve("test_domain_3.com", Default::default()).unwrap();
364 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))]);
365 }
366 assert_eq!(resolver.resolver.resolved("test_domain_1.com"), None);
367 assert_eq!(resolver.resolver.resolved("test_domain_2.com"), None);
368 assert_eq!(resolver.resolver.resolved("test_domain_3.com"), Some(1));
369 }
370
371 sleep(Duration::from_secs(1));
372
373 {
374 let resolver = CachedResolver::builder(backend.to_owned()).load_or_create_from(&tempfile_path, true);
375 {
376 let result = resolver.resolve("test_domain_1.com", Default::default()).unwrap();
377 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))]);
378 }
379 {
380 let result = resolver.resolve("test_domain_2.com", Default::default()).unwrap();
381 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))]);
382 }
383 {
384 let result = resolver.resolve("test_domain_3.com", Default::default()).unwrap();
385 assert_eq!(result.ip_addrs(), &[IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))]);
386 }
387 assert_eq!(resolver.resolver.resolved("test_domain_1.com"), None);
388 assert_eq!(resolver.resolver.resolved("test_domain_2.com"), None);
389 assert_eq!(resolver.resolver.resolved("test_domain_3.com"), None);
390 }
391
392 Ok(())
393 }
394}