use core::future::{ready, Future};
use core::pin::Pin;
use core::task::{Context, Poll};
use std::boxed::Box;
use std::collections::{hash_map, HashMap};
use std::ops;
use std::sync::Arc;
use std::vec::Vec;
use bytes::Bytes;
use futures_util::stream;
use serde::{Deserialize, Serialize};
use tracing::trace;
use super::traits::{ZoneDiff, ZoneDiffItem};
use crate::base::name::Name;
use crate::base::rdata::RecordData;
use crate::base::record::Record;
use crate::base::{iana::Rtype, Ttl};
use crate::base::{Serial, ToName};
use crate::rdata::ZoneRecordData;
pub type StoredName = Name<Bytes>;
pub type StoredRecordData = ZoneRecordData<Bytes, StoredName>;
pub type StoredRecord = Record<StoredName, StoredRecordData>;
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct SharedRr {
ttl: Ttl,
data: StoredRecordData,
}
impl SharedRr {
pub fn new(ttl: Ttl, data: StoredRecordData) -> Self {
SharedRr { ttl, data }
}
pub fn rtype(&self) -> Rtype {
self.data.rtype()
}
pub fn ttl(&self) -> Ttl {
self.ttl
}
pub fn data(&self) -> &StoredRecordData {
&self.data
}
}
impl From<StoredRecord> for SharedRr {
fn from(record: StoredRecord) -> Self {
SharedRr {
ttl: record.ttl(),
data: record.into_data(),
}
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Rrset {
rtype: Rtype,
ttl: Ttl,
data: Vec<StoredRecordData>,
}
impl Rrset {
pub fn new(rtype: Rtype, ttl: Ttl) -> Self {
Rrset {
rtype,
ttl,
data: Vec::new(),
}
}
pub fn rtype(&self) -> Rtype {
self.rtype
}
pub fn ttl(&self) -> Ttl {
self.ttl
}
pub fn data(&self) -> &[StoredRecordData] {
&self.data
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn first(&self) -> Option<SharedRr> {
self.data.first().map(|data| SharedRr {
ttl: self.ttl,
data: data.clone(),
})
}
pub fn set_ttl(&mut self, ttl: Ttl) {
self.ttl = ttl;
}
pub fn limit_ttl(&mut self, ttl: Ttl) {
if self.ttl > ttl {
self.ttl = ttl
}
}
pub fn push_data(&mut self, data: StoredRecordData) {
assert_eq!(data.rtype(), self.rtype);
self.data.push(data);
}
pub fn push_record(&mut self, record: StoredRecord) {
self.limit_ttl(record.ttl());
self.push_data(record.into_data());
}
pub fn into_shared(self) -> SharedRrset {
SharedRrset::new(self)
}
}
impl From<StoredRecord> for Rrset {
fn from(record: StoredRecord) -> Self {
Rrset {
rtype: record.rtype(),
ttl: record.ttl(),
data: vec![record.into_data()],
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SharedRrset(Arc<Rrset>);
impl SharedRrset {
pub fn new(rrset: Rrset) -> Self {
SharedRrset(Arc::new(rrset))
}
pub fn as_rrset(&self) -> &Rrset {
self.0.as_ref()
}
}
impl ops::Deref for SharedRrset {
type Target = Rrset;
fn deref(&self) -> &Self::Target {
self.as_rrset()
}
}
impl AsRef<Rrset> for SharedRrset {
fn as_ref(&self) -> &Rrset {
self.as_rrset()
}
}
impl<'de> Deserialize<'de> for SharedRrset {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<Self, D::Error> {
Rrset::deserialize(deserializer).map(SharedRrset::new)
}
}
impl Serialize for SharedRrset {
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> Result<S::Ok, S::Error> {
self.as_rrset().serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct ZoneCut {
pub name: StoredName,
pub ns: SharedRrset,
pub ds: Option<SharedRrset>,
pub glue: Vec<StoredRecord>,
}
#[derive(Debug, Default)]
pub struct InMemoryZoneDiffBuilder {
added: HashMap<(StoredName, Rtype), SharedRrset>,
removed: HashMap<(StoredName, Rtype), SharedRrset>,
}
impl InMemoryZoneDiffBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn add(
&mut self,
owner: StoredName,
rtype: Rtype,
rrset: SharedRrset,
) {
self.added.insert((owner, rtype), rrset);
}
pub fn remove(
&mut self,
owner: StoredName,
rtype: Rtype,
rrset: SharedRrset,
) {
self.removed.insert((owner, rtype), rrset);
}
pub fn build(self) -> Result<InMemoryZoneDiff, ZoneDiffError> {
InMemoryZoneDiff::new(self.added, self.removed)
}
}
#[derive(Clone, Debug)]
pub struct InMemoryZoneDiff {
pub start_serial: Serial,
pub end_serial: Serial,
pub added: Arc<HashMap<(StoredName, Rtype), SharedRrset>>,
pub removed: Arc<HashMap<(StoredName, Rtype), SharedRrset>>,
}
impl InMemoryZoneDiff {
fn new(
added: HashMap<(Name<Bytes>, Rtype), SharedRrset>,
removed: HashMap<(Name<Bytes>, Rtype), SharedRrset>,
) -> Result<Self, ZoneDiffError> {
let start_serial = removed
.iter()
.find_map(|((_, rtype), rrset)| {
if *rtype == Rtype::SOA {
if let Some(ZoneRecordData::Soa(soa)) =
rrset.data().first()
{
return Some(soa.serial());
}
}
None
})
.ok_or(ZoneDiffError::MissingStartSoa)?;
let end_serial = added
.iter()
.find_map(|((_, rtype), rrset)| {
if *rtype == Rtype::SOA {
if let Some(ZoneRecordData::Soa(soa)) =
rrset.data().first()
{
return Some(soa.serial());
}
}
None
})
.ok_or(ZoneDiffError::MissingEndSoa)?;
if start_serial == end_serial || end_serial < start_serial {
trace!("Diff construction error: serial {start_serial} -> serial {end_serial}:\nremoved: {removed:#?}\nadded: {added:#?}\n");
return Err(ZoneDiffError::InvalidSerialRange);
}
trace!(
"Built diff from serial {start_serial} to serial {end_serial}"
);
Ok(Self {
start_serial,
end_serial,
added: added.into(),
removed: removed.into(),
})
}
}
impl<'a> ZoneDiffItem for (&'a (StoredName, Rtype), &'a SharedRrset) {
fn key(&self) -> &(StoredName, Rtype) {
self.0
}
fn value(&self) -> &SharedRrset {
self.1
}
}
impl ZoneDiff for InMemoryZoneDiff {
type Item<'a>
= (&'a (StoredName, Rtype), &'a SharedRrset)
where
Self: 'a;
type Stream<'a>
= futures_util::stream::Iter<
hash_map::Iter<'a, (StoredName, Rtype), SharedRrset>,
>
where
Self: 'a;
fn start_serial(
&self,
) -> Pin<Box<dyn Future<Output = Serial> + Send + '_>> {
Box::pin(ready(self.start_serial))
}
fn end_serial(
&self,
) -> Pin<Box<dyn Future<Output = Serial> + Send + '_>> {
Box::pin(ready(self.end_serial))
}
fn added(&self) -> Self::Stream<'_> {
stream::iter(self.added.iter())
}
fn removed(&self) -> Self::Stream<'_> {
stream::iter(self.removed.iter())
}
fn get_added(
&self,
name: impl ToName,
rtype: Rtype,
) -> Pin<Box<dyn Future<Output = Option<&SharedRrset>> + Send + '_>> {
Box::pin(ready(self.added.get(&(name.to_name(), rtype))))
}
fn get_removed(
&self,
name: impl ToName,
rtype: Rtype,
) -> Pin<Box<dyn Future<Output = Option<&SharedRrset>> + Send + '_>> {
Box::pin(ready(self.removed.get(&(name.to_name(), rtype))))
}
}
pub struct EmptyZoneDiffItem;
impl ZoneDiffItem for EmptyZoneDiffItem {
fn key(&self) -> &(StoredName, Rtype) {
unreachable!()
}
fn value(&self) -> &SharedRrset {
unreachable!()
}
}
#[derive(Debug)]
pub struct EmptyZoneDiffStream;
impl futures_util::stream::Stream for EmptyZoneDiffStream {
type Item = EmptyZoneDiffItem;
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
#[derive(Debug)]
pub struct EmptyZoneDiff;
impl ZoneDiff for EmptyZoneDiff {
type Item<'a>
= EmptyZoneDiffItem
where
Self: 'a;
type Stream<'a>
= EmptyZoneDiffStream
where
Self: 'a;
fn start_serial(
&self,
) -> Pin<Box<dyn Future<Output = Serial> + Send + '_>> {
Box::pin(ready(Serial(0)))
}
fn end_serial(
&self,
) -> Pin<Box<dyn Future<Output = Serial> + Send + '_>> {
Box::pin(ready(Serial(0)))
}
fn added(&self) -> Self::Stream<'_> {
EmptyZoneDiffStream
}
fn removed(&self) -> Self::Stream<'_> {
EmptyZoneDiffStream
}
fn get_added(
&self,
_name: impl ToName,
_rtype: Rtype,
) -> Pin<Box<dyn Future<Output = Option<&SharedRrset>> + Send + '_>> {
Box::pin(ready(None))
}
fn get_removed(
&self,
_name: impl ToName,
_rtype: Rtype,
) -> Pin<Box<dyn Future<Output = Option<&SharedRrset>> + Send + '_>> {
Box::pin(ready(None))
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ZoneDiffError {
MissingStartSoa,
MissingEndSoa,
InvalidSerialRange,
}
impl std::fmt::Display for ZoneDiffError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ZoneDiffError::MissingStartSoa => f.write_str("MissingStartSoa"),
ZoneDiffError::MissingEndSoa => f.write_str("MissingEndSoa"),
ZoneDiffError::InvalidSerialRange => {
f.write_str("InvalidSerialRange")
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ZoneUpdate<R> {
DeleteAllRecords,
DeleteRecord(R),
AddRecord(R),
BeginBatchDelete(R),
BeginBatchAdd(R),
Finished(R),
}
impl<R> std::fmt::Display for ZoneUpdate<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ZoneUpdate::DeleteAllRecords => f.write_str("DeleteAllRecords"),
ZoneUpdate::DeleteRecord(_) => f.write_str("DeleteRecord"),
ZoneUpdate::AddRecord(_) => f.write_str("AddRecord"),
ZoneUpdate::BeginBatchDelete(_) => {
f.write_str("BeginBatchDelete")
}
ZoneUpdate::BeginBatchAdd(_) => f.write_str("BeginBatchAdd"),
ZoneUpdate::Finished(_) => f.write_str("Finished"),
}
}
}