use core::future::{ready, Future};
use core::pin::Pin;
use core::str::FromStr;
use std::collections::HashMap;
use std::future::pending;
use std::io::BufReader;
use std::process::exit;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use domain::rdata::{Soa, ZoneRecordData};
use octseq::Octets;
use rand::distr::Alphanumeric;
use rand::RngExt;
use tokio::net::{TcpListener, UdpSocket};
use tracing_subscriber::EnvFilter;
use domain::base::iana::{Class, Rcode};
use domain::base::name::OwnedLabel;
use domain::base::net::IpAddr;
use domain::base::{Name, Rtype, Serial, ToName, Ttl};
use domain::net::server::buf::VecBufSource;
use domain::net::server::dgram::DgramServer;
use domain::net::server::message::Request;
#[cfg(feature = "siphasher")]
use domain::net::server::middleware::cookies::CookiesMiddlewareSvc;
use domain::net::server::middleware::edns::EdnsMiddlewareSvc;
use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc;
use domain::net::server::middleware::notify::{
Notifiable, NotifyError, NotifyMiddlewareSvc,
};
use domain::net::server::middleware::tsig::TsigMiddlewareSvc;
use domain::net::server::middleware::xfr::{
XfrData, XfrDataProvider, XfrDataProviderError, XfrMiddlewareSvc,
};
use domain::net::server::service::{CallResult, ServiceResult};
use domain::net::server::stream::StreamServer;
use domain::net::server::util::{mk_builder_for_target, service_fn};
use domain::tsig::{Algorithm, Key, KeyName};
use domain::zonefile::inplace;
use domain::zonetree::{
Answer, InMemoryZoneDiff, Rrset, SharedRrset, StoredName,
};
use domain::zonetree::{Zone, ZoneTree};
#[tokio::main()]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_thread_ids(true)
.without_time()
.try_init()
.ok();
let mut key_store = HashMap::<(KeyName, Algorithm), Key>::new();
let key_name = KeyName::from_str("demo-key").unwrap();
let secret = domain::utils::base64::decode::<Vec<u8>>(
"zlCZbVJPIhobIs1gJNQfrsS3xCxxsR9pMUrGwG8OgG8=",
)
.unwrap();
let key =
Key::new(Algorithm::Sha256, &secret, key_name.clone(), None, None)
.unwrap();
key_store.insert((key_name, Algorithm::Sha256), key);
let zone_bytes = include_bytes!("../test-data/zonefiles/nsd-example.txt");
let mut zone_bytes = BufReader::new(&zone_bytes[..]);
let mut zones = ZoneTree::new();
let reader =
inplace::Zonefile::load(&mut zone_bytes).unwrap_or_else(|err| {
eprintln!("Error reading zone file bytes: {err}");
exit(1);
});
let zone = Zone::try_from(reader).unwrap_or_else(|errors| {
eprintln!(
"{} zone file entries could not be parsed, aborting:",
errors.len()
);
for (name, err) in errors {
eprintln!(" {name}: {err}");
}
exit(1);
});
let new_refresh_ttl = Ttl::from_secs(15);
let mut writable_zone = zone.write().await;
let writable_zone_node = writable_zone.open(false).await.unwrap();
let Some(soa_rrset) =
writable_zone_node.get_rrset(Rtype::SOA).await.unwrap()
else {
panic!("Loaded zone has no SOA RRSET at the apex!");
};
let soa_rr = soa_rrset.first().unwrap();
let soa_rdata = soa_rr.data();
let ZoneRecordData::Soa(soa_rdata) = soa_rdata else {
unreachable!();
};
let new_soa_rdata = Soa::new(
soa_rdata.mname().clone(),
soa_rdata.rname().clone(),
soa_rdata.serial(),
new_refresh_ttl,
soa_rdata.retry(),
soa_rdata.expire(),
soa_rdata.minimum(),
);
let mut new_soa_rrset = Rrset::new(Rtype::SOA, soa_rrset.ttl());
new_soa_rrset.push_data(ZoneRecordData::Soa(new_soa_rdata));
writable_zone_node
.update_rrset(new_soa_rrset.into_shared())
.await
.unwrap();
drop(writable_zone_node);
writable_zone.commit(false).await.unwrap();
drop(writable_zone);
zones.insert_zone(zone).unwrap();
let zones = Arc::new(zones);
let zones_and_diffs = ZoneTreeWithDiffs::new(zones.clone());
let addr = "127.0.0.1:8053";
let svc = service_fn(my_service, zones.clone());
#[cfg(feature = "siphasher")]
let svc = XfrMiddlewareSvc::new(svc, zones_and_diffs.clone(), 1);
let svc = NotifyMiddlewareSvc::new(svc, DemoNotifyTarget);
let svc = CookiesMiddlewareSvc::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::new(svc);
let svc = TsigMiddlewareSvc::new(svc, key_store);
let svc = MandatoryMiddlewareSvc::new(svc);
let svc = Arc::new(svc);
let sock = UdpSocket::bind(&addr).await.unwrap();
let sock = Arc::new(sock);
let mut udp_metrics = vec![];
let num_cores = std::thread::available_parallelism().unwrap().get();
for _i in 0..num_cores {
let udp_srv =
DgramServer::new(sock.clone(), VecBufSource, svc.clone());
let metrics = udp_srv.metrics();
udp_metrics.push(metrics);
tokio::spawn(async move { udp_srv.run().await });
}
let sock = TcpListener::bind(addr).await.unwrap();
let tcp_srv = StreamServer::new(sock, VecBufSource, svc);
let tcp_metrics = tcp_srv.metrics();
tokio::spawn(async move { tcp_srv.run().await });
eprintln!("Listening on {addr}");
eprintln!("Try:");
eprintln!(" dig @127.0.0.1 -p 8053 example.com");
eprintln!(" dig @127.0.0.1 -p 8053 example.com AXFR");
eprintln!(" dig @127.0.0.1 -p 8053 -y hmac-sha256:demo-key:zlCZbVJPIhobIs1gJNQfrsS3xCxxsR9pMUrGwG8OgG8= example.com AXFR");
eprintln!(" dig @127.0.0.1 -p 8053 +opcode=notify example.com SOA");
eprintln!(" cargo run --example ixfr-client --all-features -- 127.0.0.1:8053 example.com 2020080302");
eprintln!();
eprintln!("Tip: set env var RUST_LOG=info (or debug or trace) for more log output.");
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(5000)).await;
let mut udp_num_connections = 0;
let mut udp_num_inflight_requests = 0;
let mut udp_num_pending_writes = 0;
let mut udp_num_received_requests = 0;
let mut udp_num_sent_responses = 0;
for metrics in udp_metrics.iter() {
udp_num_connections += metrics.num_connections();
udp_num_inflight_requests += metrics.num_inflight_requests();
udp_num_pending_writes += metrics.num_pending_writes();
udp_num_received_requests += metrics.num_received_requests();
udp_num_sent_responses += metrics.num_sent_responses();
}
eprintln!(
"Server status: #conn/#in-flight/#pending-writes/#msgs-recvd/#msgs-sent: UDP={}/{}/{}/{}/{} TCP={}/{}/{}/{}/{}",
udp_num_connections,
udp_num_inflight_requests,
udp_num_pending_writes,
udp_num_received_requests,
udp_num_sent_responses,
tcp_metrics.num_connections(),
tcp_metrics.num_inflight_requests(),
tcp_metrics.num_pending_writes(),
tcp_metrics.num_received_requests(),
tcp_metrics.num_sent_responses(),
);
}
});
tokio::spawn(async move {
let zone_name = Name::<Vec<u8>>::from_str("example.com").unwrap();
let mut label: Option<OwnedLabel> = None;
loop {
tokio::time::sleep(Duration::from_millis(10000)).await;
let zone = zones.get_zone(&zone_name, Class::IN).unwrap();
let mut writer = zone.write().await;
{
let node = writer.open(true).await.unwrap();
if let Some(old_label) = label {
let node = node.update_child(&old_label).await.unwrap();
node.remove_rrset(Rtype::A).await.unwrap();
}
let random_string: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
let new_label = OwnedLabel::from_str(&random_string).unwrap();
let node = node.update_child(&new_label).await.unwrap();
let mut rrset = Rrset::new(Rtype::A, Ttl::from_secs(60));
let rec = domain::rdata::A::new("127.0.0.1".parse().unwrap());
rrset.push_data(rec.into());
node.update_rrset(SharedRrset::new(rrset)).await.unwrap();
label = Some(new_label);
}
let diff = writer.commit(true).await.unwrap();
if let Some(diff) = diff {
zones_and_diffs.add_diff(diff);
}
eprintln!(
"Added {} A record to zone example.com",
label.unwrap()
);
}
});
pending::<()>().await;
}
#[allow(clippy::type_complexity)]
fn my_service<RequestMeta>(
request: Request<Vec<u8>, RequestMeta>,
zones: Arc<ZoneTree>,
) -> ServiceResult<Vec<u8>> {
let question = request.message().sole_question().unwrap();
let zone = zones
.find_zone(question.qname(), question.qclass())
.map(|zone| zone.read());
let answer = match zone {
Some(zone) => {
let qname = question.qname().to_bytes();
let qtype = question.qtype();
zone.query(qname, qtype).unwrap()
}
None => Answer::new(Rcode::NXDOMAIN),
};
let builder = mk_builder_for_target();
let additional = answer.to_message(request.message(), builder);
Ok(CallResult::new(additional))
}
#[derive(Copy, Clone, Default, Debug)]
struct DemoNotifyTarget;
impl Notifiable for DemoNotifyTarget {
fn notify_zone_changed(
&self,
class: Class,
apex_name: &StoredName,
serial: Option<Serial>,
source: IpAddr,
) -> Pin<
Box<dyn Future<Output = Result<(), NotifyError>> + Sync + Send + '_>,
> {
eprintln!("Notify received from {source} of change to zone {apex_name} in class {class} with serial {serial:?}");
let res = match apex_name.to_string().to_lowercase().as_str() {
"example.com" => Ok(()),
"othererror.com" => Err(NotifyError::Other),
_ => Err(NotifyError::NotAuthForZone),
};
Box::pin(ready(res))
}
}
#[derive(Clone)]
struct ZoneTreeWithDiffs {
zones: Arc<ZoneTree>,
diffs: Arc<Mutex<Vec<InMemoryZoneDiff>>>,
}
impl ZoneTreeWithDiffs {
fn new(zones: Arc<ZoneTree>) -> Self {
Self {
zones,
diffs: Default::default(),
}
}
fn add_diff(&self, diff: InMemoryZoneDiff) {
self.diffs.lock().unwrap().push(diff);
}
fn get_diffs(&self, diff_from: Option<Serial>) -> Vec<InMemoryZoneDiff> {
let diffs = self.diffs.lock().unwrap();
if let Some(idx) = diffs
.iter()
.position(|diff| Some(diff.start_serial) == diff_from)
{
diffs[idx..].to_vec()
} else {
vec![]
}
}
}
impl XfrDataProvider<Option<Key>> for ZoneTreeWithDiffs {
type Diff = InMemoryZoneDiff;
fn request<Octs>(
&self,
req: &Request<Octs, Option<Key>>,
diff_from: Option<Serial>,
) -> Pin<
Box<
dyn Future<
Output = Result<
XfrData<Self::Diff>,
XfrDataProviderError,
>,
> + Sync
+ Send,
>,
>
where
Octs: Octets + Send + Sync,
{
if req.metadata().is_none() {
eprintln!("Rejecting request due to missing TSIG key");
return Box::pin(ready(Err(XfrDataProviderError::Refused)));
}
let res = req
.message()
.sole_question()
.map_err(XfrDataProviderError::ParseError)
.and_then(|q| {
if let Some(zone) =
self.zones.find_zone(q.qname(), q.qclass())
{
Ok(XfrData::new(
zone.clone(),
self.get_diffs(diff_from),
false,
))
} else {
Err(XfrDataProviderError::UnknownZone)
}
});
Box::pin(ready(res))
}
}