use anyhow::{Context as AnyhowContext, Result};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::{net::IpAddr, sync::Arc};
use tokio::sync::Mutex;
use tracing::{debug, warn};
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};
#[derive(Debug, Clone)]
pub struct DnsResolver {
resolver: Arc<TokioAsyncResolver>,
cache: sled::Tree,
cache_hits: Arc<Mutex<u64>>,
cache_misses: Arc<Mutex<u64>>,
is_test: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResolverResult {
pub ips: Vec<IpAddr>,
pub timestamp: u64,
pub ttl: u64,
}
impl DnsResolver {
pub async fn new(cache_dir: &str, cache_size: usize) -> Result<Self> {
let resolver =
TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
let db = sled::Config::new()
.path(format!("{}/dns_cache", cache_dir))
.cache_capacity((cache_size * 1024 * 1024) as u64) .mode(sled::Mode::HighThroughput)
.open()
.context("Failed to open DNS cache database")?;
let cache = db
.open_tree("dns_cache")
.context("Failed to open DNS cache tree")?;
Ok(Self {
resolver: Arc::new(resolver),
cache,
cache_hits: Arc::new(Mutex::new(0)),
cache_misses: Arc::new(Mutex::new(0)),
is_test: false,
})
}
#[allow(dead_code)]
pub fn new_for_testing() -> Result<Self> {
let db = sled::Config::new()
.temporary(true)
.open()
.context("Failed to create temporary DNS cache database")?;
let cache = db
.open_tree("dns_cache")
.context("Failed to open DNS cache tree")?;
let resolver = TokioAsyncResolver::tokio_from_system_conf()
.context("Failed to create DNS resolver from system configuration")?;
Ok(Self {
resolver: Arc::new(resolver),
cache,
cache_hits: Arc::new(Mutex::new(0)),
cache_misses: Arc::new(Mutex::new(0)),
is_test: true,
})
}
#[allow(dead_code)]
pub fn is_test_resolver(&self) -> bool {
self.is_test
}
pub async fn lookup(&self, domain: &str) -> Result<Option<String>> {
if let Some(cached_result) = self.get_from_cache(domain)? {
let mut hits = self.cache_hits.lock().await;
*hits += 1;
debug!("๐ Cache hit for domain: {}", domain);
return Ok(cached_result.ips.first().map(|ip| ip.to_string()));
}
debug!("๐ Resolving domain: {}", domain);
let mut hits = self.cache_misses.lock().await;
*hits += 1;
if self.is_test {
let test_ip = "192.0.2.1"; debug!("๐ Test resolver returning {} for {}", test_ip, domain);
let result = ResolverResult {
ips: vec![test_ip.parse().unwrap()],
timestamp: Utc::now().timestamp() as u64,
ttl: 3600, };
self.add_to_cache(domain, &result)?;
return Ok(Some(test_ip.to_string()));
}
let lookup_result = match self.resolver.lookup_ip(domain).await {
Ok(lookup) => lookup.iter().next().map(|addr| addr.to_string()),
Err(e) => {
warn!("โ Failed to resolve domain {}: {}", domain, e);
let result = ResolverResult {
ips: vec![],
timestamp: Utc::now().timestamp() as u64,
ttl: 0,
};
self.add_to_cache(domain, &result)?;
None
}
};
debug!("๐ Resolved domain {} to {:?}", domain, lookup_result);
if let Some(ip) = &lookup_result {
let result = ResolverResult {
ips: vec![ip.parse().unwrap()],
timestamp: Utc::now().timestamp() as u64,
ttl: 3600, };
self.add_to_cache(domain, &result)?;
}
Ok(lookup_result)
}
fn add_to_cache(&self, domain: &str, result: &ResolverResult) -> Result<()> {
let serialized =
serde_json::to_vec(result).context("Failed to serialize resolver result")?;
self.cache
.insert(domain.as_bytes(), serialized)
.context("Failed to write to cache")?;
Ok(())
}
fn get_from_cache(&self, domain: &str) -> Result<Option<ResolverResult>> {
if let Some(cached_bytes) = self.cache.get(domain.as_bytes())? {
let result: ResolverResult = serde_json::from_slice(&cached_bytes)
.context("Failed to deserialize cached resolver result")?;
let now = Utc::now().timestamp() as u64;
let age = now - result.timestamp;
if age < result.ttl {
return Ok(Some(result));
}
}
Ok(None)
}
#[allow(dead_code)]
pub async fn flush_cache(&self) -> Result<()> {
self.cache.clear().context("Failed to clear DNS cache")?;
debug!("๐งน DNS cache flushed");
Ok(())
}
#[allow(dead_code)]
pub async fn show_cache_status(&self) -> Result<()> {
let count = self.cache.len();
debug!("๐ DNS cache contains {} entries", count);
Ok(())
}
}
pub async fn flush_cache() -> Result<()> {
let _resolver = TokioAsyncResolver::tokio_from_system_conf()
.context("Failed to create DNS resolver from system configuration")?;
let db = sled::Config::new()
.path("./cache/dns_cache") .open()
.context("Failed to open DNS cache database")?;
let cache = db
.open_tree("dns_cache")
.context("Failed to open DNS cache tree")?;
cache.clear().context("Failed to clear DNS cache")?;
debug!("๐งน DNS cache flushed");
Ok(())
}
pub async fn show_cache_status() -> Result<()> {
let _resolver = TokioAsyncResolver::tokio_from_system_conf()
.context("Failed to create DNS resolver from system configuration")?;
let db = sled::Config::new()
.path("./cache/dns_cache") .open()
.context("Failed to open DNS cache database")?;
let cache = db
.open_tree("dns_cache")
.context("Failed to open DNS cache tree")?;
let count = cache.len();
debug!("๐ DNS cache contains {} entries", count);
Ok(())
}