use alloc::{borrow::ToOwned, boxed::Box, string::String, vec::Vec};
use thiserror::Error;
use crate::{
rr::{Name, RecordType},
serialize::binary::Restrict,
};
pub struct BinDecoder<'a> {
buffer: &'a [u8], remaining: &'a [u8], }
impl<'a> BinDecoder<'a> {
pub fn new(buffer: &'a [u8]) -> Self {
BinDecoder {
buffer,
remaining: buffer,
}
}
pub fn pop(&mut self) -> Result<Restrict<u8>, DecodeError> {
if let Some((first, remaining)) = self.remaining.split_first() {
self.remaining = remaining;
return Ok(Restrict::new(*first));
}
Err(DecodeError::InsufficientBytes)
}
pub fn len(&self) -> usize {
self.remaining.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn peek(&self) -> Option<Restrict<u8>> {
Some(Restrict::new(*self.remaining.first()?))
}
pub fn index(&self) -> usize {
self.buffer.len() - self.remaining.len()
}
pub fn clone(&self, index_at: u16) -> Self {
BinDecoder {
buffer: self.buffer,
remaining: &self.buffer[index_at as usize..],
}
}
pub fn read_character_data(&mut self) -> Result<Restrict<&[u8]>, DecodeError> {
let length = self.pop()?.unverified() as usize;
self.read_slice(length)
}
pub fn read_vec(&mut self, len: usize) -> Result<Restrict<Vec<u8>>, DecodeError> {
self.read_slice(len).map(|s| s.map(ToOwned::to_owned))
}
pub fn read_slice(&mut self, len: usize) -> Result<Restrict<&'a [u8]>, DecodeError> {
if len > self.remaining.len() {
return Err(DecodeError::InsufficientBytes);
}
let (read, remaining) = self.remaining.split_at(len);
self.remaining = remaining;
Ok(Restrict::new(read))
}
pub fn slice_from(&self, index: usize) -> Result<&'a [u8], DecodeError> {
if index > self.index() {
return Err(DecodeError::InvalidPreviousIndex);
}
Ok(&self.buffer[index..self.index()])
}
pub fn read_u8(&mut self) -> Result<Restrict<u8>, DecodeError> {
self.pop()
}
pub fn read_u16(&mut self) -> Result<Restrict<u16>, DecodeError> {
Ok(self
.read_slice(2)?
.map(|s| u16::from_be_bytes([s[0], s[1]])))
}
pub fn read_i32(&mut self) -> Result<Restrict<i32>, DecodeError> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
i32::from_be_bytes([s[0], s[1], s[2], s[3]])
}))
}
pub fn read_u32(&mut self) -> Result<Restrict<u32>, DecodeError> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
u32::from_be_bytes([s[0], s[1], s[2], s[3]])
}))
}
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum DecodeError {
#[cfg(feature = "__dnssec")]
#[error("dns key value unknown, must be 3: {0}")]
DnsKeyProtocolNot3(u8),
#[cfg(feature = "__dnssec")]
#[error("KEY flags reserved bits are set: {0:#06x}")]
KeyFlagsReserved(u16),
#[cfg(feature = "__dnssec")]
#[error("extended KEY flags not supported")]
ExtendedKeyFlagsUnsupported(u16),
#[error("edns resource record label must be the root label (.): {0}")]
EdnsNameNotRoot(Box<Name>),
#[non_exhaustive]
#[error("incorrect rdata length read: {read} expected: {len}")]
IncorrectRDataLengthRead {
read: usize,
len: usize,
},
#[error("unexpected end of input reached")]
InsufficientBytes,
#[error("unexpected record with length 0 in non-update message")]
InvalidEmptyRecord,
#[error("the index passed to BinDecoder::slice_from must be greater than the decoder position")]
InvalidPreviousIndex,
#[error("label points to data not prior to idx: {idx} ptr: {ptr}")]
PointerNotPriorToLabel {
idx: usize,
ptr: u16,
},
#[error("label bytes exceed 63: {0}")]
LabelBytesTooLong(usize),
#[error("unrecognized label code: {0:b}")]
UnrecognizedLabelCode(u8),
#[error("name label data exceed 255: {0}")]
DomainNameTooLong(usize),
#[error("overlapping labels name {label} other {other}")]
LabelOverlapsWithOther {
label: usize,
other: usize,
},
#[error("unknown digest algorithm: {0}")]
UnknownDigestAlgorithm(u8),
#[error("dns class string unknown: {0}")]
UnknownDnsClassStr(String),
#[error("dns class value unknown: {0}")]
UnknownDnsClassValue(u16),
#[error("record type string unknown: {0}")]
UnknownRecordTypeStr(String),
#[error("record type value unknown: {0}")]
UnknownRecordTypeValue(u16),
#[error("nsec3 flags should be 0b0000000*: {0:b}")]
UnrecognizedNsec3Flags(u8),
#[error("csync flags should be 0b000000**: {0:b}")]
UnrecognizedCsyncFlags(u16),
#[error("unknown NSEC3 hash algorithm: {0}")]
UnknownNsec3HashAlgorithm(u8),
#[error("record after TSIG or SIG(0)")]
RecordAfterSig,
#[error("record type {0} only allowed in additional section")]
RecordNotInAdditionalSection(RecordType),
#[error("more than one EDNS record")]
DuplicateEdns,
#[error("SvcParams out of order")]
SvcParamsOutOfOrder,
#[error("SvcParam expects at least one value")]
SvcParamMissingValue,
#[error("NSEC bitmap out of bounds")]
NsecBitmapOutOfBounds,
#[error("CAA tag invalid")]
CaaTagInvalid,
#[error("NAPTR flags not in range [a-zA-Z0-9]")]
NaptrFlagsInvalid,
#[error("unknown address family: {0:#x}")]
UnknownAddressFamily(u16),
#[error("invalid UTF-8: {0}")]
Utf8(#[from] alloc::string::FromUtf8Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_slice() {
let deadbeef = b"deadbeef";
let mut decoder = BinDecoder::new(deadbeef);
let read = decoder.read_slice(4).expect("failed to read dead");
assert_eq!(&read.unverified(), b"dead");
let read = decoder.read_slice(2).expect("failed to read be");
assert_eq!(&read.unverified(), b"be");
let read = decoder.read_slice(0).expect("failed to read nothing");
assert_eq!(&read.unverified(), b"");
assert!(decoder.read_slice(3).is_err());
}
#[test]
fn test_read_slice_from() {
let deadbeef = b"deadbeef";
let mut decoder = BinDecoder::new(deadbeef);
decoder.read_slice(4).expect("failed to read dead");
let read = decoder.slice_from(0).expect("failed to get slice");
assert_eq!(&read, b"dead");
decoder.read_slice(2).expect("failed to read be");
let read = decoder.slice_from(4).expect("failed to get slice");
assert_eq!(&read, b"be");
decoder.read_slice(0).expect("failed to read nothing");
let read = decoder.slice_from(4).expect("failed to get slice");
assert_eq!(&read, b"be");
assert!(decoder.slice_from(10).is_err());
}
}