use std::{
collections::HashMap,
net::{Ipv4Addr, Ipv6Addr},
sync::Arc,
};
use arc_swap::ArcSwap;
use crate::codec::{message::Qtype, name::Name};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecordData {
A(Ipv4Addr),
Aaaa(Ipv6Addr),
}
#[derive(Debug, PartialEq, Eq)]
pub enum LocalMatch {
Answer {
data: RecordData,
ttl: u32,
},
NameExistsNoData,
Miss,
}
#[derive(Debug, Clone, Default)]
pub struct NameEntry {
pub a: Option<(Ipv4Addr, u32)>,
pub aaaa: Option<(Ipv6Addr, u32)>,
}
impl NameEntry {
fn resolve(&self, qtype: Qtype) -> LocalMatch {
match qtype {
Qtype::A => match self.a {
Some((addr, ttl)) => LocalMatch::Answer {
data: RecordData::A(addr),
ttl,
},
None => LocalMatch::NameExistsNoData,
},
Qtype::Aaaa => match self.aaaa {
Some((addr, ttl)) => LocalMatch::Answer {
data: RecordData::Aaaa(addr),
ttl,
},
None => LocalMatch::NameExistsNoData,
},
Qtype::Other(_) => LocalMatch::NameExistsNoData,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum BuildError {
#[error("invalid wildcard pattern {0:?}: must have at least one label after `*.`")]
InvalidWildcard(String),
#[error("invalid domain name {0:?}: {1}")]
InvalidName(String, #[source] crate::codec::Error),
}
#[derive(Debug, Default)]
pub struct LocalRecords {
exact: HashMap<Name, NameEntry>,
wildcard: HashMap<Box<str>, NameEntry>,
}
impl LocalRecords {
pub fn builder() -> LocalRecordsBuilder {
LocalRecordsBuilder::default()
}
pub fn lookup(&self, qname: &Name, qtype: Qtype) -> LocalMatch {
if let Some(entry) = self.exact.get(qname) {
return entry.resolve(qtype);
}
if self.wildcard.is_empty() {
return LocalMatch::Miss;
}
let mut search = qname.as_str(); while let Some(dot_pos) = search.find('.') {
search = &search[dot_pos + 1..];
if search.is_empty() || search == "." {
break;
}
if let Some(entry) = self.wildcard.get(search) {
return entry.resolve(qtype);
}
}
LocalMatch::Miss
}
}
#[derive(Debug, Default)]
pub struct LocalRecordsBuilder {
exact: HashMap<Name, NameEntry>,
wildcard: HashMap<Box<str>, NameEntry>,
}
impl LocalRecordsBuilder {
pub fn add(&mut self, name: &str, data: RecordData, ttl: u32) -> Result<(), BuildError> {
if let Some(wildcard_suffix) = name.strip_prefix("*.") {
if wildcard_suffix.is_empty() || wildcard_suffix == "." {
return Err(BuildError::InvalidWildcard(name.to_owned()));
}
let suffix: Name = wildcard_suffix
.parse()
.map_err(|e| BuildError::InvalidName(name.to_owned(), e))?;
let entry = self.wildcard.entry(suffix.as_str().into()).or_default();
Self::fill_entry(entry, data, ttl);
} else if name == "*" {
return Err(BuildError::InvalidWildcard(name.to_owned()));
} else {
let parsed: Name = name
.parse()
.map_err(|e| BuildError::InvalidName(name.to_owned(), e))?;
let entry = self.exact.entry(parsed).or_default();
Self::fill_entry(entry, data, ttl);
}
Ok(())
}
pub fn build(self) -> LocalRecords {
LocalRecords {
exact: self.exact,
wildcard: self.wildcard,
}
}
fn fill_entry(entry: &mut NameEntry, data: RecordData, ttl: u32) {
match data {
RecordData::A(addr) => entry.a = Some((addr, ttl)),
RecordData::Aaaa(addr) => entry.aaaa = Some((addr, ttl)),
}
}
}
pub struct LocalMatcher {
inner: ArcSwap<LocalRecords>,
}
impl LocalMatcher {
pub fn new(records: LocalRecords) -> Self {
Self {
inner: ArcSwap::from_pointee(records),
}
}
pub fn empty() -> Self {
Self::new(LocalRecords::default())
}
pub fn lookup(&self, qname: &Name, qtype: Qtype) -> LocalMatch {
self.inner.load().lookup(qname, qtype)
}
pub fn store(&self, records: LocalRecords) {
self.inner.store(Arc::new(records));
}
}
impl Default for LocalMatcher {
fn default() -> Self {
Self::empty()
}
}
impl std::fmt::Debug for LocalMatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let snap = self.inner.load();
f.debug_struct("LocalMatcher")
.field("exact_len", &snap.exact.len())
.field("wildcard_len", &snap.wildcard.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use super::*;
use crate::codec::message::Qtype;
fn name(s: &str) -> Name {
s.parse().expect("valid domain name in test")
}
fn ipv4(s: &str) -> Ipv4Addr {
s.parse().unwrap()
}
fn ipv6(s: &str) -> Ipv6Addr {
s.parse().unwrap()
}
fn records(entries: &[(&str, RecordData, u32)]) -> LocalRecords {
let mut b = LocalRecords::builder();
for (n, d, ttl) in entries {
b.add(n, d.clone(), *ttl).expect("valid entry in test");
}
b.build()
}
#[test]
fn exact_a_hit() {
let r = records(&[("router.home.lan", RecordData::A(ipv4("192.168.1.1")), 300)]);
let result = r.lookup(&name("router.home.lan"), Qtype::A);
assert_eq!(
result,
LocalMatch::Answer {
data: RecordData::A(ipv4("192.168.1.1")),
ttl: 300,
}
);
}
#[test]
fn exact_aaaa_hit() {
let r = records(&[("host.lan", RecordData::Aaaa(ipv6("::1")), 60)]);
let result = r.lookup(&name("host.lan"), Qtype::Aaaa);
assert_eq!(
result,
LocalMatch::Answer {
data: RecordData::Aaaa(ipv6("::1")),
ttl: 60,
}
);
}
#[test]
fn both_a_and_aaaa_on_same_name() {
let mut b = LocalRecords::builder();
b.add("dual.lan", RecordData::A(ipv4("10.0.0.1")), 100)
.unwrap();
b.add("dual.lan", RecordData::Aaaa(ipv6("2001:db8::1")), 200)
.unwrap();
let r = b.build();
assert_eq!(
r.lookup(&name("dual.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("10.0.0.1")),
ttl: 100,
}
);
assert_eq!(
r.lookup(&name("dual.lan"), Qtype::Aaaa),
LocalMatch::Answer {
data: RecordData::Aaaa(ipv6("2001:db8::1")),
ttl: 200,
}
);
}
#[test]
fn only_a_query_aaaa_is_name_exists_no_data() {
let r = records(&[("a-only.lan", RecordData::A(ipv4("1.2.3.4")), 300)]);
assert_eq!(
r.lookup(&name("a-only.lan"), Qtype::Aaaa),
LocalMatch::NameExistsNoData
);
}
#[test]
fn only_aaaa_query_a_is_name_exists_no_data() {
let r = records(&[("aaaa-only.lan", RecordData::Aaaa(ipv6("::1")), 300)]);
assert_eq!(
r.lookup(&name("aaaa-only.lan"), Qtype::A),
LocalMatch::NameExistsNoData
);
}
#[test]
fn non_a_aaaa_qtype_on_local_name_is_name_exists_no_data() {
let r = records(&[("host.lan", RecordData::A(ipv4("1.2.3.4")), 300)]);
for qtype_val in [2u16, 15, 16, 255] {
assert_eq!(
r.lookup(&name("host.lan"), Qtype::Other(qtype_val)),
LocalMatch::NameExistsNoData,
"expected NameExistsNoData for Other({qtype_val})"
);
}
}
#[test]
fn unrelated_name_is_miss() {
let r = records(&[("known.lan", RecordData::A(ipv4("1.1.1.1")), 60)]);
assert_eq!(r.lookup(&name("unknown.lan"), Qtype::A), LocalMatch::Miss);
assert_eq!(r.lookup(&name("example.com"), Qtype::A), LocalMatch::Miss);
}
#[test]
fn wildcard_a_hit() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.99")), 120)]);
assert_eq!(
r.lookup(&name("device.home.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("10.0.0.99")),
ttl: 120,
}
);
}
#[test]
fn wildcard_multi_label_qname() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert_eq!(
r.lookup(&name("a.b.home.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("10.0.0.1")),
ttl: 300,
}
);
}
#[test]
fn wildcard_most_specific_wins() {
let mut b = LocalRecords::builder();
b.add("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)
.unwrap();
b.add("*.a.home.lan", RecordData::A(ipv4("10.0.0.2")), 300)
.unwrap();
let r = b.build();
assert_eq!(
r.lookup(&name("x.a.home.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("10.0.0.2")),
ttl: 300,
},
"x.a.home.lan should resolve via the more specific *.a.home.lan"
);
assert_eq!(
r.lookup(&name("y.home.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("10.0.0.1")),
ttl: 300,
},
"y.home.lan should resolve via *.home.lan"
);
}
#[test]
fn wildcard_does_not_match_the_suffix_itself() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert_eq!(
r.lookup(&name("home.lan"), Qtype::A),
LocalMatch::Miss,
"home.lan itself must be a Miss, not matched by *.home.lan"
);
}
#[test]
fn wildcard_does_not_match_substring_prefix() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert_eq!(
r.lookup(&name("evilhome.lan"), Qtype::A),
LocalMatch::Miss,
"evilhome.lan must be a Miss — substring match is forbidden"
);
}
#[test]
fn wildcard_matches_direct_child() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert!(
matches!(
r.lookup(&name("x.home.lan"), Qtype::A),
LocalMatch::Answer { .. }
),
"x.home.lan should be an Answer via *.home.lan"
);
}
#[test]
fn exact_beats_wildcard() {
let mut b = LocalRecords::builder();
b.add("*.home.lan", RecordData::A(ipv4("10.0.0.99")), 300)
.unwrap();
b.add("router.home.lan", RecordData::A(ipv4("192.168.1.1")), 3600)
.unwrap();
let r = b.build();
assert_eq!(
r.lookup(&name("router.home.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("192.168.1.1")),
ttl: 3600,
},
"exact entry must win over matching wildcard"
);
}
#[test]
fn wildcard_qtype_mismatch_is_name_exists_no_data() {
let r = records(&[("*.home.lan", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert_eq!(
r.lookup(&name("device.home.lan"), Qtype::Aaaa),
LocalMatch::NameExistsNoData
);
}
#[test]
fn matcher_store_swaps_snapshot() {
let r1 = records(&[("old.lan", RecordData::A(ipv4("1.1.1.1")), 60)]);
let matcher = LocalMatcher::new(r1);
assert!(matches!(
matcher.lookup(&name("old.lan"), Qtype::A),
LocalMatch::Answer { .. }
));
assert_eq!(matcher.lookup(&name("new.lan"), Qtype::A), LocalMatch::Miss);
let r2 = records(&[("new.lan", RecordData::A(ipv4("2.2.2.2")), 120)]);
matcher.store(r2);
assert_eq!(
matcher.lookup(&name("new.lan"), Qtype::A),
LocalMatch::Answer {
data: RecordData::A(ipv4("2.2.2.2")),
ttl: 120,
}
);
assert_eq!(matcher.lookup(&name("old.lan"), Qtype::A), LocalMatch::Miss);
}
#[test]
fn matcher_empty_returns_miss() {
let m = LocalMatcher::empty();
assert_eq!(m.lookup(&name("anything.com"), Qtype::A), LocalMatch::Miss);
}
#[test]
fn matcher_default_is_empty() {
let m = LocalMatcher::default();
assert_eq!(m.lookup(&name("anything.com"), Qtype::A), LocalMatch::Miss);
}
#[test]
fn builder_rejects_bare_star() {
let mut b = LocalRecords::builder();
let err = b.add("*", RecordData::A(ipv4("1.1.1.1")), 300).unwrap_err();
assert!(
matches!(err, BuildError::InvalidWildcard(_)),
"expected InvalidWildcard, got: {err}"
);
}
#[test]
fn builder_rejects_star_dot_only() {
let mut b = LocalRecords::builder();
let err = b
.add("*.", RecordData::A(ipv4("1.1.1.1")), 300)
.unwrap_err();
assert!(
matches!(err, BuildError::InvalidWildcard(_)),
"expected InvalidWildcard, got: {err}"
);
}
#[test]
fn builder_rejects_invalid_name() {
let mut b = LocalRecords::builder();
let err = b
.add("foo..bar", RecordData::A(ipv4("1.1.1.1")), 300)
.unwrap_err();
assert!(
matches!(err, BuildError::InvalidName(_, _)),
"expected InvalidName, got: {err}"
);
}
#[test]
fn builder_rejects_label_too_long() {
let long_label = "a".repeat(64);
let name_str = format!("{long_label}.lan");
let mut b = LocalRecords::builder();
let err = b
.add(&name_str, RecordData::A(ipv4("1.1.1.1")), 300)
.unwrap_err();
assert!(
matches!(err, BuildError::InvalidName(_, _)),
"expected InvalidName, got: {err}"
);
}
#[test]
fn matcher_debug_shows_lengths() {
let mut b = LocalRecords::builder();
b.add("a.lan", RecordData::A(ipv4("1.1.1.1")), 60).unwrap();
b.add("*.b.lan", RecordData::A(ipv4("2.2.2.2")), 60)
.unwrap();
let m = LocalMatcher::new(b.build());
let s = format!("{m:?}");
assert!(s.contains("exact_len: 1"), "debug: {s}");
assert!(s.contains("wildcard_len: 1"), "debug: {s}");
}
#[test]
fn lookup_is_case_insensitive() {
let r = records(&[("Host.LAN", RecordData::A(ipv4("1.2.3.4")), 300)]);
assert!(
matches!(
r.lookup(&name("host.lan"), Qtype::A),
LocalMatch::Answer { .. }
),
"lookup for lowercase must hit"
);
assert!(
matches!(
r.lookup(&name("HOST.LAN"), Qtype::A),
LocalMatch::Answer { .. }
),
"lookup for uppercase must hit"
);
}
#[test]
fn wildcard_lookup_is_case_insensitive() {
let r = records(&[("*.Home.LAN", RecordData::A(ipv4("10.0.0.1")), 300)]);
assert!(
matches!(
r.lookup(&name("DEVICE.home.lan"), Qtype::A),
LocalMatch::Answer { .. }
),
"wildcard lookup must be case-insensitive"
);
}
}