use std::{
collections::HashMap,
fmt,
net::{IpAddr, SocketAddr},
};
use tokio_util::task::TaskTracker;
use tracing::warn;
use crate::{
codec::{message::Question, name::Name},
resolver::upstream::{
self, DEFAULT_QUERY_TIMEOUT, ForwardResult, RandomSelector, UpstreamConfig, UpstreamHealth,
UpstreamPool, UpstreamTransport,
},
storage::forward_zones::ForwardZone,
};
const DEFAULT_FORWARD_PORT: u16 = 53;
pub struct ForwardZoneSet {
zones: HashMap<Box<str>, SocketAddr>,
forwarders: HashMap<SocketAddr, UpstreamPool>,
health: UpstreamHealth,
}
impl ForwardZoneSet {
#[must_use]
pub fn empty() -> Self {
Self {
zones: HashMap::new(),
forwarders: HashMap::new(),
health: UpstreamHealth::new(),
}
}
pub async fn build(rows: &[ForwardZone], tracker: &TaskTracker) -> Self {
let mut zones: HashMap<Box<str>, SocketAddr> = HashMap::new();
for row in rows {
let Some(target) = row.target.as_deref().and_then(parse_target) else {
warn!(
zone = %row.zone_suffix,
target = ?row.target,
"forward zone has no usable target; skipping"
);
continue;
};
let Ok(suffix) = row.zone_suffix.parse::<Name>() else {
warn!(zone = %row.zone_suffix, "invalid forward-zone suffix; skipping");
continue;
};
zones.insert(suffix.as_str().into(), target);
}
let mut forwarders: HashMap<SocketAddr, UpstreamPool> = HashMap::new();
for &target in zones.values() {
if forwarders.contains_key(&target) {
continue;
}
let config = UpstreamConfig {
addr: target,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
};
let pool = UpstreamPool::connect(
std::slice::from_ref(&config),
tracker,
std::sync::Arc::new(RandomSelector),
0,
DEFAULT_QUERY_TIMEOUT,
)
.await;
forwarders.insert(target, pool);
}
Self {
zones,
forwarders,
health: UpstreamHealth::new(),
}
}
#[must_use]
pub fn match_target(&self, qname: &Name) -> Option<SocketAddr> {
if self.zones.is_empty() {
return None;
}
let mut search = qname.as_str();
loop {
if let Some(&target) = self.zones.get(search) {
return Some(target);
}
match search.find('.') {
Some(pos) => {
search = &search[pos + 1..];
if search.is_empty() || search == "." {
return None;
}
}
None => return None,
}
}
}
pub async fn forward(
&self,
target: SocketAddr,
question: &Question,
) -> upstream::Result<ForwardResult> {
match self.forwarders.get(&target) {
Some(pool) => pool.forward(question, &self.health).await,
None => Err(upstream::Error::AllUpstreamsFailed { attempts: 0 }),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.zones.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.zones.is_empty()
}
}
impl Default for ForwardZoneSet {
fn default() -> Self {
Self::empty()
}
}
impl fmt::Debug for ForwardZoneSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ForwardZoneSet")
.field("zones", &self.zones.len())
.field("forwarders", &self.forwarders.len())
.finish_non_exhaustive()
}
}
fn parse_target(s: &str) -> Option<SocketAddr> {
if let Ok(sa) = s.parse::<SocketAddr>() {
Some(sa)
} else {
s.parse::<IpAddr>()
.ok()
.map(|ip| SocketAddr::new(ip, DEFAULT_FORWARD_PORT))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn name(s: &str) -> Name {
s.parse().expect("valid name")
}
fn zone_set(pairs: &[(&str, &str)]) -> ForwardZoneSet {
let mut zones = HashMap::new();
for (suffix, target) in pairs {
let n: Name = suffix.parse().unwrap();
zones.insert(n.as_str().into(), parse_target(target).unwrap());
}
ForwardZoneSet {
zones,
forwarders: HashMap::new(),
health: UpstreamHealth::new(),
}
}
#[test]
fn parse_target_ip_and_socket() {
assert_eq!(
parse_target("192.168.1.1"),
Some("192.168.1.1:53".parse().unwrap())
);
assert_eq!(
parse_target("10.0.0.1:5353"),
Some("10.0.0.1:5353".parse().unwrap())
);
assert_eq!(
parse_target("fd00::1"),
Some("[fd00::1]:53".parse().unwrap())
);
assert_eq!(parse_target("not-an-ip"), None);
}
#[test]
fn match_subdomain_and_apex() {
let set = zone_set(&[("168.192.in-addr.arpa", "192.168.1.1")]);
let target: SocketAddr = "192.168.1.1:53".parse().unwrap();
assert_eq!(
set.match_target(&name("5.1.168.192.in-addr.arpa")),
Some(target)
);
assert_eq!(
set.match_target(&name("168.192.in-addr.arpa")),
Some(target)
);
}
#[test]
fn non_matching_name_is_none() {
let set = zone_set(&[("168.192.in-addr.arpa", "192.168.1.1")]);
assert_eq!(set.match_target(&name("example.com")), None);
assert_eq!(set.match_target(&name("5.1.10.in-addr.arpa")), None);
}
#[test]
fn empty_set_matches_nothing() {
let set = ForwardZoneSet::empty();
assert!(set.is_empty());
assert_eq!(set.match_target(&name("5.1.168.192.in-addr.arpa")), None);
}
#[test]
fn most_specific_zone_wins() {
let set = zone_set(&[
("10.in-addr.arpa", "10.0.0.1"),
("0.10.in-addr.arpa", "10.0.0.2"),
]);
let specific: SocketAddr = "10.0.0.2:53".parse().unwrap();
let general: SocketAddr = "10.0.0.1:53".parse().unwrap();
assert_eq!(
set.match_target(&name("5.1.0.10.in-addr.arpa")),
Some(specific)
);
assert_eq!(
set.match_target(&name("5.1.9.10.in-addr.arpa")),
Some(general)
);
}
#[tokio::test]
async fn build_dedups_forwarders_by_target() {
let rows = vec![
ForwardZone {
id: 1,
zone_suffix: "10.in-addr.arpa".to_owned(),
target: Some("10.0.0.1".to_owned()),
enabled: true,
sort_order: 0,
},
ForwardZone {
id: 2,
zone_suffix: "168.192.in-addr.arpa".to_owned(),
target: Some("10.0.0.1".to_owned()),
enabled: true,
sort_order: 1,
},
ForwardZone {
id: 3,
zone_suffix: "16.172.in-addr.arpa".to_owned(),
target: Some("10.0.0.2:5353".to_owned()),
enabled: true,
sort_order: 2,
},
];
let tracker = TaskTracker::new();
let set = ForwardZoneSet::build(&rows, &tracker).await;
assert_eq!(set.len(), 3, "three zones mapped");
assert_eq!(
set.forwarders.len(),
2,
"two distinct targets → two forwarders"
);
}
#[tokio::test]
async fn build_skips_bad_target_and_suffix() {
let rows = vec![
ForwardZone {
id: 1,
zone_suffix: "168.192.in-addr.arpa".to_owned(),
target: Some("not-an-ip".to_owned()),
enabled: true,
sort_order: 0,
},
ForwardZone {
id: 2,
zone_suffix: "10.in-addr.arpa".to_owned(),
target: Some("10.0.0.1".to_owned()),
enabled: true,
sort_order: 1,
},
];
let tracker = TaskTracker::new();
let set = ForwardZoneSet::build(&rows, &tracker).await;
assert_eq!(set.len(), 1);
assert!(set.match_target(&name("5.1.10.in-addr.arpa")).is_some());
assert!(
set.match_target(&name("5.1.168.192.in-addr.arpa"))
.is_none()
);
}
}