1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use hashbrown::HashMap;
6use reqwest::dns::{Addrs, Name, Resolve, Resolving};
7use tokio::sync::RwLock;
8
9type DynErr = Box<dyn std::error::Error + Send + Sync>;
10
11const DEFAULT_DNS_CACHE_TTL_SECS: u64 = 5;
12
13pub(crate) fn get_dns_cache_ttl() -> Duration {
14 let ttl = Duration::from_secs(
15 std::env::var("POLARS_DNS_CACHE_TTL_SECS")
16 .ok()
17 .and_then(|s| s.parse::<u64>().ok())
18 .unwrap_or(DEFAULT_DNS_CACHE_TTL_SECS),
19 );
20
21 if polars_config::config().verbose() {
22 eprintln!("[dns_cache] ttl: {}s", ttl.as_secs());
23 }
24
25 ttl
26}
27
28#[derive(Debug)]
29struct CachedAddrs {
30 addrs: Arc<Vec<SocketAddr>>,
31 fetched_at: Instant,
32}
33
34#[derive(Clone, Debug)]
36pub struct CachingResolver {
37 cache: Arc<RwLock<HashMap<String, CachedAddrs>>>,
38 ttl: Duration,
41}
42
43impl CachingResolver {
44 pub fn new(ttl: Duration) -> Self {
45 Self {
46 cache: Arc::new(RwLock::default()),
47 ttl,
48 }
49 }
50}
51
52impl Resolve for CachingResolver {
53 fn resolve(&self, name: Name) -> Resolving {
54 let cache = self.cache.clone();
55 let ttl = self.ttl;
56 let key = name.as_str().to_string();
57
58 Box::pin(async move {
59 {
60 let read_guard = cache.read().await;
61
62 if let Some(entry) = read_guard.get(&key) {
63 if entry.fetched_at.elapsed() < ttl {
64 return Ok(shuffle_addrs(&entry.addrs));
65 }
66 }
67 }
68
69 let key_clone = key.clone();
71 let mut write_guard = cache.write().await;
72
73 if let Some(entry) = write_guard.get(&key) {
75 if entry.fetched_at.elapsed() < ttl {
76 return Ok(shuffle_addrs(&entry.addrs));
77 }
78 }
79
80 let addrs = Arc::new(
81 polars_core::runtime::ASYNC
82 .spawn_blocking(move || {
83 (key_clone.as_str(), 0u16)
84 .to_socket_addrs()
85 .map(|it| it.collect::<Vec<_>>())
86 })
87 .await
88 .map_err(DynErr::from)??,
89 );
90
91 write_guard.insert(
92 key,
93 CachedAddrs {
94 addrs: addrs.clone(),
95 fetched_at: Instant::now(),
96 },
97 );
98 drop(write_guard);
99
100 Ok(shuffle_addrs(&addrs))
101 })
102 }
103}
104
105fn shuffle_addrs(addrs: &Arc<Vec<SocketAddr>>) -> Addrs {
106 let mut indices: Vec<usize> = (0..addrs.len()).collect();
107 fastrand::shuffle(&mut indices);
108 let addrs = addrs.clone();
109 Box::new(indices.into_iter().map(move |i| addrs[i]))
110}