Skip to main content

polars_io/cloud/
dns.rs

1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use hashbrown::HashMap;
6use reqwest::dns::{Addrs, Name, Resolve, Resolving};
7use tokio::sync::RwLock;
8
9type DynErr = Box<dyn std::error::Error + Send + Sync>;
10
11const DEFAULT_DNS_CACHE_TTL_SECS: u64 = 5;
12
13pub(crate) fn get_dns_cache_ttl() -> Duration {
14    let ttl = Duration::from_secs(
15        std::env::var("POLARS_DNS_CACHE_TTL_SECS")
16            .ok()
17            .and_then(|s| s.parse::<u64>().ok())
18            .unwrap_or(DEFAULT_DNS_CACHE_TTL_SECS),
19    );
20
21    if polars_config::config().verbose() {
22        eprintln!("[dns_cache] ttl: {}s", ttl.as_secs());
23    }
24
25    ttl
26}
27
28#[derive(Debug)]
29struct CachedAddrs {
30    addrs: Arc<Vec<SocketAddr>>,
31    fetched_at: Instant,
32}
33
34/// Shuffle resolver with basic DNS cache. TTL is fixed and set by the calling site.
35#[derive(Clone, Debug)]
36pub struct CachingResolver {
37    cache: Arc<RwLock<HashMap<String, CachedAddrs>>>,
38    // Since the OS does not return the TTL as provided by DNS, the calling site
39    // is responsible for providing one.
40    ttl: Duration,
41}
42
43impl CachingResolver {
44    pub fn new(ttl: Duration) -> Self {
45        Self {
46            cache: Arc::new(RwLock::default()),
47            ttl,
48        }
49    }
50}
51
52impl Resolve for CachingResolver {
53    fn resolve(&self, name: Name) -> Resolving {
54        let cache = self.cache.clone();
55        let ttl = self.ttl;
56        let key = name.as_str().to_string();
57
58        Box::pin(async move {
59            {
60                let read_guard = cache.read().await;
61
62                if let Some(entry) = read_guard.get(&key) {
63                    if entry.fetched_at.elapsed() < ttl {
64                        return Ok(shuffle_addrs(&entry.addrs));
65                    }
66                }
67            }
68
69            // Cache miss or expired
70            let key_clone = key.clone();
71            let mut write_guard = cache.write().await;
72
73            // Re-check in case the cache has been populated in the meanwhile
74            if let Some(entry) = write_guard.get(&key) {
75                if entry.fetched_at.elapsed() < ttl {
76                    return Ok(shuffle_addrs(&entry.addrs));
77                }
78            }
79
80            let addrs = Arc::new(
81                polars_core::runtime::ASYNC
82                    .spawn_blocking(move || {
83                        (key_clone.as_str(), 0u16)
84                            .to_socket_addrs()
85                            .map(|it| it.collect::<Vec<_>>())
86                    })
87                    .await
88                    .map_err(DynErr::from)??,
89            );
90
91            write_guard.insert(
92                key,
93                CachedAddrs {
94                    addrs: addrs.clone(),
95                    fetched_at: Instant::now(),
96                },
97            );
98            drop(write_guard);
99
100            Ok(shuffle_addrs(&addrs))
101        })
102    }
103}
104
105fn shuffle_addrs(addrs: &Arc<Vec<SocketAddr>>) -> Addrs {
106    let mut indices: Vec<usize> = (0..addrs.len()).collect();
107    fastrand::shuffle(&mut indices);
108    let addrs = addrs.clone();
109    Box::new(indices.into_iter().map(move |i| addrs[i]))
110}