1use crate::error::DomainCheckError;
7use std::collections::HashMap;
8use std::sync::Mutex;
9use std::time::{Duration, Instant};
10
11struct BootstrapCache {
13 endpoints: HashMap<String, String>,
14 last_update: Instant,
15}
16
17impl BootstrapCache {
18 fn new() -> Self {
19 Self {
20 endpoints: HashMap::new(),
21 last_update: Instant::now(),
22 }
23 }
24
25 fn get(&self, tld: &str) -> Option<String> {
26 self.endpoints.get(tld).cloned()
27 }
28
29 fn insert(&mut self, tld: String, endpoint: String) {
30 self.endpoints.insert(tld, endpoint);
31 self.last_update = Instant::now();
32 }
33
34 fn is_stale(&self) -> bool {
35 self.last_update.elapsed() > Duration::from_secs(3600)
37 }
38}
39
40lazy_static::lazy_static! {
42 static ref BOOTSTRAP_CACHE: Mutex<BootstrapCache> = Mutex::new(BootstrapCache::new());
43}
44
45pub fn get_rdap_registry_map() -> HashMap<&'static str, &'static str> {
54 HashMap::from([
55 ("com", "https://rdap.verisign.com/com/v1/domain/"),
57 ("net", "https://rdap.verisign.com/net/v1/domain/"),
58 (
59 "org",
60 "https://rdap.publicinterestregistry.org/rdap/domain/",
61 ),
62 ("info", "https://rdap.identitydigital.services/rdap/domain/"),
63 ("biz", "https://rdap.nic.biz/domain/"),
64 ("app", "https://rdap.nic.google/domain/"),
66 ("dev", "https://rdap.nic.google/domain/"),
67 ("page", "https://rdap.nic.google/domain/"),
68 ("blog", "https://rdap.nic.blog/domain/"),
70 ("shop", "https://rdap.nic.shop/domain/"),
71 ("xyz", "https://rdap.nic.xyz/domain/"),
72 ("tech", "https://rdap.nic.tech/domain/"),
73 ("online", "https://rdap.nic.online/domain/"),
74 ("site", "https://rdap.nic.site/domain/"),
75 ("website", "https://rdap.nic.website/domain/"),
76 ("io", "https://rdap.identitydigital.services/rdap/domain/"), ("ai", "https://rdap.nic.ai/domain/"), ("co", "https://rdap.nic.co/domain/"), ("me", "https://rdap.nic.me/domain/"), ("us", "https://rdap.nic.us/domain/"), ("uk", "https://rdap.nominet.uk/domain/"), ("eu", "https://rdap.eu.org/domain/"), ("de", "https://rdap.denic.de/domain/"), ("ca", "https://rdap.cira.ca/domain/"), ("au", "https://rdap.auda.org.au/domain/"), ("fr", "https://rdap.nic.fr/domain/"), ("es", "https://rdap.nic.es/domain/"), ("it", "https://rdap.nic.it/domain/"), ("nl", "https://rdap.domain-registry.nl/domain/"), ("jp", "https://rdap.jprs.jp/domain/"), ("br", "https://rdap.registro.br/domain/"), ("in", "https://rdap.registry.in/domain/"), ("cn", "https://rdap.cnnic.cn/domain/"), ("tv", "https://rdap.verisign.com/tv/v1/domain/"), ("cc", "https://rdap.verisign.com/cc/v1/domain/"), ("zone", "https://rdap.nic.zone/domain/"),
100 ("cloud", "https://rdap.nic.cloud/domain/"),
101 ("digital", "https://rdap.nic.digital/domain/"),
102 ])
103}
104
105pub fn get_all_known_tlds() -> Vec<String> {
116 let registry = get_rdap_registry_map();
117 let mut tlds: Vec<String> = registry.keys().map(|k| k.to_string()).collect();
118 tlds.sort(); tlds
120}
121
122pub fn get_preset_tlds(preset: &str) -> Option<Vec<String>> {
144 match preset.to_lowercase().as_str() {
145 "startup" => Some(vec![
146 "com".to_string(),
147 "org".to_string(),
148 "io".to_string(),
149 "ai".to_string(),
150 "tech".to_string(),
151 "app".to_string(),
152 "dev".to_string(),
153 "xyz".to_string(),
154 ]),
155 "enterprise" => Some(vec![
156 "com".to_string(),
157 "org".to_string(),
158 "net".to_string(),
159 "info".to_string(),
160 "biz".to_string(),
161 "us".to_string(),
162 ]),
163 "country" => Some(vec![
164 "us".to_string(),
165 "uk".to_string(),
166 "de".to_string(),
167 "fr".to_string(),
168 "ca".to_string(),
169 "au".to_string(),
170 "jp".to_string(),
171 "br".to_string(),
172 "in".to_string(),
173 ]),
174 _ => None,
175 }
176}
177
178pub fn get_preset_tlds_with_custom(
204 preset: &str,
205 custom_presets: Option<&std::collections::HashMap<String, Vec<String>>>,
206) -> Option<Vec<String>> {
207 let preset_lower = preset.to_lowercase();
208
209 if let Some(custom_map) = custom_presets {
211 if let Some(custom_tlds) = custom_map
213 .get(preset)
214 .or_else(|| custom_map.get(&preset_lower))
215 {
216 return Some(custom_tlds.clone());
217 }
218 }
219
220 get_preset_tlds(&preset_lower)
222}
223
224pub fn get_available_presets() -> Vec<&'static str> {
232 vec!["startup", "enterprise", "country"]
233}
234
235#[allow(dead_code)]
248pub fn validate_preset_tlds(preset_tlds: &[String]) -> bool {
249 let registry = get_rdap_registry_map();
250 preset_tlds
251 .iter()
252 .all(|tld| registry.contains_key(tld.as_str()))
253}
254
255pub async fn get_rdap_endpoint(tld: &str, use_bootstrap: bool) -> Result<String, DomainCheckError> {
269 let tld_lower = tld.to_lowercase();
270
271 let registry = get_rdap_registry_map();
273 if let Some(endpoint) = registry.get(tld_lower.as_str()) {
274 return Ok(endpoint.to_string());
275 }
276
277 {
279 let cache = BOOTSTRAP_CACHE
280 .lock()
281 .map_err(|_| DomainCheckError::internal("Failed to acquire bootstrap cache lock"))?;
282
283 if !cache.is_stale() {
284 if let Some(endpoint) = cache.get(&tld_lower) {
285 return Ok(endpoint);
286 }
287 }
288 }
289
290 if use_bootstrap {
292 discover_rdap_endpoint(&tld_lower).await
293 } else {
294 Err(DomainCheckError::bootstrap(
295 &tld_lower,
296 "No known RDAP endpoint and bootstrap disabled",
297 ))
298 }
299}
300
301async fn discover_rdap_endpoint(tld: &str) -> Result<String, DomainCheckError> {
314 const BOOTSTRAP_URL: &str = "https://data.iana.org/rdap/dns.json";
315
316 let client = reqwest::Client::builder()
318 .timeout(Duration::from_secs(5))
319 .build()
320 .map_err(|e| {
321 DomainCheckError::network_with_source("Failed to create HTTP client", e.to_string())
322 })?;
323
324 let response = client.get(BOOTSTRAP_URL).send().await.map_err(|e| {
326 DomainCheckError::bootstrap(tld, format!("Failed to fetch bootstrap registry: {}", e))
327 })?;
328
329 if !response.status().is_success() {
330 return Err(DomainCheckError::bootstrap(
331 tld,
332 format!("Bootstrap registry returned HTTP {}", response.status()),
333 ));
334 }
335
336 let json: serde_json::Value = response.json().await.map_err(|e| {
337 DomainCheckError::bootstrap(tld, format!("Failed to parse bootstrap JSON: {}", e))
338 })?;
339
340 if let Some(services) = json.get("services").and_then(|s| s.as_array()) {
342 for service in services {
343 if let Some(service_array) = service.as_array() {
344 if service_array.len() >= 2 {
345 if let Some(tlds) = service_array[0].as_array() {
347 for t in tlds {
348 if let Some(t_str) = t.as_str() {
349 if t_str.to_lowercase() == tld.to_lowercase() {
350 if let Some(urls) = service_array[1].as_array() {
352 if let Some(url) = urls.first().and_then(|u| u.as_str()) {
353 let endpoint =
354 format!("{}/domain/", url.trim_end_matches('/'));
355
356 cache_discovered_endpoint(tld, &endpoint)?;
358
359 return Ok(endpoint);
360 }
361 }
362 }
363 }
364 }
365 }
366 }
367 }
368 }
369 }
370
371 Err(DomainCheckError::bootstrap(
372 tld,
373 "TLD not found in IANA bootstrap registry",
374 ))
375}
376
377fn cache_discovered_endpoint(tld: &str, endpoint: &str) -> Result<(), DomainCheckError> {
379 let mut cache = BOOTSTRAP_CACHE.lock().map_err(|_| {
380 DomainCheckError::internal("Failed to acquire bootstrap cache lock for writing")
381 })?;
382
383 cache.insert(tld.to_string(), endpoint.to_string());
384 Ok(())
385}
386
387pub fn extract_tld(domain: &str) -> Result<String, DomainCheckError> {
400 let parts: Vec<&str> = domain.split('.').collect();
401
402 if parts.len() < 2 {
403 return Err(DomainCheckError::invalid_domain(
404 domain,
405 "Domain must contain at least one dot",
406 ));
407 }
408
409 Ok(parts.last().unwrap().to_lowercase())
413}
414
415#[allow(dead_code)]
417pub fn clear_bootstrap_cache() -> Result<(), DomainCheckError> {
418 let mut cache = BOOTSTRAP_CACHE.lock().map_err(|_| {
419 DomainCheckError::internal("Failed to acquire bootstrap cache lock for clearing")
420 })?;
421
422 cache.endpoints.clear();
423 cache.last_update = Instant::now();
424 Ok(())
425}
426
427#[allow(dead_code)]
429pub fn get_bootstrap_cache_stats() -> Result<(usize, bool), DomainCheckError> {
430 let cache = BOOTSTRAP_CACHE.lock().map_err(|_| {
431 DomainCheckError::internal("Failed to acquire bootstrap cache lock for stats")
432 })?;
433
434 Ok((cache.endpoints.len(), cache.is_stale()))
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_extract_tld() {
443 assert_eq!(extract_tld("example.com").unwrap(), "com");
444 assert_eq!(extract_tld("test.org").unwrap(), "org");
445 assert_eq!(extract_tld("sub.example.com").unwrap(), "com");
446 assert!(extract_tld("invalid").is_err());
447 assert!(extract_tld("").is_err());
448 }
449
450 #[test]
451 fn test_registry_map_contains_common_tlds() {
452 let registry = get_rdap_registry_map();
453 assert!(registry.contains_key("com"));
454 assert!(registry.contains_key("org"));
455 assert!(registry.contains_key("net"));
456 assert!(registry.contains_key("io"));
457 }
458
459 #[tokio::test]
460 async fn test_get_rdap_endpoint_builtin() {
461 let endpoint = get_rdap_endpoint("com", false).await.unwrap();
462 assert!(endpoint.contains("verisign.com"));
463 }
464
465 #[tokio::test]
466 async fn test_get_rdap_endpoint_unknown_no_bootstrap() {
467 let result = get_rdap_endpoint("unknowntld123", false).await;
468 assert!(result.is_err());
469 }
470}
471
472#[cfg(test)]
473mod preset_tests {
474 use super::*;
475
476 #[test]
477 fn test_get_all_known_tlds() {
478 let tlds = get_all_known_tlds();
479
480 assert!(tlds.len() >= 30);
482 assert!(tlds.contains(&"com".to_string()));
483 assert!(tlds.contains(&"org".to_string()));
484 assert!(tlds.contains(&"io".to_string()));
485 assert!(tlds.contains(&"ai".to_string()));
486
487 let mut sorted_tlds = tlds.clone();
489 sorted_tlds.sort();
490 assert_eq!(tlds, sorted_tlds);
491 }
492
493 #[test]
494 fn test_startup_preset() {
495 let tlds = get_preset_tlds("startup").unwrap();
496
497 assert_eq!(tlds.len(), 8);
498 assert!(tlds.contains(&"com".to_string()));
499 assert!(tlds.contains(&"io".to_string()));
500 assert!(tlds.contains(&"ai".to_string()));
501 assert!(tlds.contains(&"tech".to_string()));
502
503 assert_eq!(get_preset_tlds("STARTUP"), get_preset_tlds("startup"));
505 }
506
507 #[test]
508 fn test_enterprise_preset() {
509 let tlds = get_preset_tlds("enterprise").unwrap();
510
511 assert_eq!(tlds.len(), 6);
512 assert!(tlds.contains(&"com".to_string()));
513 assert!(tlds.contains(&"org".to_string()));
514 assert!(tlds.contains(&"biz".to_string()));
515 }
516
517 #[test]
518 fn test_country_preset() {
519 let tlds = get_preset_tlds("country").unwrap();
520
521 assert_eq!(tlds.len(), 9);
522 assert!(tlds.contains(&"us".to_string()));
523 assert!(tlds.contains(&"uk".to_string()));
524 assert!(tlds.contains(&"de".to_string()));
525 }
526
527 #[test]
528 fn test_invalid_preset() {
529 assert!(get_preset_tlds("invalid").is_none());
530 assert!(get_preset_tlds("").is_none());
531 }
532
533 #[test]
534 fn test_available_presets() {
535 let presets = get_available_presets();
536 assert_eq!(presets.len(), 3);
537 assert!(presets.contains(&"startup"));
538 assert!(presets.contains(&"enterprise"));
539 assert!(presets.contains(&"country"));
540 }
541
542 #[test]
543 fn test_validate_preset_tlds() {
544 for preset_name in get_available_presets() {
546 let tlds = get_preset_tlds(preset_name).unwrap();
547 assert!(
548 validate_preset_tlds(&tlds),
549 "Preset '{}' contains TLDs without RDAP endpoints",
550 preset_name
551 );
552 }
553 }
554
555 #[test]
556 fn test_preset_tlds_subset_of_known() {
557 let all_tlds = get_all_known_tlds();
558
559 for preset_name in get_available_presets() {
560 let preset_tlds = get_preset_tlds(preset_name).unwrap();
561 for tld in preset_tlds {
562 assert!(
563 all_tlds.contains(&tld),
564 "Preset '{}' contains unknown TLD: {}",
565 preset_name,
566 tld
567 );
568 }
569 }
570 }
571}