ribbit_client/
dns_cache.rs

1//! DNS caching for Ribbit client connections
2
3use std::collections::HashMap;
4use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9/// Default DNS cache TTL in seconds
10const DEFAULT_DNS_TTL_SECS: u64 = 300; // 5 minutes
11
12/// A cached DNS entry
13#[derive(Debug, Clone)]
14struct DnsEntry {
15    /// Resolved IP addresses
16    addresses: Vec<IpAddr>,
17    /// When this entry expires
18    expires_at: Instant,
19}
20
21/// DNS cache for resolving hostnames to IP addresses
22#[derive(Debug, Clone)]
23pub struct DnsCache {
24    /// Cache storage
25    cache: Arc<RwLock<HashMap<String, DnsEntry>>>,
26    /// TTL for cache entries
27    ttl: Duration,
28}
29
30impl DnsCache {
31    /// Create a new DNS cache with default TTL
32    #[must_use]
33    pub fn new() -> Self {
34        Self::with_ttl(Duration::from_secs(DEFAULT_DNS_TTL_SECS))
35    }
36
37    /// Create a new DNS cache with specified TTL
38    #[must_use]
39    pub fn with_ttl(ttl: Duration) -> Self {
40        Self {
41            cache: Arc::new(RwLock::new(HashMap::new())),
42            ttl,
43        }
44    }
45
46    /// Resolve a hostname, using cache if available
47    /// Resolve hostname and port to socket addresses with caching
48    ///
49    /// # Errors
50    /// Returns an error if DNS resolution fails
51    pub async fn resolve(&self, hostname: &str, port: u16) -> std::io::Result<Vec<SocketAddr>> {
52        // Check cache first
53        {
54            let cache = self.cache.read().await;
55            if let Some(entry) = cache.get(hostname) {
56                if entry.expires_at > Instant::now() {
57                    // Cache hit - return cached addresses
58                    let socket_addrs: Vec<SocketAddr> = entry
59                        .addresses
60                        .iter()
61                        .map(|&ip| SocketAddr::new(ip, port))
62                        .collect();
63                    return Ok(socket_addrs);
64                }
65            }
66        }
67
68        // Cache miss or expired - resolve and update cache
69        self.resolve_and_cache(hostname, port).await
70    }
71
72    /// Resolve hostname and update cache
73    async fn resolve_and_cache(
74        &self,
75        hostname: &str,
76        port: u16,
77    ) -> std::io::Result<Vec<SocketAddr>> {
78        // Perform DNS resolution synchronously in a blocking task
79        let hostname_string = hostname.to_string();
80        let addrs = tokio::task::spawn_blocking(move || {
81            format!("{hostname_string}:{port}").to_socket_addrs()
82        })
83        .await
84        .map_err(std::io::Error::other)??;
85
86        let socket_addrs: Vec<SocketAddr> = addrs.collect();
87
88        if !socket_addrs.is_empty() {
89            // Extract IP addresses and cache them
90            let ip_addrs: Vec<IpAddr> = socket_addrs.iter().map(std::net::SocketAddr::ip).collect();
91
92            let entry = DnsEntry {
93                addresses: ip_addrs.clone(),
94                expires_at: Instant::now() + self.ttl,
95            };
96
97            let mut cache = self.cache.write().await;
98            cache.insert(hostname.to_string(), entry);
99        }
100
101        Ok(socket_addrs)
102    }
103
104    /// Clear the DNS cache
105    pub async fn clear(&self) {
106        let mut cache = self.cache.write().await;
107        cache.clear();
108    }
109
110    /// Remove expired entries from the cache
111    pub async fn remove_expired(&self) {
112        let mut cache = self.cache.write().await;
113        let now = Instant::now();
114        cache.retain(|_, entry| entry.expires_at > now);
115    }
116
117    /// Get the number of cached entries
118    pub async fn size(&self) -> usize {
119        let cache = self.cache.read().await;
120        cache.len()
121    }
122}
123
124impl Default for DnsCache {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[tokio::test]
135    async fn test_dns_cache_creation() {
136        let cache = DnsCache::new();
137        assert_eq!(cache.size().await, 0);
138    }
139
140    #[tokio::test]
141    async fn test_dns_cache_with_ttl() {
142        let cache = DnsCache::with_ttl(Duration::from_secs(60));
143        assert_eq!(cache.size().await, 0);
144    }
145
146    #[tokio::test]
147    async fn test_dns_resolution() {
148        let cache = DnsCache::new();
149
150        // Resolve localhost - should always work
151        let addrs = cache.resolve("localhost", 80).await.unwrap();
152        assert!(!addrs.is_empty());
153
154        // Cache should now contain the entry
155        assert_eq!(cache.size().await, 1);
156
157        // Second resolution should use cache
158        let addrs2 = cache.resolve("localhost", 80).await.unwrap();
159        assert_eq!(addrs, addrs2);
160    }
161
162    #[tokio::test]
163    async fn test_cache_expiration() {
164        let cache = DnsCache::with_ttl(Duration::from_millis(10));
165
166        // Resolve and cache
167        let _addrs = cache.resolve("localhost", 80).await.unwrap();
168        assert_eq!(cache.size().await, 1);
169
170        // Wait for expiration
171        tokio::time::sleep(Duration::from_millis(20)).await;
172
173        // Remove expired entries
174        cache.remove_expired().await;
175        assert_eq!(cache.size().await, 0);
176    }
177
178    #[tokio::test]
179    async fn test_clear_cache() {
180        let cache = DnsCache::new();
181
182        // Add some entries
183        let _addrs1 = cache.resolve("localhost", 80).await.unwrap();
184        let _addrs2 = cache.resolve("127.0.0.1", 80).await.unwrap();
185        assert!(cache.size().await >= 1);
186
187        // Clear cache
188        cache.clear().await;
189        assert_eq!(cache.size().await, 0);
190    }
191}