use core::future::Future;
use core::marker::PhantomData;
use core::pin::Pin;
use std::boxed::Box;
use bytes::Bytes;
use tracing::trace;
use crate::base::name::{FlattenInto, Label};
use crate::base::scan::ScannerError;
use crate::base::{Name, ParsedName, Record, Rtype, ToName};
use crate::rdata::ZoneRecordData;
use crate::zonetree::{Rrset, SharedRrset};
use super::error::OutOfZone;
use super::types::ZoneUpdate;
use super::util::rel_name_rev_iter;
use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
pub struct ZoneUpdater<N = ParsedName<Bytes>> {
zone: Zone,
write: ReopenableZoneWriter,
state: ZoneUpdaterState,
_phantom: PhantomData<N>,
}
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
pub fn new(
zone: Zone,
) -> Pin<Box<dyn Future<Output = Result<Self, Error>> + Send>> {
Box::pin(async move {
let write = ReopenableZoneWriter::new(zone.clone()).await?;
Ok(Self {
zone,
write,
state: Default::default(),
_phantom: PhantomData,
})
})
}
}
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
pub async fn apply(
&mut self,
update: ZoneUpdate<Record<N, ZoneRecordData<Bytes, N>>>,
) -> Result<Option<InMemoryZoneDiff>, Error> {
trace!("Update: {update}");
if self.state == ZoneUpdaterState::Finished {
return Err(Error::Finished);
}
match update {
ZoneUpdate::DeleteAllRecords => {
self.write.remove_all().await?;
}
ZoneUpdate::DeleteRecord(rec) => {
self.delete_record_from_rrset(rec).await?
}
ZoneUpdate::AddRecord(rec) => {
self.add_record_to_rrset(rec).await?
}
ZoneUpdate::BeginBatchDelete(_old_soa) => {
let diff = self.write.commit().await?;
self.write.reopen().await?;
self.state = ZoneUpdaterState::Batching;
return Ok(diff);
}
ZoneUpdate::BeginBatchAdd(new_soa) => {
self.update_soa(new_soa).await?;
self.state = ZoneUpdaterState::Batching;
}
ZoneUpdate::Finished(zone_soa) => {
self.update_soa(zone_soa).await?;
let diff = self.write.commit().await?;
self.write.close()?;
self.state = ZoneUpdaterState::Finished;
return Ok(diff);
}
}
Ok(None)
}
pub fn is_finished(&self) -> bool {
self.state == ZoneUpdaterState::Finished
}
}
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
async fn get_writable_child_node_for_owner(
&mut self,
rec: &Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<Option<Box<dyn WritableZoneNode>>, Error> {
let mut it = rel_name_rev_iter(self.zone.apex_name(), rec.owner())?;
let Some(label) = it.next() else {
return Ok(None);
};
let mut child_node = self.write.update_child(label).await?;
for label in it {
child_node = child_node.update_child(label).await?;
}
Ok(Some(child_node))
}
async fn update_soa(
&mut self,
new_soa: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error> {
if new_soa.rtype() != Rtype::SOA {
return Err(Error::NotSoaRecord);
}
let mut rrset = Rrset::new(Rtype::SOA, new_soa.ttl());
let Ok(flattened) = new_soa.data().clone().try_flatten_into() else {
return Err(Error::IoError(std::io::Error::custom(
"Unable to flatten bytes",
)));
};
rrset.push_data(flattened);
self.write
.update_root_rrset(SharedRrset::new(rrset))
.await?;
Ok(())
}
async fn delete_record_from_rrset(
&mut self,
rec: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error> {
let tree_node = self.get_writable_child_node_for_owner(&rec).await?;
let tree_node = tree_node.as_ref().unwrap_or(self.write.root());
let mut rrset = Rrset::new(rec.rtype(), rec.ttl());
let rtype = rec.rtype();
let data = rec.data();
if let Some(existing_rrset) = tree_node.get_rrset(rtype).await? {
for existing_data in existing_rrset.data() {
if existing_data != data {
rrset.push_data(existing_data.clone());
}
}
}
if rrset.is_empty() {
tree_node.remove_rrset(rrset.rtype()).await?;
} else {
tree_node.update_rrset(SharedRrset::new(rrset)).await?;
}
Ok(())
}
async fn add_record_to_rrset(
&mut self,
rec: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error>
where
ZoneRecordData<Bytes, N>:
FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
let tree_node = self.get_writable_child_node_for_owner(&rec).await?;
let tree_node = tree_node.as_ref().unwrap_or(self.write.root());
let mut rrset = Rrset::new(rec.rtype(), rec.ttl());
let rtype = rec.rtype();
let Ok(data) = rec.into_data().try_flatten_into() else {
return Err(Error::IoError(std::io::Error::custom(
"Unable to flatten bytes",
)));
};
rrset.push_data(data);
if let Some(existing_rrset) = tree_node.get_rrset(rtype).await? {
for existing_data in existing_rrset.data() {
rrset.push_data(existing_data.clone());
}
}
tree_node.update_rrset(SharedRrset::new(rrset)).await?;
Ok(())
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
enum ZoneUpdaterState {
#[default]
Normal,
Batching,
Finished,
}
struct ReopenableZoneWriter {
write: Option<Box<dyn WritableZone>>,
writable: Option<Box<dyn WritableZoneNode>>,
}
impl ReopenableZoneWriter {
async fn new(zone: Zone) -> std::io::Result<Self> {
let write = zone.write().await;
let writable = Some(write.open(true).await?);
let write = Some(write);
Ok(Self { write, writable })
}
async fn commit(&mut self) -> Result<Option<InMemoryZoneDiff>, Error> {
if let Some(writable) = self.writable.take() {
drop(writable);
let diff = self
.write
.as_mut()
.ok_or(Error::Finished)?
.commit(false)
.await?;
Ok(diff)
} else {
Ok(None)
}
}
async fn reopen(&mut self) -> Result<(), Error> {
self.writable = Some(
self.write
.as_mut()
.ok_or(Error::Finished)?
.open(true)
.await?,
);
Ok(())
}
fn close(&mut self) -> Result<(), Error> {
self.writable.take();
self.write.take().ok_or(Error::Finished)?;
Ok(())
}
async fn remove_all(&mut self) -> std::io::Result<()> {
if let Some(writable) = &mut self.writable {
writable.remove_all().await?;
}
Ok(())
}
async fn update_child(
&self,
label: &Label,
) -> std::io::Result<Box<dyn WritableZoneNode>> {
self.root().update_child(label).await
}
async fn update_root_rrset(
&self,
rrset: SharedRrset,
) -> std::io::Result<()> {
self.root().update_rrset(rrset).await
}
#[allow(clippy::borrowed_box)]
fn root(&self) -> &Box<dyn WritableZoneNode> {
self.writable.as_ref().unwrap()
}
}
#[cfg(test)]
mod tests {
use core::str::FromStr;
use core::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::vec::Vec;
use bytes::BytesMut;
use octseq::Octets;
use crate::base::iana::{Class, Rcode};
use crate::base::message_builder::{AnswerBuilder, QuestionBuilder};
use crate::base::net::Ipv4Addr;
use crate::base::rdata::ComposeRecordData;
use crate::base::{
Message, MessageBuilder, Name, ParsedName, Record, Serial, Ttl,
};
use crate::logging::init_logging;
use crate::net::xfr::protocol::XfrResponseInterpreter;
use crate::rdata::{Ns, Soa, A};
use crate::zonetree::ZoneBuilder;
use super::*;
#[tokio::test]
async fn write_soa_read_soa() {
init_logging();
let zone = mk_empty_zone("example.com");
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
let qname = Name::from_str("example.com").unwrap();
let s = Serial::now();
let soa = mk_soa(s);
let soa_data = ZoneRecordData::Soa(soa.clone());
let soa_rec = Record::new(
ParsedName::from(qname.clone()),
Class::IN,
Ttl::from_secs(0),
soa_data,
);
updater
.apply(ZoneUpdate::AddRecord(soa_rec.clone()))
.await
.unwrap();
let diff = updater
.apply(ZoneUpdate::Finished(soa_rec.clone()))
.await
.unwrap();
let query = MessageBuilder::new_vec();
let mut query = query.question();
query.push((qname.clone(), Rtype::SOA)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::SOA)
.unwrap()
.to_message(&message, builder)
.into();
let found_soa_rec = answer
.answer()
.unwrap()
.limit_to::<Soa<_>>()
.next()
.unwrap()
.unwrap()
.into_data();
assert_eq!(found_soa_rec, soa);
assert!(diff.is_none());
}
#[tokio::test]
async fn diff_check() {
init_logging();
let zone = mk_empty_zone("example.com");
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
let qname = Name::from_str("example.com").unwrap();
let s = Serial(20240922);
let soa = mk_soa(s);
let soa_data = ZoneRecordData::Soa(soa.clone());
let soa_rec = Record::new(
ParsedName::from(qname.clone()),
Class::IN,
Ttl::from_secs(0),
soa_data,
);
updater
.apply(ZoneUpdate::AddRecord(soa_rec.clone()))
.await
.unwrap();
let diff = updater
.apply(ZoneUpdate::Finished(soa_rec.clone()))
.await
.unwrap();
assert!(diff.is_none());
let soa = mk_soa(s.add(1));
let soa_data = ZoneRecordData::Soa(soa.clone());
let soa_rec = Record::new(
ParsedName::from(qname.clone()),
Class::IN,
Ttl::from_secs(0),
soa_data,
);
assert!(updater.is_finished());
let res = updater.apply(ZoneUpdate::AddRecord(soa_rec.clone())).await;
assert!(matches!(res, Err(crate::zonetree::update::Error::Finished)));
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
updater
.apply(ZoneUpdate::AddRecord(soa_rec.clone()))
.await
.unwrap();
let diff = updater
.apply(ZoneUpdate::Finished(soa_rec.clone()))
.await
.unwrap();
let query = MessageBuilder::new_vec();
let mut query = query.question();
query.push((qname.clone(), Rtype::SOA)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::SOA)
.unwrap()
.to_message(&message, builder)
.into();
let found_soa_rec = answer
.answer()
.unwrap()
.limit_to::<Soa<_>>()
.next()
.unwrap()
.unwrap()
.into_data();
assert_eq!(found_soa_rec, soa);
assert!(diff.is_some());
let diff = diff.unwrap();
assert_eq!(diff.start_serial, Serial(20240922));
assert_eq!(diff.end_serial, Serial(20240923));
}
#[tokio::test]
async fn axfr_response_generates_expected_events() {
init_logging();
let zone = mk_empty_zone("example.com");
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
let req = mk_request("example.com", Rtype::AXFR).into_message();
let mut interpreter = XfrResponseInterpreter::new();
let mut answer = mk_empty_answer(&req, Rcode::NOERROR);
let serial = Serial::now();
let soa = mk_soa(serial);
add_answer_record(&req, &mut answer, soa.clone());
let a_1 = A::new(Ipv4Addr::LOCALHOST);
add_answer_record(&req, &mut answer, a_1.clone());
let a_2 = A::new(Ipv4Addr::BROADCAST);
add_answer_record(&req, &mut answer, a_2.clone());
add_answer_record(&req, &mut answer, soa.clone());
let resp = answer.into_message();
let it = interpreter.interpret_response(resp).unwrap();
for update in it {
let update = update.unwrap();
updater.apply(update).await.unwrap();
}
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("example.com").unwrap();
query.push((qname.clone(), Rtype::SOA)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::SOA)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers = answer.answer().unwrap().limit_to::<Soa<_>>();
assert_eq!(answers.next().unwrap().unwrap().into_data(), soa);
assert_eq!(answers.next(), None);
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("example.com").unwrap();
query.push((qname.clone(), Rtype::A)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::A)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers = answer.answer().unwrap().limit_to::<A>();
assert_eq!(answers.next().unwrap().unwrap().into_data(), a_2);
assert_eq!(answers.next().unwrap().unwrap().into_data(), a_1);
assert_eq!(answers.next(), None);
}
#[tokio::test]
async fn rfc_1995_ixfr_example() {
fn mk_rfc_1995_ixfr_example_soa(
serial: u32,
) -> Record<ParsedName<Bytes>, ZoneRecordData<Bytes, ParsedName<Bytes>>>
{
Record::new(
ParsedName::from(Name::from_str("JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
Soa::new(
ParsedName::from(
Name::from_str("NS.JAIN.AD.JP.").unwrap(),
),
ParsedName::from(
Name::from_str("mohta.jain.ad.jp.").unwrap(),
),
Serial(serial),
Ttl::from_secs(600),
Ttl::from_secs(600),
Ttl::from_secs(3600000),
Ttl::from_secs(604800),
)
.into(),
)
}
init_logging();
let zone = mk_empty_zone("JAIN.AD.JP.");
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
let soa_1 = mk_rfc_1995_ixfr_example_soa(1);
updater
.apply(ZoneUpdate::AddRecord(soa_1.clone()))
.await
.unwrap();
let ns_1 = Record::new(
ParsedName::from(Name::from_str("JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
Ns::new(ParsedName::from(
Name::from_str("NS.JAIN.AD.JP.").unwrap(),
))
.into(),
);
updater
.apply(ZoneUpdate::AddRecord(ns_1.clone()))
.await
.unwrap();
let a_1 = Record::new(
ParsedName::from(Name::from_str("NS.JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 1)).into(),
);
updater
.apply(ZoneUpdate::AddRecord(a_1.clone()))
.await
.unwrap();
let nezu = Record::new(
ParsedName::from(Name::from_str("NEZU.JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 5)).into(),
);
updater
.apply(ZoneUpdate::AddRecord(nezu.clone()))
.await
.unwrap();
let diff_1 = updater
.apply(ZoneUpdate::BeginBatchDelete(soa_1.clone()))
.await
.unwrap();
updater
.apply(ZoneUpdate::DeleteRecord(nezu.clone()))
.await
.unwrap();
let soa_2 = mk_rfc_1995_ixfr_example_soa(2);
updater
.apply(ZoneUpdate::BeginBatchAdd(soa_2.clone()))
.await
.unwrap();
let a_2 = Record::new(
ParsedName::from(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 4)).into(),
);
updater
.apply(ZoneUpdate::AddRecord(a_2.clone()))
.await
.unwrap();
let a_3 = Record::new(
ParsedName::from(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(192, 41, 197, 2)).into(),
);
updater
.apply(ZoneUpdate::AddRecord(a_3.clone()))
.await
.unwrap();
let diff_2 = updater
.apply(ZoneUpdate::BeginBatchDelete(soa_2.clone()))
.await
.unwrap();
updater
.apply(ZoneUpdate::DeleteRecord(a_2.clone()))
.await
.unwrap();
let soa_3 = mk_rfc_1995_ixfr_example_soa(3);
updater
.apply(ZoneUpdate::BeginBatchAdd(soa_3.clone()))
.await
.unwrap();
let a_4 = Record::new(
ParsedName::from(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 3)).into(),
);
updater
.apply(ZoneUpdate::AddRecord(a_4.clone()))
.await
.unwrap();
let diff_3 = updater
.apply(ZoneUpdate::Finished(soa_3.clone()))
.await
.unwrap();
let count = Arc::new(AtomicUsize::new(0));
let cloned_count = count.clone();
zone.read()
.walk(Box::new(move |_name, _rrset, _at_zone_cut| {
cloned_count.fetch_add(1, Ordering::SeqCst);
}));
assert_eq!(count.load(Ordering::SeqCst), 4);
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("JAIN.AD.JP.").unwrap();
query.push((qname.clone(), Rtype::SOA)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::SOA)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers =
answer.answer().unwrap().limit_to::<ZoneRecordData<_, _>>();
assert_eq!(answers.next().unwrap().unwrap(), soa_3);
assert_eq!(answers.next(), None);
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("JAIN.AD.JP.").unwrap();
query.push((qname.clone(), Rtype::NS)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::NS)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers =
answer.answer().unwrap().limit_to::<ZoneRecordData<_, _>>();
assert_eq!(answers.next().unwrap().unwrap(), ns_1);
assert_eq!(answers.next(), None);
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("NS.JAIN.AD.JP.").unwrap();
query.push((qname.clone(), Rtype::A)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::A)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers =
answer.answer().unwrap().limit_to::<ZoneRecordData<_, _>>();
assert_eq!(answers.next().unwrap().unwrap(), a_1);
assert_eq!(answers.next(), None);
let query = MessageBuilder::new_vec();
let mut query = query.question();
let qname = Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap();
query.push((qname.clone(), Rtype::A)).unwrap();
let message: Message<Vec<u8>> = query.into();
let builder = MessageBuilder::new_bytes();
let answer: Message<Bytes> = zone
.read()
.query(qname, Rtype::A)
.unwrap()
.to_message(&message, builder)
.into();
let mut answers =
answer.answer().unwrap().limit_to::<ZoneRecordData<_, _>>();
assert_eq!(answers.next().unwrap().unwrap(), a_4);
assert_eq!(answers.next().unwrap().unwrap(), a_3);
assert_eq!(answers.next(), None);
assert!(diff_1.is_none());
assert!(diff_2.is_some());
let diff_2 = diff_2.unwrap();
assert_eq!(diff_2.start_serial, Serial(1));
assert_eq!(diff_2.end_serial, Serial(2));
assert_eq!(diff_2.removed.len(), 2);
let mut expected = vec![nezu.into_data()];
let mut actual = diff_2
.removed
.get(&(Name::from_str("NEZU.JAIN.AD.JP.").unwrap(), Rtype::A))
.unwrap()
.data()
.to_vec();
expected.sort();
actual.sort();
assert_eq!(expected, actual);
assert_eq!(diff_2.added.len(), 2);
let mut expected = vec![a_2.clone().into_data(), a_3.into_data()];
let mut actual = diff_2
.added
.get(&(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap(), Rtype::A))
.unwrap()
.data()
.to_vec();
expected.sort();
actual.sort();
assert_eq!(expected, actual);
assert!(diff_3.is_some());
let diff_3 = diff_3.unwrap();
assert_eq!(diff_3.start_serial, Serial(2));
assert_eq!(diff_3.end_serial, Serial(3));
assert_eq!(diff_3.removed.len(), 2);
let mut expected = vec![a_2.into_data()];
let mut actual = diff_3
.removed
.get(&(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap(), Rtype::A))
.unwrap()
.data()
.to_vec();
expected.sort();
actual.sort();
assert_eq!(expected, actual);
assert_eq!(diff_3.added.len(), 2);
let mut expected = vec![a_4.into_data()];
let mut actual = diff_3
.added
.get(&(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap(), Rtype::A))
.unwrap()
.data()
.to_vec();
expected.sort();
actual.sort();
assert_eq!(expected, actual);
}
#[tokio::test]
async fn check_rollback() {
init_logging();
let zone = mk_empty_zone("example.com");
let mut updater = ZoneUpdater::new(zone.clone()).await.unwrap();
let req = mk_request("example.com", Rtype::AXFR).into_message();
let mut interpreter = XfrResponseInterpreter::new();
let mut answer = mk_empty_answer(&req, Rcode::NOERROR);
let serial = Serial::now();
let soa = mk_soa(serial);
add_answer_record(&req, &mut answer, soa.clone());
let a_1 = A::new(Ipv4Addr::LOCALHOST);
add_answer_record(&req, &mut answer, a_1.clone());
let a_2 = A::new(Ipv4Addr::BROADCAST);
add_answer_record(&req, &mut answer, a_2.clone());
add_answer_record(&req, &mut answer, soa.clone());
let resp = answer.into_message();
let it = interpreter.interpret_response(resp).unwrap();
for update in it {
let update = update.unwrap();
if !matches!(update, ZoneUpdate::Finished(_)) {
updater.apply(update).await.unwrap();
}
}
drop(updater);
let count = Arc::new(AtomicUsize::new(0));
let cloned_count = count.clone();
zone.read()
.walk(Box::new(move |_name, _rrset, _at_zone_cut| {
cloned_count.fetch_add(1, Ordering::SeqCst);
}));
assert_eq!(count.load(Ordering::SeqCst), 0);
}
fn mk_empty_zone(apex_name: &str) -> Zone {
ZoneBuilder::new(Name::from_str(apex_name).unwrap(), Class::IN)
.build()
}
fn mk_soa(serial: Serial) -> Soa<ParsedName<Bytes>> {
let mname = ParsedName::from(Name::from_str("mname").unwrap());
let rname = ParsedName::from(Name::from_str("rname").unwrap());
let ttl = Ttl::from_secs(0);
Soa::new(mname, rname, serial, ttl, ttl, ttl, ttl)
}
fn mk_request(qname: &str, qtype: Rtype) -> QuestionBuilder<BytesMut> {
let req = MessageBuilder::new_bytes();
let mut req = req.question();
req.push((Name::vec_from_str(qname).unwrap(), qtype))
.unwrap();
req
}
fn mk_empty_answer(
req: &Message<Bytes>,
rcode: Rcode,
) -> AnswerBuilder<BytesMut> {
let builder = MessageBuilder::new_bytes();
builder.start_answer(req, rcode).unwrap()
}
fn add_answer_record<O: Octets, T: ComposeRecordData>(
req: &Message<O>,
answer: &mut AnswerBuilder<BytesMut>,
item: T,
) {
let question = req.sole_question().unwrap();
let qname = question.qname();
let qclass = question.qclass();
answer
.push((qname, qclass, Ttl::from_secs(0), item))
.unwrap();
}
}
#[derive(Debug)]
pub enum Error {
OutOfZone,
NotSoaRecord,
IoError(std::io::Error),
Finished,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Error::OutOfZone => f.write_str("OutOfZone"),
Error::NotSoaRecord => f.write_str("NotSoaRecord"),
Error::IoError(err) => write!(f, "I/O error: {err}"),
Error::Finished => f.write_str("Finished"),
}
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Self::IoError(err)
}
}
impl From<OutOfZone> for Error {
fn from(_: OutOfZone) -> Self {
Self::OutOfZone
}
}