use std::fmt::Debug;
use bytes::Bytes;
use crate::base::iana::Opcode;
use crate::base::{Message, ParsedName, Rtype};
use crate::rdata::{Soa, ZoneRecordData};
use crate::zonetree::types::ZoneUpdate;
use super::iterator::XfrZoneUpdateIterator;
use super::types::{Error, IxfrUpdateMode, ParsedRecord, XfrType};
use super::IterationError;
#[derive(Default)]
pub struct XfrResponseInterpreter {
inner: Option<Inner>,
}
impl XfrResponseInterpreter {
pub fn new() -> Self {
Self::default()
}
}
impl XfrResponseInterpreter {
pub fn interpret_response(
&mut self,
resp: Message<Bytes>,
) -> Result<XfrZoneUpdateIterator<'_, '_>, Error> {
if self.is_finished() {
return Err(Error::Finished);
}
self.check_response(&resp)?;
if let Some(inner) = &mut self.inner {
inner.resp = resp;
} else {
self.initialize(resp)?;
}
let inner = self.inner.as_mut().unwrap();
XfrZoneUpdateIterator::new(&mut inner.processor, &inner.resp)
}
pub fn is_finished(&self) -> bool {
self.inner
.as_ref()
.map(|inner| inner.processor.is_finished())
.unwrap_or_default()
}
}
impl XfrResponseInterpreter {
fn initialize(&mut self, resp: Message<Bytes>) -> Result<(), Error> {
self.inner = Some(Inner::new(resp)?);
Ok(())
}
fn check_response(&self, resp: &Message<Bytes>) -> Result<(), Error> {
let resp_header = resp.header();
let resp_counts = resp.header_counts();
if resp.is_error()
|| !resp_header.qr()
|| resp_header.opcode() != Opcode::QUERY
|| resp_header.tc()
|| resp_counts.ancount() == 0
|| resp_counts.nscount() != 0
{
return Err(Error::NotValidXfrResponse);
}
let qdcount = resp_counts.qdcount();
let first_message = self.inner.is_none();
if (first_message && qdcount != 1) || (!first_message && qdcount > 1)
{
return Err(Error::NotValidXfrResponse);
}
Ok(())
}
}
struct Inner {
resp: Message<Bytes>,
processor: RecordProcessor,
}
impl Inner {
fn new(resp: Message<Bytes>) -> Result<Self, Error> {
let answer = resp.answer().map_err(Error::ParseError)?;
let mut records = answer.limit_to();
let xfr_type = match resp.qtype() {
Some(Rtype::AXFR) => XfrType::Axfr,
Some(Rtype::IXFR) => XfrType::Ixfr,
_ => unreachable!(),
};
let Some(Ok(record)) = records.next() else {
return Err(Error::Malformed);
};
let ZoneRecordData::Soa(soa) = record.into_data() else {
return Err(Error::NotValidXfrResponse);
};
let state = RecordProcessor::new(xfr_type, soa);
Ok(Inner {
resp,
processor: state,
})
}
}
#[derive(Debug)]
pub(super) struct RecordProcessor {
actual_xfr_type: XfrType,
initial_soa: Soa<ParsedName<Bytes>>,
current_soa: Soa<ParsedName<Bytes>>,
ixfr_update_mode: IxfrUpdateMode,
rr_count: usize,
axfr_delete_already_returned: bool,
finished: bool,
}
impl RecordProcessor {
fn new(
initial_xfr_type: XfrType,
initial_soa: Soa<ParsedName<Bytes>>,
) -> Self {
let ixfr_update_mode = IxfrUpdateMode::Adding;
Self {
actual_xfr_type: initial_xfr_type,
initial_soa: initial_soa.clone(),
current_soa: initial_soa,
rr_count: 0,
ixfr_update_mode,
axfr_delete_already_returned: false,
finished: false,
}
}
pub(super) fn finish(&mut self) {
self.finished = true;
}
#[allow(clippy::type_complexity)]
pub(super) fn process_record(
&mut self,
rec: ParsedRecord,
) -> Result<
Option<(ZoneUpdate<ParsedRecord>, Option<ZoneUpdate<ParsedRecord>>)>,
IterationError,
> {
if self.finished {
return Err(IterationError::AlreadyFinished);
}
self.rr_count += 1;
let soa = match rec.data() {
ZoneRecordData::Soa(soa) => Some(soa),
_ => None,
};
let record_matches_initial_soa = soa == Some(&self.initial_soa);
let update = match self.actual_xfr_type {
XfrType::Axfr | XfrType::Ixfr if self.rr_count == 1 => {
if soa.is_none() {
return Err(IterationError::MissingInitialSoa);
} else {
return Ok(None);
}
}
XfrType::Axfr if record_matches_initial_soa => {
ZoneUpdate::Finished(rec)
}
XfrType::Axfr => ZoneUpdate::AddRecord(rec),
XfrType::Ixfr
if self.rr_count == 2 && rec.rtype() != Rtype::SOA =>
{
self.actual_xfr_type = XfrType::Axfr;
ZoneUpdate::AddRecord(rec)
}
XfrType::Ixfr => {
if let Some(soa) = soa {
self.ixfr_update_mode.toggle();
self.current_soa = soa.clone();
match self.ixfr_update_mode {
IxfrUpdateMode::Deleting => {
if record_matches_initial_soa {
ZoneUpdate::Finished(rec)
} else {
ZoneUpdate::BeginBatchDelete(rec)
}
}
IxfrUpdateMode::Adding => {
ZoneUpdate::BeginBatchAdd(rec)
}
}
} else {
match self.ixfr_update_mode {
IxfrUpdateMode::Deleting => {
ZoneUpdate::DeleteRecord(rec)
}
IxfrUpdateMode::Adding => ZoneUpdate::AddRecord(rec),
}
}
}
};
if matches!(update, ZoneUpdate::Finished(_)) {
self.finished = true;
}
let updates = if self.actual_xfr_type == XfrType::Axfr
&& !self.axfr_delete_already_returned
{
self.axfr_delete_already_returned = true;
(ZoneUpdate::DeleteAllRecords, Some(update))
} else {
(update, None)
};
Ok(Some(updates))
}
pub fn rr_count(&self) -> usize {
self.rr_count
}
pub fn actual_xfr_type(&self) -> XfrType {
self.actual_xfr_type
}
pub fn is_finished(&self) -> bool {
self.finished
}
}