use alloc::{collections::BTreeSet, string::ToString};
use core::{fmt, str::FromStr};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
error::*,
rr::{RData, RecordData, RecordDataDecodable, RecordType, RecordTypeSet},
serialize::{binary::*, txt::ParseError},
};
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[non_exhaustive]
pub struct CSYNC {
pub soa_serial: u32,
pub immediate: bool,
pub soa_minimum: bool,
pub reserved_flags: u16,
pub type_bit_maps: RecordTypeSet,
}
impl CSYNC {
pub fn new(
soa_serial: u32,
immediate: bool,
soa_minimum: bool,
type_bit_maps: impl IntoIterator<Item = RecordType>,
) -> Self {
Self {
soa_serial,
immediate,
soa_minimum,
reserved_flags: 0,
type_bit_maps: RecordTypeSet::new(type_bit_maps),
}
}
pub(crate) fn from_tokens<'i, I: Iterator<Item = &'i str>>(
mut tokens: I,
) -> Result<Self, ParseError> {
let soa_serial: u32 = tokens
.next()
.ok_or_else(|| ParseError::MissingToken("soa_serial".to_string()))
.and_then(|s| s.parse().map_err(Into::into))?;
let flags: u16 = tokens
.next()
.ok_or_else(|| ParseError::MissingToken("flags".to_string()))
.and_then(|s| s.parse().map_err(Into::into))?;
let immediate: bool = flags & 0b0000_0001 == 0b0000_0001;
let soa_minimum: bool = flags & 0b0000_0010 == 0b0000_0010;
let mut record_types = BTreeSet::new();
for token in tokens {
record_types.insert(RecordType::from_str(token)?);
}
Ok(Self::new(soa_serial, immediate, soa_minimum, record_types))
}
pub fn flags(&self) -> u16 {
let mut flags = self.reserved_flags & 0b1111_1111_1111_1100;
if self.immediate {
flags |= 0b0000_0001
};
if self.soa_minimum {
flags |= 0b0000_0010
};
flags
}
}
impl BinEncodable for CSYNC {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit_u32(self.soa_serial)?;
encoder.emit_u16(self.flags())?;
self.type_bit_maps.emit(encoder)?;
Ok(())
}
}
impl<'r> RecordDataDecodable<'r> for CSYNC {
fn read_data(decoder: &mut BinDecoder<'r>, length: Restrict<u16>) -> Result<Self, DecodeError> {
let start_idx = decoder.index();
let soa_serial = decoder.read_u32()?.unverified();
let flags: u16 = decoder
.read_u16()?
.verify_unwrap(|flags| flags & 0b1111_1100 == 0)
.map_err(DecodeError::UnrecognizedCsyncFlags)?;
let immediate: bool = flags & 0b0000_0001 == 0b0000_0001;
let soa_minimum: bool = flags & 0b0000_0010 == 0b0000_0010;
let reserved_flags = flags & 0b1111_1111_1111_1100;
let offset = u16::try_from(decoder.index() - start_idx).map_err(|_| {
DecodeError::IncorrectRDataLengthRead {
read: decoder.index() - start_idx,
len: u16::MAX as usize,
}
})?;
let bit_map_len =
length
.checked_sub(offset)
.map_err(|len| DecodeError::IncorrectRDataLengthRead {
read: offset as usize,
len: len as usize,
})?;
let type_bit_maps = RecordTypeSet::read_data(decoder, bit_map_len)?;
Ok(Self {
soa_serial,
immediate,
soa_minimum,
reserved_flags,
type_bit_maps,
})
}
}
impl RecordData for CSYNC {
fn try_borrow(data: &RData) -> Option<&Self> {
match data {
RData::CSYNC(csync) => Some(csync),
_ => None,
}
}
fn record_type(&self) -> RecordType {
RecordType::CSYNC
}
fn into_rdata(self) -> RData {
RData::CSYNC(self)
}
}
impl fmt::Display for CSYNC {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{soa_serial} {flags}",
soa_serial = &self.soa_serial,
flags = &self.flags(),
)?;
for ty in self.type_bit_maps.iter() {
write!(f, " {ty}")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
#[cfg(feature = "std")]
use std::println;
use alloc::vec::Vec;
use super::*;
#[test]
fn test() {
let types = [RecordType::A, RecordType::NS, RecordType::AAAA];
let rdata = CSYNC::new(123, true, true, types);
let mut bytes = Vec::new();
let mut encoder: BinEncoder<'_> = BinEncoder::new(&mut bytes);
assert!(rdata.emit(&mut encoder).is_ok());
let bytes = encoder.into_bytes();
#[cfg(feature = "std")]
println!("bytes: {bytes:?}");
let mut decoder: BinDecoder<'_> = BinDecoder::new(bytes);
let restrict = Restrict::new(bytes.len() as u16);
let read_rdata = CSYNC::read_data(&mut decoder, restrict).expect("Decoding error");
assert_eq!(rdata, read_rdata);
}
#[test]
fn test_parsing() {
assert_eq!(
CSYNC::from_tokens(vec!["123", "3", "NS"].into_iter()).expect("failed to parse CSYNC"),
CSYNC::new(123, true, true, [RecordType::NS]),
);
}
#[test]
fn test_parsing_fails() {
assert!(CSYNC::from_tokens(vec!["NS"].into_iter()).is_err());
assert!(CSYNC::from_tokens(vec![].into_iter()).is_err());
}
}