ribbit_client/
dns_cache.rs1use std::collections::HashMap;
4use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9const DEFAULT_DNS_TTL_SECS: u64 = 300; #[derive(Debug, Clone)]
14struct DnsEntry {
15 addresses: Vec<IpAddr>,
17 expires_at: Instant,
19}
20
21#[derive(Debug, Clone)]
23pub struct DnsCache {
24 cache: Arc<RwLock<HashMap<String, DnsEntry>>>,
26 ttl: Duration,
28}
29
30impl DnsCache {
31 #[must_use]
33 pub fn new() -> Self {
34 Self::with_ttl(Duration::from_secs(DEFAULT_DNS_TTL_SECS))
35 }
36
37 #[must_use]
39 pub fn with_ttl(ttl: Duration) -> Self {
40 Self {
41 cache: Arc::new(RwLock::new(HashMap::new())),
42 ttl,
43 }
44 }
45
46 pub async fn resolve(&self, hostname: &str, port: u16) -> std::io::Result<Vec<SocketAddr>> {
52 {
54 let cache = self.cache.read().await;
55 if let Some(entry) = cache.get(hostname) {
56 if entry.expires_at > Instant::now() {
57 let socket_addrs: Vec<SocketAddr> = entry
59 .addresses
60 .iter()
61 .map(|&ip| SocketAddr::new(ip, port))
62 .collect();
63 return Ok(socket_addrs);
64 }
65 }
66 }
67
68 self.resolve_and_cache(hostname, port).await
70 }
71
72 async fn resolve_and_cache(
74 &self,
75 hostname: &str,
76 port: u16,
77 ) -> std::io::Result<Vec<SocketAddr>> {
78 let hostname_string = hostname.to_string();
80 let addrs = tokio::task::spawn_blocking(move || {
81 format!("{hostname_string}:{port}").to_socket_addrs()
82 })
83 .await
84 .map_err(std::io::Error::other)??;
85
86 let socket_addrs: Vec<SocketAddr> = addrs.collect();
87
88 if !socket_addrs.is_empty() {
89 let ip_addrs: Vec<IpAddr> = socket_addrs.iter().map(std::net::SocketAddr::ip).collect();
91
92 let entry = DnsEntry {
93 addresses: ip_addrs.clone(),
94 expires_at: Instant::now() + self.ttl,
95 };
96
97 let mut cache = self.cache.write().await;
98 cache.insert(hostname.to_string(), entry);
99 }
100
101 Ok(socket_addrs)
102 }
103
104 pub async fn clear(&self) {
106 let mut cache = self.cache.write().await;
107 cache.clear();
108 }
109
110 pub async fn remove_expired(&self) {
112 let mut cache = self.cache.write().await;
113 let now = Instant::now();
114 cache.retain(|_, entry| entry.expires_at > now);
115 }
116
117 pub async fn size(&self) -> usize {
119 let cache = self.cache.read().await;
120 cache.len()
121 }
122}
123
124impl Default for DnsCache {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[tokio::test]
135 async fn test_dns_cache_creation() {
136 let cache = DnsCache::new();
137 assert_eq!(cache.size().await, 0);
138 }
139
140 #[tokio::test]
141 async fn test_dns_cache_with_ttl() {
142 let cache = DnsCache::with_ttl(Duration::from_secs(60));
143 assert_eq!(cache.size().await, 0);
144 }
145
146 #[tokio::test]
147 async fn test_dns_resolution() {
148 let cache = DnsCache::new();
149
150 let addrs = cache.resolve("localhost", 80).await.unwrap();
152 assert!(!addrs.is_empty());
153
154 assert_eq!(cache.size().await, 1);
156
157 let addrs2 = cache.resolve("localhost", 80).await.unwrap();
159 assert_eq!(addrs, addrs2);
160 }
161
162 #[tokio::test]
163 async fn test_cache_expiration() {
164 let cache = DnsCache::with_ttl(Duration::from_millis(10));
165
166 let _addrs = cache.resolve("localhost", 80).await.unwrap();
168 assert_eq!(cache.size().await, 1);
169
170 tokio::time::sleep(Duration::from_millis(20)).await;
172
173 cache.remove_expired().await;
175 assert_eq!(cache.size().await, 0);
176 }
177
178 #[tokio::test]
179 async fn test_clear_cache() {
180 let cache = DnsCache::new();
181
182 let _addrs1 = cache.resolve("localhost", 80).await.unwrap();
184 let _addrs2 = cache.resolve("127.0.0.1", 80).await.unwrap();
185 assert!(cache.size().await >= 1);
186
187 cache.clear().await;
189 assert_eq!(cache.size().await, 0);
190 }
191}