polyfill_rs/
dns_cache.rs

1//! DNS caching to reduce lookup latency
2//!
3//! This module provides DNS caching functionality to avoid repeated DNS lookups
4//! which can add 10-20ms per request.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use hickory_resolver::config::*;
12use hickory_resolver::TokioAsyncResolver;
13
14/// DNS cache entry with TTL
15#[derive(Clone, Debug)]
16struct DnsCacheEntry {
17    ips: Vec<IpAddr>,
18    expires_at: Instant,
19}
20
21/// DNS cache for resolving hostnames
22pub struct DnsCache {
23    resolver: TokioAsyncResolver,
24    cache: Arc<RwLock<HashMap<String, DnsCacheEntry>>>,
25    default_ttl: Duration,
26}
27
28impl DnsCache {
29    /// Create a new DNS cache with system configuration
30    pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
31        let resolver =
32            TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
33
34        Ok(Self {
35            resolver,
36            cache: Arc::new(RwLock::new(HashMap::new())),
37            default_ttl: Duration::from_secs(300), // 5 minutes default TTL
38        })
39    }
40
41    /// Create a DNS cache with custom TTL
42    pub async fn with_ttl(ttl: Duration) -> Result<Self, Box<dyn std::error::Error>> {
43        let resolver =
44            TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
45
46        Ok(Self {
47            resolver,
48            cache: Arc::new(RwLock::new(HashMap::new())),
49            default_ttl: ttl,
50        })
51    }
52
53    /// Resolve a hostname, using cache if available
54    pub async fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>, Box<dyn std::error::Error>> {
55        // Check cache first
56        {
57            let cache = self.cache.read().await;
58            if let Some(entry) = cache.get(hostname) {
59                if entry.expires_at > Instant::now() {
60                    return Ok(entry.ips.clone());
61                }
62            }
63        }
64
65        // Cache miss or expired, do actual lookup
66        let lookup = self.resolver.lookup_ip(hostname).await?;
67        let ips: Vec<IpAddr> = lookup.iter().collect();
68
69        // Store in cache
70        let entry = DnsCacheEntry {
71            ips: ips.clone(),
72            expires_at: Instant::now() + self.default_ttl,
73        };
74
75        let mut cache = self.cache.write().await;
76        cache.insert(hostname.to_string(), entry);
77
78        Ok(ips)
79    }
80
81    /// Pre-warm the cache by resolving a hostname
82    pub async fn prewarm(&self, hostname: &str) -> Result<(), Box<dyn std::error::Error>> {
83        self.resolve(hostname).await?;
84        Ok(())
85    }
86
87    /// Clear the cache
88    pub async fn clear(&self) {
89        let mut cache = self.cache.write().await;
90        cache.clear();
91    }
92
93    /// Get cache size
94    pub async fn cache_size(&self) -> usize {
95        let cache = self.cache.read().await;
96        cache.len()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[tokio::test]
105    async fn test_dns_cache_resolve() {
106        let cache = DnsCache::new().await.unwrap();
107        let ips = cache.resolve("clob.polymarket.com").await.unwrap();
108        assert!(!ips.is_empty());
109    }
110
111    #[tokio::test]
112    async fn test_dns_cache_prewarm() {
113        let cache = DnsCache::new().await.unwrap();
114        cache.prewarm("clob.polymarket.com").await.unwrap();
115        assert_eq!(cache.cache_size().await, 1);
116    }
117
118    #[tokio::test]
119    async fn test_dns_cache_clear() {
120        let cache = DnsCache::new().await.unwrap();
121        cache.prewarm("clob.polymarket.com").await.unwrap();
122        cache.clear().await;
123        assert_eq!(cache.cache_size().await, 0);
124    }
125}