use std::{net::IpAddr, sync::Arc, time::Duration};
use bytes::Bytes;
use moka::future::Cache;
use tower::{Service, ServiceExt as _};
use crate::{
codec::{
header::Header,
message::{Qtype, Query},
name::Name,
writer::Writer,
},
resolver::pipeline::{BoxError, DnsRequest, PipelineResponse},
};
const DEFAULT_CAPACITY: u64 = 4_096;
const DEFAULT_TTL: Duration = Duration::from_secs(15 * 60);
fn synthetic_client() -> std::net::SocketAddr {
std::net::SocketAddr::from(([127, 0, 0, 1], 0))
}
pub struct ReverseResolver<S> {
service: std::sync::Mutex<S>,
cache: Cache<IpAddr, Option<Name>>,
}
impl<S> ReverseResolver<S>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
pub fn new(service: S) -> Self {
Self::with_bounds(service, DEFAULT_CAPACITY, DEFAULT_TTL)
}
pub fn with_bounds(service: S, capacity: u64, ttl: Duration) -> Self {
let cache = Cache::builder()
.max_capacity(capacity)
.time_to_live(ttl)
.build();
Self {
service: std::sync::Mutex::new(service),
cache,
}
}
pub async fn lookup(&self, ip: IpAddr) -> Option<Name> {
let service = self
.service
.lock()
.expect("reverse-lookup service mutex poisoned")
.clone();
self.cache.get_with(ip, Self::resolve(service, ip)).await
}
pub async fn cached(&self, ip: IpAddr) -> Option<Name> {
self.cache.get(&ip).await.flatten()
}
pub fn clear(&self) {
self.cache.invalidate_all();
}
pub fn warm(self: &Arc<Self>, ip: IpAddr) {
let this = Arc::clone(self);
tokio::spawn(async move {
this.lookup(ip).await;
});
}
async fn resolve(service: S, ip: IpAddr) -> Option<Name> {
let raw = ptr_query_datagram(&Name::reverse_query(ip));
let query = Query::try_from(raw).ok()?;
let request = DnsRequest::new(query, synthetic_client());
let response = service.oneshot(request).await.ok()?;
extract_ptr(&response.bytes)
}
}
fn ptr_query_datagram(name: &Name) -> Bytes {
let mut w = Writer::with_capacity(64);
Header::new(0).with_qdcount(1).with_rd(true).write(&mut w);
name.write(&mut w);
w.write_u16(u16::from(Qtype::Ptr));
w.write_u16(1); w.finish()
}
fn extract_ptr(bytes: &[u8]) -> Option<Name> {
use hickory_net::proto::op::Message;
use hickory_net::proto::rr::RData;
let message = Message::from_vec(bytes).ok()?;
message
.answers
.iter()
.find_map(|record| match &record.data {
RData::PTR(ptr) => ptr.to_string().parse::<Name>().ok(),
_ => None,
})
}
pub type InternalDnsService = tower::util::BoxCloneService<DnsRequest, PipelineResponse, BoxError>;
pub type SharedReverseResolver = Arc<ReverseResolver<InternalDnsService>>;
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use std::sync::atomic::{AtomicUsize, Ordering};
use bytes::Bytes;
use tower::ServiceExt as _;
use super::*;
use crate::codec::{
header::{Header, Rcode},
name::Name,
writer::Writer,
};
use crate::resolver::pipeline::{DnsRequest, Outcome, PipelineResponse};
fn ptr_response(req: &DnsRequest, target: &str) -> Bytes {
let mut w = Writer::with_capacity(128);
Header::new(req.header().id)
.with_qr(true)
.with_rcode(Rcode::NoError)
.with_qdcount(1)
.with_ancount(1)
.write(&mut w);
req.question().name.write(&mut w);
w.write_u16(u16::from(Qtype::Ptr));
w.write_u16(1);
w.write_u8(0xC0);
w.write_u8(0x0C);
w.write_u16(u16::from(Qtype::Ptr));
w.write_u16(1); w.write_u32(300); let target_name: Name = target.parse().unwrap();
let mut rdata = Writer::with_capacity(64);
target_name.write(&mut rdata);
let rdata = rdata.finish();
w.write_u16(rdata.len() as u16);
w.write_slice(&rdata);
w.finish()
}
#[derive(Clone)]
struct StubService {
target: &'static str,
calls: Arc<AtomicUsize>,
}
impl Service<DnsRequest> for StubService {
type Response = PipelineResponse;
type Error = BoxError;
type Future = std::future::Ready<Result<PipelineResponse, BoxError>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
self.calls.fetch_add(1, Ordering::SeqCst);
let bytes = ptr_response(&req, self.target);
std::future::ready(Ok(PipelineResponse::new(bytes, Outcome::Local)))
}
}
#[derive(Clone)]
struct EmptyService {
calls: Arc<AtomicUsize>,
}
impl Service<DnsRequest> for EmptyService {
type Response = PipelineResponse;
type Error = BoxError;
type Future = std::future::Ready<Result<PipelineResponse, BoxError>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
self.calls.fetch_add(1, Ordering::SeqCst);
let mut w = Writer::with_capacity(32);
Header::new(req.header().id)
.with_qr(true)
.with_qdcount(1)
.write(&mut w);
req.question().name.write(&mut w);
w.write_u16(u16::from(Qtype::Ptr));
w.write_u16(1);
std::future::ready(Ok(PipelineResponse::new(w.finish(), Outcome::Servfail)))
}
}
#[tokio::test]
async fn resolves_and_caches_a_hostname() {
let calls = Arc::new(AtomicUsize::new(0));
let svc = StubService {
target: "router.home.lan",
calls: Arc::clone(&calls),
};
let resolver = ReverseResolver::new(svc);
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let name = resolver.lookup(ip).await.expect("hostname");
assert_eq!(name.to_string(), "router.home.lan.");
assert_eq!(calls.load(Ordering::SeqCst), 1);
let again = resolver.lookup(ip).await.expect("cached hostname");
assert_eq!(again.to_string(), "router.home.lan.");
assert_eq!(calls.load(Ordering::SeqCst), 1, "must not re-query");
}
#[tokio::test]
async fn negative_result_is_cached() {
let calls = Arc::new(AtomicUsize::new(0));
let svc = EmptyService {
calls: Arc::clone(&calls),
};
let resolver = ReverseResolver::new(svc);
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 7));
assert!(resolver.lookup(ip).await.is_none(), "no PTR → None");
assert!(resolver.lookup(ip).await.is_none(), "still None");
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"negative result must be cached (one query only)"
);
}
#[derive(Clone)]
struct ToggleService {
target: Arc<std::sync::Mutex<Option<&'static str>>>,
calls: Arc<AtomicUsize>,
}
impl Service<DnsRequest> for ToggleService {
type Response = PipelineResponse;
type Error = BoxError;
type Future = std::future::Ready<Result<PipelineResponse, BoxError>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
self.calls.fetch_add(1, Ordering::SeqCst);
let bytes = match *self.target.lock().unwrap() {
Some(target) => ptr_response(&req, target),
None => {
let mut w = Writer::with_capacity(32);
Header::new(req.header().id)
.with_qr(true)
.with_qdcount(1)
.write(&mut w);
req.question().name.write(&mut w);
w.write_u16(u16::from(Qtype::Ptr));
w.write_u16(1);
w.finish()
}
};
std::future::ready(Ok(PipelineResponse::new(bytes, Outcome::Local)))
}
}
#[tokio::test]
async fn clear_drops_a_sticky_negative_entry() {
let target = Arc::new(std::sync::Mutex::new(None));
let calls = Arc::new(AtomicUsize::new(0));
let resolver = ReverseResolver::new(ToggleService {
target: Arc::clone(&target),
calls: Arc::clone(&calls),
});
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 9));
assert!(resolver.lookup(ip).await.is_none());
assert_eq!(calls.load(Ordering::SeqCst), 1);
*target.lock().unwrap() = Some("late.lan");
assert!(
resolver.lookup(ip).await.is_none(),
"stale negative result is served from the cache"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"must not re-query while the negative entry is cached"
);
resolver.clear();
assert_eq!(
resolver.lookup(ip).await.map(|n| n.to_string()),
Some("late.lan.".to_owned())
);
assert_eq!(calls.load(Ordering::SeqCst), 2, "clear() forces a re-query");
}
#[tokio::test]
async fn cached_does_not_trigger_a_lookup() {
let calls = Arc::new(AtomicUsize::new(0));
let svc = StubService {
target: "host.lan",
calls: Arc::clone(&calls),
};
let resolver = ReverseResolver::new(svc);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5));
assert!(resolver.cached(ip).await.is_none());
assert_eq!(calls.load(Ordering::SeqCst), 0, "cached() must not query");
resolver.lookup(ip).await.expect("hostname");
assert_eq!(
resolver.cached(ip).await.map(|n| n.to_string()),
Some("host.lan.".to_owned())
);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn warm_populates_the_cache() {
let calls = Arc::new(AtomicUsize::new(0));
let svc = StubService {
target: "warm.lan",
calls: Arc::clone(&calls),
};
let resolver = Arc::new(ReverseResolver::new(svc));
let ip = IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3));
resolver.warm(ip);
let mut name = None;
for _ in 0..100 {
if let Some(n) = resolver.cached(ip).await {
name = Some(n);
break;
}
tokio::task::yield_now().await;
}
assert_eq!(
name.map(|n| n.to_string()),
Some("warm.lan.".to_owned()),
"warm must populate the cache"
);
}
#[tokio::test]
async fn resolves_local_record_through_real_engine() {
use std::time::Duration;
use crate::resolver::{
local::{LocalRecords, RecordData},
pipeline::engine::build_internal_service,
state::ResolverState,
upstream::{RandomSelector, SharedUpstreamPool, UpstreamPool},
};
use tokio_util::task::TaskTracker;
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let mut b = LocalRecords::builder();
b.add(
"router.home.lan",
RecordData::A("192.168.1.1".parse().unwrap()),
300,
)
.unwrap();
state.local().store(b.build());
let tracker = TaskTracker::new();
let pool = Arc::new(SharedUpstreamPool::new(
UpstreamPool::connect(
&[],
&tracker,
Arc::new(RandomSelector),
0,
Duration::from_millis(500),
)
.await,
));
let resolver = ReverseResolver::new(build_internal_service(state, pool));
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert_eq!(
resolver.lookup(ip).await.map(|n| n.to_string()),
Some("router.home.lan.".to_owned()),
"local IP must resolve via the E13 reverse index"
);
let unknown = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 200));
assert!(resolver.lookup(unknown).await.is_none());
}
#[tokio::test]
async fn stub_service_shape_round_trips() {
let svc = StubService {
target: "x.lan",
calls: Arc::new(AtomicUsize::new(0)),
};
let raw = ptr_query_datagram(&Name::reverse_query(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))));
let req = DnsRequest::new(Query::try_from(raw).unwrap(), synthetic_client());
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(
extract_ptr(&resp.bytes).map(|n| n.to_string()),
Some("x.lan.".to_owned())
);
}
}