use trust_dns_proto::{
rr::Name,
serialize::binary::{BinDecodable, BinDecoder},
};
use crate::error::{DecodeError, DecodeResult};
use std::{
array::TryFromSliceError,
convert::TryInto,
ffi::{CStr, CString},
mem,
net::{Ipv4Addr, Ipv6Addr},
str,
};
pub trait Decodable: Sized {
fn decode(decoder: &mut Decoder<'_>) -> DecodeResult<Self>;
fn from_bytes(bytes: &[u8]) -> DecodeResult<Self> {
let mut decoder = Decoder::new(bytes);
Self::decode(&mut decoder)
}
}
#[derive(Debug)]
pub struct Decoder<'a> {
buffer: &'a [u8],
}
impl<'a> Decoder<'a> {
pub fn new(buffer: &'a [u8]) -> Self {
Decoder { buffer }
}
pub fn peek_u8(&self) -> DecodeResult<u8> {
Ok(u8::from_be_bytes(self.peek::<{ mem::size_of::<u8>() }>()?))
}
pub fn read_u8(&mut self) -> DecodeResult<u8> {
Ok(u8::from_be_bytes(self.read::<{ mem::size_of::<u8>() }>()?))
}
pub fn read_u32(&mut self) -> DecodeResult<u32> {
Ok(u32::from_be_bytes(
self.read::<{ mem::size_of::<u32>() }>()?,
))
}
pub fn read_i32(&mut self) -> DecodeResult<i32> {
Ok(i32::from_be_bytes(
self.read::<{ mem::size_of::<i32>() }>()?,
))
}
pub fn read_u16(&mut self) -> DecodeResult<u16> {
Ok(u16::from_be_bytes(
self.read::<{ mem::size_of::<u16>() }>()?,
))
}
pub fn read_u64(&mut self) -> DecodeResult<u64> {
Ok(u64::from_be_bytes(
self.read::<{ mem::size_of::<u64>() }>()?,
))
}
pub fn read<const N: usize>(&mut self) -> DecodeResult<[u8; N]> {
if N > self.buffer.len() {
return Err(DecodeError::NotEnoughBytes);
}
let (slice, remaining) = self.buffer.split_at(N);
self.buffer = remaining;
Ok(slice.try_into().unwrap())
}
pub fn peek<const N: usize>(&self) -> DecodeResult<[u8; N]> {
if N > self.buffer.len() {
return Err(DecodeError::NotEnoughBytes);
}
Ok(self.buffer[..N].try_into().unwrap())
}
pub fn read_cstring<const MAX: usize>(&mut self) -> DecodeResult<Option<CString>> {
let bytes = self.read::<MAX>()?;
let nul_idx = bytes.iter().position(|&b| b == 0);
match nul_idx {
Some(0) => Ok(None),
Some(n) => Ok(Some(CStr::from_bytes_with_nul(&bytes[..=n])?.to_owned())),
None => Ok(None),
}
}
pub fn read_nul_bytes<const MAX: usize>(&mut self) -> DecodeResult<Option<Vec<u8>>> {
let bytes = self.read::<MAX>()?;
let nul_idx = bytes.iter().position(|&b| b == 0);
match nul_idx {
Some(0) => Ok(None),
Some(n) => Ok(Some(bytes[..=n].to_vec())),
None => Ok(None),
}
}
pub fn read_nul_string<const MAX: usize>(&mut self) -> DecodeResult<Option<String>> {
Ok(self
.read_nul_bytes::<MAX>()?
.map(|ref bytes| str::from_utf8(bytes).map(|s| s.to_owned()))
.transpose()?)
}
pub fn read_slice(&mut self, len: usize) -> DecodeResult<&'a [u8]> {
if len > self.buffer.len() {
return Err(DecodeError::NotEnoughBytes);
}
let (slice, remaining) = self.buffer.split_at(len);
self.buffer = remaining;
Ok(slice)
}
pub fn read_string(&mut self, len: usize) -> DecodeResult<String> {
Ok(self.read_str(len)?.to_owned())
}
pub fn read_str(&mut self, len: usize) -> DecodeResult<&str> {
Ok(str::from_utf8(self.read_slice(len)?)?)
}
pub fn read_ipv4(&mut self, length: usize) -> DecodeResult<Ipv4Addr> {
if length != 4 {
return Err(DecodeError::NotEnoughBytes);
}
let bytes = self.read::<4>()?;
Ok(bytes.into())
}
pub fn read_ipv4s(&mut self, length: usize) -> DecodeResult<Vec<Ipv4Addr>> {
if length % 4 != 0 {
return Err(DecodeError::NotEnoughBytes);
}
let ips = self.read_slice(length)?;
Ok(ips
.chunks(4)
.map(|bytes| [bytes[0], bytes[1], bytes[2], bytes[3]].into())
.collect())
}
pub fn read_ipv6s(&mut self, length: usize) -> DecodeResult<Vec<Ipv6Addr>> {
if length % 16 != 0 {
return Err(DecodeError::NotEnoughBytes);
}
let ips = self.read_slice(length)?;
Ok(ips
.chunks(16)
.map(|bytes| Ok::<_, TryFromSliceError>(TryInto::<[u8; 16]>::try_into(bytes)?.into()))
.collect::<Result<Vec<Ipv6Addr>, _>>()?)
}
pub fn read_pair_ipv4s(&mut self, length: usize) -> DecodeResult<Vec<(Ipv4Addr, Ipv4Addr)>> {
if length % 8 != 0 {
return Err(DecodeError::NotEnoughBytes);
}
let ips = self.read_slice(length)?;
Ok(ips
.chunks(8)
.map(|bytes| {
(
[bytes[0], bytes[1], bytes[2], bytes[3]].into(),
[bytes[4], bytes[5], bytes[6], bytes[7]].into(),
)
})
.collect())
}
pub fn read_domains(&mut self, length: usize) -> DecodeResult<Vec<Name>> {
let mut name_decoder = BinDecoder::new(self.read_slice(length)?);
let mut names = Vec::new();
while let Ok(name) = Name::read(&mut name_decoder) {
names.push(name);
}
Ok(names)
}
pub fn read_bool(&mut self) -> DecodeResult<bool> {
Ok(self.read_u8()? == 1)
}
pub fn buffer(&self) -> &[u8] {
self.buffer
}
}