polyfill_rs/
dns_cache.rs

1//! DNS caching to reduce lookup latency
2//!
3//! This module provides DNS caching functionality to avoid repeated DNS lookups
4//! which can add 10-20ms per request.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use trust_dns_resolver::TokioAsyncResolver;
12use trust_dns_resolver::config::*;
13
14/// DNS cache entry with TTL
15#[derive(Clone, Debug)]
16struct DnsCacheEntry {
17    ips: Vec<IpAddr>,
18    expires_at: Instant,
19}
20
21/// DNS cache for resolving hostnames
22pub struct DnsCache {
23    resolver: TokioAsyncResolver,
24    cache: Arc<RwLock<HashMap<String, DnsCacheEntry>>>,
25    default_ttl: Duration,
26}
27
28impl DnsCache {
29    /// Create a new DNS cache with system configuration
30    pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
31        let resolver = TokioAsyncResolver::tokio(
32            ResolverConfig::default(),
33            ResolverOpts::default(),
34        );
35
36        Ok(Self {
37            resolver,
38            cache: Arc::new(RwLock::new(HashMap::new())),
39            default_ttl: Duration::from_secs(300), // 5 minutes default TTL
40        })
41    }
42
43    /// Create a DNS cache with custom TTL
44    pub async fn with_ttl(ttl: Duration) -> Result<Self, Box<dyn std::error::Error>> {
45        let resolver = TokioAsyncResolver::tokio(
46            ResolverConfig::default(),
47            ResolverOpts::default(),
48        );
49
50        Ok(Self {
51            resolver,
52            cache: Arc::new(RwLock::new(HashMap::new())),
53            default_ttl: ttl,
54        })
55    }
56
57    /// Resolve a hostname, using cache if available
58    pub async fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>, Box<dyn std::error::Error>> {
59        // Check cache first
60        {
61            let cache = self.cache.read().await;
62            if let Some(entry) = cache.get(hostname) {
63                if entry.expires_at > Instant::now() {
64                    return Ok(entry.ips.clone());
65                }
66            }
67        }
68
69        // Cache miss or expired, do actual lookup
70        let lookup = self.resolver.lookup_ip(hostname).await?;
71        let ips: Vec<IpAddr> = lookup.iter().collect();
72
73        // Store in cache
74        let entry = DnsCacheEntry {
75            ips: ips.clone(),
76            expires_at: Instant::now() + self.default_ttl,
77        };
78
79        let mut cache = self.cache.write().await;
80        cache.insert(hostname.to_string(), entry);
81
82        Ok(ips)
83    }
84
85    /// Pre-warm the cache by resolving a hostname
86    pub async fn prewarm(&self, hostname: &str) -> Result<(), Box<dyn std::error::Error>> {
87        self.resolve(hostname).await?;
88        Ok(())
89    }
90
91    /// Clear the cache
92    pub async fn clear(&self) {
93        let mut cache = self.cache.write().await;
94        cache.clear();
95    }
96
97    /// Get cache size
98    pub async fn cache_size(&self) -> usize {
99        let cache = self.cache.read().await;
100        cache.len()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[tokio::test]
109    async fn test_dns_cache_resolve() {
110        let cache = DnsCache::new().await.unwrap();
111        let ips = cache.resolve("clob.polymarket.com").await.unwrap();
112        assert!(!ips.is_empty());
113    }
114
115    #[tokio::test]
116    async fn test_dns_cache_prewarm() {
117        let cache = DnsCache::new().await.unwrap();
118        cache.prewarm("clob.polymarket.com").await.unwrap();
119        assert_eq!(cache.cache_size().await, 1);
120    }
121
122    #[tokio::test]
123    async fn test_dns_cache_clear() {
124        let cache = DnsCache::new().await.unwrap();
125        cache.prewarm("clob.polymarket.com").await.unwrap();
126        cache.clear().await;
127        assert_eq!(cache.cache_size().await, 0);
128    }
129}
130