Skip to main content

essence/utils/
dns_cache.rs

1//! DNS caching for improved latency
2//!
3//! This module provides DNS caching using hickory-resolver (formerly trust-dns)
4//! to reduce DNS lookup latency by 10-50ms. Implements LRU eviction to prevent
5//! unbounded memory growth.
6
7use crate::error::{Result, ScrapeError};
8use hickory_resolver::config::{ResolverConfig, ResolverOpts};
9use hickory_resolver::TokioAsyncResolver;
10use lru::LruCache;
11use std::net::IpAddr;
12use std::num::NonZeroUsize;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15use tracing::{debug, trace};
16
17/// DNS cache with LRU eviction
18///
19/// Caches DNS lookups to reduce latency. Thread-safe via Arc<Mutex<...>>.
20#[derive(Clone)]
21pub struct DnsCache {
22    resolver: Arc<TokioAsyncResolver>,
23    cache: Arc<Mutex<LruCache<String, Vec<IpAddr>>>>,
24    stats: Arc<Mutex<CacheStats>>,
25}
26
27#[derive(Debug, Default, Clone)]
28pub struct CacheStats {
29    pub hits: u64,
30    pub misses: u64,
31    pub lookups: u64,
32}
33
34impl CacheStats {
35    pub fn hit_rate(&self) -> f64 {
36        if self.lookups == 0 {
37            0.0
38        } else {
39            self.hits as f64 / self.lookups as f64
40        }
41    }
42}
43
44impl DnsCache {
45    /// Create a new DNS cache with default capacity (1000 entries)
46    pub fn new() -> Result<Self> {
47        Self::with_capacity(1000)
48    }
49
50    /// Create a new DNS cache with specified capacity
51    pub fn with_capacity(capacity: usize) -> Result<Self> {
52        // Use system DNS configuration
53        let resolver = TokioAsyncResolver::tokio(
54            ResolverConfig::default(),
55            ResolverOpts::default(),
56        );
57
58        Ok(Self {
59            resolver: Arc::new(resolver),
60            cache: Arc::new(Mutex::new(
61                LruCache::new(
62                    NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1000).unwrap())
63                ),
64            )),
65            stats: Arc::new(Mutex::new(CacheStats::default())),
66        })
67    }
68
69    /// Lookup a domain, using cache if available
70    ///
71    /// Returns a list of IP addresses for the domain. The first IP is typically
72    /// the preferred address for connection.
73    pub async fn lookup(&self, domain: &str) -> Result<Vec<IpAddr>> {
74        // Update stats
75        {
76            let mut stats = self.stats.lock().await;
77            stats.lookups += 1;
78        }
79
80        // Check cache first
81        {
82            let mut cache = self.cache.lock().await;
83            if let Some(ips) = cache.get(domain) {
84                trace!("DNS cache hit for domain: {}", domain);
85                let mut stats = self.stats.lock().await;
86                stats.hits += 1;
87                return Ok(ips.clone());
88            }
89        }
90
91        // Cache miss - perform actual DNS lookup
92        debug!("DNS cache miss for domain: {}, performing lookup", domain);
93        {
94            let mut stats = self.stats.lock().await;
95            stats.misses += 1;
96        }
97
98        let response = self
99            .resolver
100            .lookup_ip(domain)
101            .await
102            .map_err(|e| ScrapeError::Internal(format!("DNS lookup failed for {}: {}", domain, e)))?;
103
104        let ips: Vec<IpAddr> = response.iter().collect();
105
106        if ips.is_empty() {
107            return Err(ScrapeError::Internal(format!(
108                "No IP addresses found for domain: {}",
109                domain
110            )));
111        }
112
113        debug!("DNS lookup resolved {} to {} addresses", domain, ips.len());
114
115        // Store in cache
116        {
117            let mut cache = self.cache.lock().await;
118            cache.put(domain.to_string(), ips.clone());
119        }
120
121        Ok(ips)
122    }
123
124    /// Get cache statistics
125    pub async fn stats(&self) -> CacheStats {
126        self.stats.lock().await.clone()
127    }
128
129    /// Reset cache statistics
130    pub async fn reset_stats(&self) {
131        let mut stats = self.stats.lock().await;
132        *stats = CacheStats::default();
133    }
134
135    /// Clear the cache
136    pub async fn clear(&self) {
137        let mut cache = self.cache.lock().await;
138        cache.clear();
139        debug!("DNS cache cleared");
140    }
141
142    /// Get current cache size
143    pub async fn size(&self) -> usize {
144        let cache = self.cache.lock().await;
145        cache.len()
146    }
147}
148
149impl Default for DnsCache {
150    fn default() -> Self {
151        Self::new().expect("Failed to create default DNS cache")
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[tokio::test]
160    async fn test_dns_cache_creation() {
161        let cache = DnsCache::new();
162        assert!(cache.is_ok());
163    }
164
165    #[tokio::test]
166    async fn test_dns_cache_with_capacity() {
167        let cache = DnsCache::with_capacity(500);
168        assert!(cache.is_ok());
169    }
170
171    #[tokio::test]
172    #[ignore] // Requires network
173    async fn test_dns_lookup_success() {
174        let cache = DnsCache::new().unwrap();
175        let result = cache.lookup("google.com").await;
176        assert!(result.is_ok());
177        let ips = result.unwrap();
178        assert!(!ips.is_empty());
179    }
180
181    #[tokio::test]
182    #[ignore] // Requires network
183    async fn test_dns_cache_hit() {
184        let cache = DnsCache::new().unwrap();
185
186        // First lookup - cache miss
187        let result1 = cache.lookup("google.com").await;
188        assert!(result1.is_ok());
189
190        // Second lookup - should be cache hit
191        let result2 = cache.lookup("google.com").await;
192        assert!(result2.is_ok());
193
194        // Verify they return the same IPs
195        assert_eq!(result1.unwrap(), result2.unwrap());
196
197        // Check stats
198        let stats = cache.stats().await;
199        assert_eq!(stats.lookups, 2);
200        assert_eq!(stats.hits, 1);
201        assert_eq!(stats.misses, 1);
202        assert_eq!(stats.hit_rate(), 0.5);
203    }
204
205    #[tokio::test]
206    #[ignore] // Requires network
207    async fn test_dns_cache_multiple_domains() {
208        let cache = DnsCache::new().unwrap();
209
210        // Lookup multiple domains
211        let _ = cache.lookup("google.com").await;
212        let _ = cache.lookup("github.com").await;
213        let _ = cache.lookup("google.com").await; // cache hit
214        let _ = cache.lookup("github.com").await; // cache hit
215
216        let stats = cache.stats().await;
217        assert_eq!(stats.lookups, 4);
218        assert_eq!(stats.hits, 2);
219        assert_eq!(stats.misses, 2);
220    }
221
222    #[tokio::test]
223    async fn test_dns_cache_stats() {
224        let cache = DnsCache::new().unwrap();
225        let stats = cache.stats().await;
226        assert_eq!(stats.hits, 0);
227        assert_eq!(stats.misses, 0);
228        assert_eq!(stats.lookups, 0);
229        assert_eq!(stats.hit_rate(), 0.0);
230    }
231
232    #[tokio::test]
233    async fn test_dns_cache_clear() {
234        let cache = DnsCache::new().unwrap();
235        cache.clear().await;
236        let size = cache.size().await;
237        assert_eq!(size, 0);
238    }
239
240    #[tokio::test]
241    async fn test_dns_cache_reset_stats() {
242        let cache = DnsCache::new().unwrap();
243        
244        // Manually update stats
245        {
246            let mut stats = cache.stats.lock().await;
247            stats.hits = 10;
248            stats.misses = 5;
249            stats.lookups = 15;
250        }
251
252        // Reset
253        cache.reset_stats().await;
254
255        let stats = cache.stats().await;
256        assert_eq!(stats.hits, 0);
257        assert_eq!(stats.misses, 0);
258        assert_eq!(stats.lookups, 0);
259    }
260
261    #[tokio::test]
262    async fn test_invalid_domain() {
263        let cache = DnsCache::new().unwrap();
264        let result = cache.lookup("invalid.domain.that.does.not.exist.xyz123").await;
265        assert!(result.is_err());
266    }
267}