use core::any::Any;
use std::boxed::Box;
use std::collections::{hash_map, HashMap};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use parking_lot::{
RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard,
};
use tokio::sync::Mutex;
use crate::base::iana::{Class, Rtype};
use crate::base::name::{Label, OwnedLabel, ToName};
use crate::zonetree::error::{CnameError, OutOfZone, ZoneCutError};
use crate::zonetree::types::{StoredName, ZoneCut};
use crate::zonetree::util::rel_name_rev_iter;
use crate::zonetree::walk::WalkState;
use crate::zonetree::{
ReadableZone, SharedRr, SharedRrset, WritableZone, ZoneStore,
};
use super::read::ReadZone;
use super::versioned::{Version, Versioned};
use super::write::{WriteZone, ZoneVersions};
#[derive(Debug)]
pub struct ZoneApex {
apex_name: StoredName,
class: Class,
rrsets: NodeRrsets,
children: NodeChildren,
update_lock: Arc<Mutex<()>>,
versions: Arc<RwLock<ZoneVersions>>,
}
impl ZoneApex {
pub fn new(apex_name: StoredName, class: Class) -> Self {
ZoneApex {
apex_name,
class,
rrsets: Default::default(),
children: Default::default(),
update_lock: Default::default(),
versions: Default::default(),
}
}
pub fn from_parts(
apex_name: StoredName,
class: Class,
rrsets: NodeRrsets,
children: NodeChildren,
versions: ZoneVersions,
) -> Self {
ZoneApex {
apex_name,
class,
rrsets,
children,
update_lock: Default::default(),
versions: Arc::new(RwLock::new(versions)),
}
}
pub fn prepare_name<'l>(
&self,
qname: &'l impl ToName,
) -> Result<impl Iterator<Item = &'l Label> + Clone, OutOfZone> {
rel_name_rev_iter(&self.apex_name, qname)
}
pub fn rrsets(&self) -> &NodeRrsets {
&self.rrsets
}
pub fn get_soa(&self, version: Version) -> Option<SharedRr> {
self.rrsets()
.get(Rtype::SOA, version)
.and_then(|rrset| rrset.first())
}
pub fn children(&self) -> &NodeChildren {
&self.children
}
pub fn rollback(&self, version: Version) {
self.rrsets.rollback(version);
self.children.rollback(version);
}
pub fn remove_all(&self, version: Version) {
self.rrsets.remove_all(version);
self.children.remove_all(version);
}
pub fn versions(&self) -> &RwLock<ZoneVersions> {
&self.versions
}
pub fn name(&self) -> &StoredName {
&self.apex_name
}
}
impl ZoneStore for ZoneApex {
fn class(&self) -> Class {
self.class
}
fn apex_name(&self) -> &StoredName {
&self.apex_name
}
fn read(self: Arc<Self>) -> Box<dyn ReadableZone> {
let (version, marker) = self.versions().read().current().clone();
Box::new(ReadZone::new(self, version, marker))
}
fn write(
self: Arc<Self>,
) -> Pin<
Box<
dyn Future<Output = Box<dyn WritableZone + 'static>>
+ Send
+ Sync
+ 'static,
>,
> {
Box::pin(async move {
let lock = self.update_lock.clone().lock_owned().await;
let version = self.versions().read().current().0.next();
let zone_versions = self.versions.clone();
Box::new(WriteZone::new(self, lock, version, zone_versions))
as Box<dyn WritableZone>
})
}
fn as_any(&self) -> &dyn Any {
self as &dyn Any
}
}
impl<'a> From<&'a ZoneApex> for CnameError {
fn from(_: &'a ZoneApex) -> CnameError {
CnameError::CnameAtApex
}
}
impl<'a> From<&'a ZoneApex> for ZoneCutError {
fn from(_: &'a ZoneApex) -> ZoneCutError {
ZoneCutError::ZoneCutAtApex
}
}
#[derive(Default, Debug)]
pub struct ZoneNode {
rrsets: NodeRrsets,
special: RwLock<Versioned<Option<Special>>>,
children: NodeChildren,
}
impl ZoneNode {
pub fn rrsets(&self) -> &NodeRrsets {
&self.rrsets
}
pub fn is_nx_domain(&self, version: Version) -> bool {
self.with_special(version, |special| {
matches!(special, Some(Special::NxDomain))
})
}
pub fn with_special<R>(
&self,
version: Version,
op: impl FnOnce(Option<&Special>) -> R,
) -> R {
op(self.special.read().get(version).and_then(Option::as_ref))
}
pub fn update_special(&self, version: Version, special: Option<Special>) {
self.special.write().update(version, special)
}
pub fn children(&self) -> &NodeChildren {
&self.children
}
pub fn rollback(&self, version: Version) {
self.rrsets.rollback(version);
self.special.write().rollback(version);
self.children.rollback(version);
}
pub fn remove_all(&self, version: Version) {
self.rrsets.remove_all(version);
self.special.write().remove(version);
self.children.remove_all(version);
}
}
#[derive(Default, Debug)]
pub struct NodeRrsets {
rrsets: RwLock<HashMap<Rtype, NodeRrset>>,
}
impl NodeRrsets {
pub fn is_empty(&self, version: Version) -> bool {
let rrsets = self.rrsets.read();
if rrsets.is_empty() {
return true;
}
for value in self.rrsets.read().values() {
if value.get(version).is_some() {
return false;
}
}
true
}
pub fn get(&self, rtype: Rtype, version: Version) -> Option<SharedRrset> {
self.rrsets
.read()
.get(&rtype)
.and_then(|rrsets| rrsets.get(version))
.cloned()
}
pub fn update(&self, rrset: SharedRrset, version: Version) {
if rrset.is_empty() {
self.remove_rtype(rrset.rtype(), version);
} else {
self.rrsets
.write()
.entry(rrset.rtype())
.or_default()
.update(rrset, version);
}
}
pub fn remove_rtype(&self, rtype: Rtype, version: Version) {
self.rrsets
.write()
.entry(rtype)
.or_default()
.remove(version);
}
pub fn rollback(&self, version: Version) {
self.rrsets
.write()
.values_mut()
.for_each(|rrset| rrset.rollback(version));
}
pub fn remove_all(&self, version: Version) {
self.rrsets
.write()
.values_mut()
.for_each(|rrset| rrset.remove(version));
}
pub(super) fn iter(&self) -> NodeRrsetsIter<'_> {
NodeRrsetsIter::new(self.rrsets.read())
}
}
pub(super) struct NodeRrsetsIter<'a> {
guard: RwLockReadGuard<'a, HashMap<Rtype, NodeRrset>>,
}
impl<'a> NodeRrsetsIter<'a> {
fn new(guard: RwLockReadGuard<'a, HashMap<Rtype, NodeRrset>>) -> Self {
Self { guard }
}
pub fn iter(&self) -> hash_map::Iter<'_, Rtype, NodeRrset> {
self.guard.iter()
}
}
#[derive(Default, Debug)]
pub(crate) struct NodeRrset {
rrsets: Versioned<SharedRrset>,
}
impl NodeRrset {
pub fn get(&self, version: Version) -> Option<&SharedRrset> {
self.rrsets.get(version)
}
fn update(&mut self, rrset: SharedRrset, version: Version) {
self.rrsets.update(version, rrset)
}
fn remove(&mut self, version: Version) {
self.rrsets.remove(version)
}
pub fn rollback(&mut self, version: Version) {
self.rrsets.rollback(version);
}
}
#[derive(Clone, Debug)]
pub enum Special {
Cut(ZoneCut),
Cname(SharedRr),
NxDomain,
}
#[derive(Debug, Default)]
pub struct NodeChildren {
children: RwLock<HashMap<OwnedLabel, Arc<ZoneNode>>>,
}
impl NodeChildren {
pub fn with<R>(
&self,
label: &Label,
op: impl FnOnce(Option<&Arc<ZoneNode>>) -> R,
) -> R {
op(self.children.read().get(label))
}
pub fn with_or_default<R>(
&self,
label: &Label,
op: impl FnOnce(&Arc<ZoneNode>, bool) -> R,
) -> R {
let lock = self.children.upgradable_read();
if let Some(node) = lock.get(label) {
return op(node, false);
}
let mut lock = RwLockUpgradableReadGuard::upgrade(lock);
lock.insert(label.into(), Default::default());
let lock = RwLockWriteGuard::downgrade(lock);
op(lock.get(label).unwrap(), true)
}
fn rollback(&self, version: Version) {
self.children
.read()
.values()
.for_each(|item| item.rollback(version))
}
fn remove_all(&self, version: Version) {
self.children
.read()
.values()
.for_each(|item| item.remove_all(version))
}
pub(super) fn walk(
&self,
walk: WalkState,
op: impl Fn(WalkState, (&OwnedLabel, &Arc<ZoneNode>)),
) {
for child in self.children.read().iter() {
(op)(walk.clone(), child)
}
}
}