Skip to main content

rdap_bootstrap/
lib.rs

1//! IANA Bootstrap service discovery.
2//!
3//! Implements RFC 9224 — the client fetches IANA bootstrap files to locate
4//! the authoritative RDAP server for a given query.
5//!
6//! Bootstrap files are cached in memory with a 24-hour TTL.
7
8#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock as StdRwLock};
12use std::time::{Duration, Instant};
13
14use ipnetwork::IpNetwork;
15use serde::Deserialize;
16use tokio::sync::RwLock;
17
18use rdap_types::error::{RdapError, Result};
19
20// ── IANA bootstrap response format ───────────────────────────────────────────
21
22/// Root structure of every IANA bootstrap JSON file.
23#[derive(Debug, Deserialize)]
24struct BootstrapFile {
25    #[allow(dead_code)]
26    version: String,
27    /// Each entry is `[ [patterns…], [servers…] ]`
28    services: Vec<(Vec<String>, Vec<String>)>,
29}
30
31// ── Internal cache entry ──────────────────────────────────────────────────────
32
33#[derive(Debug)]
34struct CacheEntry {
35    /// Parsed entries: `(pattern, first_server_url)`.
36    entries: Vec<(String, String)>,
37    fetched_at: Instant,
38}
39
40impl CacheEntry {
41    fn is_expired(&self, ttl: Duration) -> bool {
42        self.fetched_at.elapsed() > ttl
43    }
44}
45
46// ── Resolver ──────────────────────────────────────────────────────────────────
47
48/// Discovers the authoritative RDAP server URL for a query target.
49///
50/// Thread-safe: the cache is behind a `RwLock`, and a single `Bootstrap`
51/// instance can be shared across tasks via `Arc<Bootstrap>`.
52#[derive(Debug, Clone)]
53pub struct Bootstrap {
54    base_url: String,
55    client: reqwest::Client,
56    ttl: Duration,
57    cache: Arc<RwLock<HashMap<&'static str, CacheEntry>>>,
58    /// Custom TLD → server URL overrides; consulted before IANA lookup.
59    custom_servers: Arc<StdRwLock<HashMap<String, String>>>,
60}
61
62impl Bootstrap {
63    /// Creates a new resolver using the official IANA bootstrap endpoint.
64    pub fn new(client: reqwest::Client) -> Self {
65        Self {
66            base_url: "https://data.iana.org/rdap".to_string(),
67            client,
68            ttl: Duration::from_secs(86_400),
69            cache: Arc::new(RwLock::new(HashMap::new())),
70            custom_servers: Arc::new(StdRwLock::new(HashMap::new())),
71        }
72    }
73
74    /// Creates a resolver with a custom base URL (useful for testing).
75    pub fn with_base_url(base_url: impl Into<String>, client: reqwest::Client) -> Self {
76        Self {
77            base_url: base_url.into().trim_end_matches('/').to_string(),
78            client,
79            ttl: Duration::from_secs(86_400),
80            cache: Arc::new(RwLock::new(HashMap::new())),
81            custom_servers: Arc::new(StdRwLock::new(HashMap::new())),
82        }
83    }
84
85    /// Registers custom TLD → RDAP server URL overrides.
86    pub fn set_custom_servers(&mut self, servers: HashMap<String, String>) {
87        let mut guard = self.custom_servers.write().expect("lock poisoned");
88        *guard = servers
89            .into_iter()
90            .map(|(k, v)| (k.to_lowercase(), v))
91            .collect();
92    }
93
94    // ── Public API ────────────────────────────────────────────────────────────
95
96    /// Returns the RDAP server base URL for a domain (by TLD).
97    pub async fn for_domain(&self, domain: &str) -> Result<String> {
98        let tld = extract_tld(domain)?;
99        let tld_lower = tld.to_lowercase();
100
101        {
102            let custom = self.custom_servers.read().expect("lock poisoned");
103            if let Some(server) = custom.get(&tld_lower) {
104                return Ok(server.clone());
105            }
106        }
107
108        let entries = self.get_entries("dns").await?;
109
110        entries
111            .iter()
112            .find(|(pattern, _)| pattern.to_lowercase() == tld_lower)
113            .map(|(_, server)| server.clone())
114            .ok_or_else(|| RdapError::NoServerFound {
115                query: domain.to_string(),
116            })
117    }
118
119    /// Returns the RDAP server base URL for an IPv4 address.
120    pub async fn for_ipv4(&self, ip: &str) -> Result<String> {
121        let addr: std::net::IpAddr = ip
122            .parse()
123            .map_err(|_| RdapError::InvalidInput(format!("Invalid IPv4 address: {ip}")))?;
124
125        let entries = self.get_entries("ipv4").await?;
126        self.match_ip_entries(&entries, addr, ip)
127    }
128
129    /// Returns the RDAP server base URL for an IPv6 address.
130    pub async fn for_ipv6(&self, ip: &str) -> Result<String> {
131        let addr: std::net::IpAddr = ip
132            .parse()
133            .map_err(|_| RdapError::InvalidInput(format!("Invalid IPv6 address: {ip}")))?;
134
135        let entries = self.get_entries("ipv6").await?;
136        self.match_ip_entries(&entries, addr, ip)
137    }
138
139    /// Returns the RDAP server base URL for an ASN.
140    pub async fn for_asn(&self, asn: u32) -> Result<String> {
141        let entries = self.get_entries("asn").await?;
142
143        for (pattern, server) in &entries {
144            if let Some((start, end)) = pattern.split_once('-') {
145                let start: u32 = start.parse().unwrap_or(u32::MAX);
146                let end: u32 = end.parse().unwrap_or(0);
147                if asn >= start && asn <= end {
148                    return Ok(server.clone());
149                }
150            } else if let Ok(n) = pattern.parse::<u32>() {
151                if asn == n {
152                    return Ok(server.clone());
153                }
154            }
155        }
156
157        Err(RdapError::NoServerFound {
158            query: format!("AS{asn}"),
159        })
160    }
161
162    /// Clears the in-memory bootstrap cache.
163    pub async fn clear_cache(&self) {
164        self.cache.write().await.clear();
165    }
166
167    // ── Private helpers ───────────────────────────────────────────────────────
168
169    async fn get_entries(&self, resource: &'static str) -> Result<Vec<(String, String)>> {
170        {
171            let cache = self.cache.read().await;
172            if let Some(entry) = cache.get(resource) {
173                if !entry.is_expired(self.ttl) {
174                    return Ok(entry.entries.clone());
175                }
176            }
177        }
178
179        let entries = self.fetch_entries(resource).await?;
180
181        let mut cache = self.cache.write().await;
182        cache.insert(
183            resource,
184            CacheEntry {
185                entries: entries.clone(),
186                fetched_at: Instant::now(),
187            },
188        );
189
190        Ok(entries)
191    }
192
193    async fn fetch_entries(&self, resource: &str) -> Result<Vec<(String, String)>> {
194        let url = format!("{}/{}.json", self.base_url, resource);
195
196        let response = self
197            .client
198            .get(&url)
199            .send()
200            .await
201            .map_err(RdapError::Network)?;
202
203        if !response.status().is_success() {
204            return Err(RdapError::HttpStatus {
205                status: response.status().as_u16(),
206                url,
207            });
208        }
209
210        let file: BootstrapFile = response.json().await.map_err(|e| RdapError::ParseError {
211            reason: e.to_string(),
212        })?;
213
214        let entries = file
215            .services
216            .into_iter()
217            .filter_map(|(patterns, servers)| {
218                let server = servers.into_iter().next()?;
219                let server = server.trim_end_matches('/').to_string();
220                Some(patterns.into_iter().map(move |p| (p, server.clone())))
221            })
222            .flatten()
223            .collect();
224
225        Ok(entries)
226    }
227
228    fn match_ip_entries(
229        &self,
230        entries: &[(String, String)],
231        addr: std::net::IpAddr,
232        original: &str,
233    ) -> Result<String> {
234        for (pattern, server) in entries {
235            if let Ok(network) = pattern.parse::<IpNetwork>() {
236                if network.contains(addr) {
237                    return Ok(server.clone());
238                }
239            }
240        }
241        Err(RdapError::NoServerFound {
242            query: original.to_string(),
243        })
244    }
245}
246
247// ── Utilities ─────────────────────────────────────────────────────────────────
248
249fn extract_tld(domain: &str) -> Result<String> {
250    let domain = domain.trim_end_matches('.').to_lowercase();
251
252    if domain.is_empty() {
253        return Err(RdapError::InvalidInput(
254            "Domain name must not be empty".to_string(),
255        ));
256    }
257
258    let parts: Vec<&str> = domain.split('.').collect();
259
260    match parts.len() {
261        0 => Err(RdapError::InvalidInput(
262            "Domain name must not be empty".to_string(),
263        )),
264        1 => Ok(parts[0].to_string()),
265        _ => Ok(parts.last().unwrap().to_string()),
266    }
267}
268
269// ── Tests ─────────────────────────────────────────────────────────────────────
270
271#[cfg(test)]
272mod tests {
273    use super::extract_tld;
274
275    #[test]
276    fn extracts_simple_tld() {
277        assert_eq!(extract_tld("example.com").unwrap(), "com");
278        assert_eq!(extract_tld("google.org").unwrap(), "org");
279    }
280
281    #[test]
282    fn extracts_from_subdomain() {
283        assert_eq!(extract_tld("www.example.com").unwrap(), "com");
284        assert_eq!(extract_tld("deep.sub.example.net").unwrap(), "net");
285    }
286
287    #[test]
288    fn handles_single_label() {
289        assert_eq!(extract_tld("com").unwrap(), "com");
290    }
291
292    #[test]
293    fn rejects_empty() {
294        assert!(extract_tld("").is_err());
295    }
296}