1#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock as StdRwLock};
12use std::time::{Duration, Instant};
13
14use ipnetwork::IpNetwork;
15use serde::Deserialize;
16use tokio::sync::RwLock;
17
18use rdap_types::error::{RdapError, Result};
19
20#[derive(Debug, Deserialize)]
24struct BootstrapFile {
25 #[allow(dead_code)]
26 version: String,
27 services: Vec<(Vec<String>, Vec<String>)>,
29}
30
31#[derive(Debug)]
34struct CacheEntry {
35 entries: Vec<(String, String)>,
37 fetched_at: Instant,
38}
39
40impl CacheEntry {
41 fn is_expired(&self, ttl: Duration) -> bool {
42 self.fetched_at.elapsed() > ttl
43 }
44}
45
46#[derive(Debug, Clone)]
53pub struct Bootstrap {
54 base_url: String,
55 client: reqwest::Client,
56 ttl: Duration,
57 cache: Arc<RwLock<HashMap<&'static str, CacheEntry>>>,
58 custom_servers: Arc<StdRwLock<HashMap<String, String>>>,
60}
61
62impl Bootstrap {
63 pub fn new(client: reqwest::Client) -> Self {
65 Self {
66 base_url: "https://data.iana.org/rdap".to_string(),
67 client,
68 ttl: Duration::from_secs(86_400),
69 cache: Arc::new(RwLock::new(HashMap::new())),
70 custom_servers: Arc::new(StdRwLock::new(HashMap::new())),
71 }
72 }
73
74 pub fn with_base_url(base_url: impl Into<String>, client: reqwest::Client) -> Self {
76 Self {
77 base_url: base_url.into().trim_end_matches('/').to_string(),
78 client,
79 ttl: Duration::from_secs(86_400),
80 cache: Arc::new(RwLock::new(HashMap::new())),
81 custom_servers: Arc::new(StdRwLock::new(HashMap::new())),
82 }
83 }
84
85 pub fn set_custom_servers(&mut self, servers: HashMap<String, String>) {
87 let mut guard = self.custom_servers.write().expect("lock poisoned");
88 *guard = servers
89 .into_iter()
90 .map(|(k, v)| (k.to_lowercase(), v))
91 .collect();
92 }
93
94 pub async fn for_domain(&self, domain: &str) -> Result<String> {
98 let tld = extract_tld(domain)?;
99 let tld_lower = tld.to_lowercase();
100
101 {
102 let custom = self.custom_servers.read().expect("lock poisoned");
103 if let Some(server) = custom.get(&tld_lower) {
104 return Ok(server.clone());
105 }
106 }
107
108 let entries = self.get_entries("dns").await?;
109
110 entries
111 .iter()
112 .find(|(pattern, _)| pattern.to_lowercase() == tld_lower)
113 .map(|(_, server)| server.clone())
114 .ok_or_else(|| RdapError::NoServerFound {
115 query: domain.to_string(),
116 })
117 }
118
119 pub async fn for_ipv4(&self, ip: &str) -> Result<String> {
121 let addr: std::net::IpAddr = ip
122 .parse()
123 .map_err(|_| RdapError::InvalidInput(format!("Invalid IPv4 address: {ip}")))?;
124
125 let entries = self.get_entries("ipv4").await?;
126 self.match_ip_entries(&entries, addr, ip)
127 }
128
129 pub async fn for_ipv6(&self, ip: &str) -> Result<String> {
131 let addr: std::net::IpAddr = ip
132 .parse()
133 .map_err(|_| RdapError::InvalidInput(format!("Invalid IPv6 address: {ip}")))?;
134
135 let entries = self.get_entries("ipv6").await?;
136 self.match_ip_entries(&entries, addr, ip)
137 }
138
139 pub async fn for_asn(&self, asn: u32) -> Result<String> {
141 let entries = self.get_entries("asn").await?;
142
143 for (pattern, server) in &entries {
144 if let Some((start, end)) = pattern.split_once('-') {
145 let start: u32 = start.parse().unwrap_or(u32::MAX);
146 let end: u32 = end.parse().unwrap_or(0);
147 if asn >= start && asn <= end {
148 return Ok(server.clone());
149 }
150 } else if let Ok(n) = pattern.parse::<u32>() {
151 if asn == n {
152 return Ok(server.clone());
153 }
154 }
155 }
156
157 Err(RdapError::NoServerFound {
158 query: format!("AS{asn}"),
159 })
160 }
161
162 pub async fn clear_cache(&self) {
164 self.cache.write().await.clear();
165 }
166
167 async fn get_entries(&self, resource: &'static str) -> Result<Vec<(String, String)>> {
170 {
171 let cache = self.cache.read().await;
172 if let Some(entry) = cache.get(resource) {
173 if !entry.is_expired(self.ttl) {
174 return Ok(entry.entries.clone());
175 }
176 }
177 }
178
179 let entries = self.fetch_entries(resource).await?;
180
181 let mut cache = self.cache.write().await;
182 cache.insert(
183 resource,
184 CacheEntry {
185 entries: entries.clone(),
186 fetched_at: Instant::now(),
187 },
188 );
189
190 Ok(entries)
191 }
192
193 async fn fetch_entries(&self, resource: &str) -> Result<Vec<(String, String)>> {
194 let url = format!("{}/{}.json", self.base_url, resource);
195
196 let response = self
197 .client
198 .get(&url)
199 .send()
200 .await
201 .map_err(RdapError::Network)?;
202
203 if !response.status().is_success() {
204 return Err(RdapError::HttpStatus {
205 status: response.status().as_u16(),
206 url,
207 });
208 }
209
210 let file: BootstrapFile = response.json().await.map_err(|e| RdapError::ParseError {
211 reason: e.to_string(),
212 })?;
213
214 let entries = file
215 .services
216 .into_iter()
217 .filter_map(|(patterns, servers)| {
218 let server = servers.into_iter().next()?;
219 let server = server.trim_end_matches('/').to_string();
220 Some(patterns.into_iter().map(move |p| (p, server.clone())))
221 })
222 .flatten()
223 .collect();
224
225 Ok(entries)
226 }
227
228 fn match_ip_entries(
229 &self,
230 entries: &[(String, String)],
231 addr: std::net::IpAddr,
232 original: &str,
233 ) -> Result<String> {
234 for (pattern, server) in entries {
235 if let Ok(network) = pattern.parse::<IpNetwork>() {
236 if network.contains(addr) {
237 return Ok(server.clone());
238 }
239 }
240 }
241 Err(RdapError::NoServerFound {
242 query: original.to_string(),
243 })
244 }
245}
246
247fn extract_tld(domain: &str) -> Result<String> {
250 let domain = domain.trim_end_matches('.').to_lowercase();
251
252 if domain.is_empty() {
253 return Err(RdapError::InvalidInput(
254 "Domain name must not be empty".to_string(),
255 ));
256 }
257
258 let parts: Vec<&str> = domain.split('.').collect();
259
260 match parts.len() {
261 0 => Err(RdapError::InvalidInput(
262 "Domain name must not be empty".to_string(),
263 )),
264 1 => Ok(parts[0].to_string()),
265 _ => Ok(parts.last().unwrap().to_string()),
266 }
267}
268
269#[cfg(test)]
272mod tests {
273 use super::extract_tld;
274
275 #[test]
276 fn extracts_simple_tld() {
277 assert_eq!(extract_tld("example.com").unwrap(), "com");
278 assert_eq!(extract_tld("google.org").unwrap(), "org");
279 }
280
281 #[test]
282 fn extracts_from_subdomain() {
283 assert_eq!(extract_tld("www.example.com").unwrap(), "com");
284 assert_eq!(extract_tld("deep.sub.example.net").unwrap(), "net");
285 }
286
287 #[test]
288 fn handles_single_label() {
289 assert_eq!(extract_tld("com").unwrap(), "com");
290 }
291
292 #[test]
293 fn rejects_empty() {
294 assert!(extract_tld("").is_err());
295 }
296}