1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use trust_dns_resolver::TokioAsyncResolver;
12use trust_dns_resolver::config::*;
13
14#[derive(Clone, Debug)]
16struct DnsCacheEntry {
17 ips: Vec<IpAddr>,
18 expires_at: Instant,
19}
20
21pub struct DnsCache {
23 resolver: TokioAsyncResolver,
24 cache: Arc<RwLock<HashMap<String, DnsCacheEntry>>>,
25 default_ttl: Duration,
26}
27
28impl DnsCache {
29 pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
31 let resolver = TokioAsyncResolver::tokio(
32 ResolverConfig::default(),
33 ResolverOpts::default(),
34 );
35
36 Ok(Self {
37 resolver,
38 cache: Arc::new(RwLock::new(HashMap::new())),
39 default_ttl: Duration::from_secs(300), })
41 }
42
43 pub async fn with_ttl(ttl: Duration) -> Result<Self, Box<dyn std::error::Error>> {
45 let resolver = TokioAsyncResolver::tokio(
46 ResolverConfig::default(),
47 ResolverOpts::default(),
48 );
49
50 Ok(Self {
51 resolver,
52 cache: Arc::new(RwLock::new(HashMap::new())),
53 default_ttl: ttl,
54 })
55 }
56
57 pub async fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>, Box<dyn std::error::Error>> {
59 {
61 let cache = self.cache.read().await;
62 if let Some(entry) = cache.get(hostname) {
63 if entry.expires_at > Instant::now() {
64 return Ok(entry.ips.clone());
65 }
66 }
67 }
68
69 let lookup = self.resolver.lookup_ip(hostname).await?;
71 let ips: Vec<IpAddr> = lookup.iter().collect();
72
73 let entry = DnsCacheEntry {
75 ips: ips.clone(),
76 expires_at: Instant::now() + self.default_ttl,
77 };
78
79 let mut cache = self.cache.write().await;
80 cache.insert(hostname.to_string(), entry);
81
82 Ok(ips)
83 }
84
85 pub async fn prewarm(&self, hostname: &str) -> Result<(), Box<dyn std::error::Error>> {
87 self.resolve(hostname).await?;
88 Ok(())
89 }
90
91 pub async fn clear(&self) {
93 let mut cache = self.cache.write().await;
94 cache.clear();
95 }
96
97 pub async fn cache_size(&self) -> usize {
99 let cache = self.cache.read().await;
100 cache.len()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[tokio::test]
109 async fn test_dns_cache_resolve() {
110 let cache = DnsCache::new().await.unwrap();
111 let ips = cache.resolve("clob.polymarket.com").await.unwrap();
112 assert!(!ips.is_empty());
113 }
114
115 #[tokio::test]
116 async fn test_dns_cache_prewarm() {
117 let cache = DnsCache::new().await.unwrap();
118 cache.prewarm("clob.polymarket.com").await.unwrap();
119 assert_eq!(cache.cache_size().await, 1);
120 }
121
122 #[tokio::test]
123 async fn test_dns_cache_clear() {
124 let cache = DnsCache::new().await.unwrap();
125 cache.prewarm("clob.polymarket.com").await.unwrap();
126 cache.clear().await;
127 assert_eq!(cache.cache_size().await, 0);
128 }
129}
130