1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use hickory_resolver::config::*;
12use hickory_resolver::TokioAsyncResolver;
13
14#[derive(Clone, Debug)]
16struct DnsCacheEntry {
17 ips: Vec<IpAddr>,
18 expires_at: Instant,
19}
20
21pub struct DnsCache {
23 resolver: TokioAsyncResolver,
24 cache: Arc<RwLock<HashMap<String, DnsCacheEntry>>>,
25 default_ttl: Duration,
26}
27
28impl DnsCache {
29 pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
31 let resolver =
32 TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
33
34 Ok(Self {
35 resolver,
36 cache: Arc::new(RwLock::new(HashMap::new())),
37 default_ttl: Duration::from_secs(300), })
39 }
40
41 pub async fn with_ttl(ttl: Duration) -> Result<Self, Box<dyn std::error::Error>> {
43 let resolver =
44 TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
45
46 Ok(Self {
47 resolver,
48 cache: Arc::new(RwLock::new(HashMap::new())),
49 default_ttl: ttl,
50 })
51 }
52
53 pub async fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>, Box<dyn std::error::Error>> {
55 {
57 let cache = self.cache.read().await;
58 if let Some(entry) = cache.get(hostname) {
59 if entry.expires_at > Instant::now() {
60 return Ok(entry.ips.clone());
61 }
62 }
63 }
64
65 let lookup = self.resolver.lookup_ip(hostname).await?;
67 let ips: Vec<IpAddr> = lookup.iter().collect();
68
69 let entry = DnsCacheEntry {
71 ips: ips.clone(),
72 expires_at: Instant::now() + self.default_ttl,
73 };
74
75 let mut cache = self.cache.write().await;
76 cache.insert(hostname.to_string(), entry);
77
78 Ok(ips)
79 }
80
81 pub async fn prewarm(&self, hostname: &str) -> Result<(), Box<dyn std::error::Error>> {
83 self.resolve(hostname).await?;
84 Ok(())
85 }
86
87 pub async fn clear(&self) {
89 let mut cache = self.cache.write().await;
90 cache.clear();
91 }
92
93 pub async fn cache_size(&self) -> usize {
95 let cache = self.cache.read().await;
96 cache.len()
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[tokio::test]
105 async fn test_dns_cache_resolve() {
106 let cache = DnsCache::new().await.unwrap();
107 let ips = cache.resolve("clob.polymarket.com").await.unwrap();
108 assert!(!ips.is_empty());
109 }
110
111 #[tokio::test]
112 async fn test_dns_cache_prewarm() {
113 let cache = DnsCache::new().await.unwrap();
114 cache.prewarm("clob.polymarket.com").await.unwrap();
115 assert_eq!(cache.cache_size().await, 1);
116 }
117
118 #[tokio::test]
119 async fn test_dns_cache_clear() {
120 let cache = DnsCache::new().await.unwrap();
121 cache.prewarm("clob.polymarket.com").await.unwrap();
122 cache.clear().await;
123 assert_eq!(cache.cache_size().await, 0);
124 }
125}