use alloc::vec::Vec;
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use super::BinEncodable;
use crate::error::{ProtoError, ProtoResult};
mod private {
use alloc::vec::Vec;
use crate::{ProtoError, error::ProtoResult};
pub(super) struct MaximalBuf<'a> {
max_size: usize,
buffer: &'a mut Vec<u8>,
}
impl<'a> MaximalBuf<'a> {
pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
MaximalBuf {
max_size: max_size as usize,
buffer,
}
}
pub(super) fn set_max_size(&mut self, max: u16) {
self.max_size = max as usize;
}
pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
debug_assert!(offset <= self.buffer.len());
if offset + data.len() > self.max_size {
return Err(ProtoError::MaxBufferSizeExceeded(self.max_size));
}
if offset == self.buffer.len() {
self.buffer.extend(data);
return Ok(());
}
let end = offset + data.len();
if end > self.buffer.len() {
self.buffer.resize(end, 0);
}
self.buffer[offset..end].copy_from_slice(data);
Ok(())
}
pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
let end = offset + len;
if end > self.max_size {
return Err(ProtoError::MaxBufferSizeExceeded(self.max_size));
}
self.buffer.resize(end, 0);
Ok(())
}
pub(super) fn truncate(&mut self, len: usize) {
self.buffer.truncate(len)
}
pub(super) fn len(&self) -> usize {
self.buffer.len()
}
pub(super) fn buffer(&'a self) -> &'a [u8] {
self.buffer as &'a [u8]
}
pub(super) fn into_bytes(self) -> &'a Vec<u8> {
self.buffer
}
}
}
pub struct BinEncoder<'a> {
offset: usize,
buffer: private::MaximalBuf<'a>,
name_pointers: Vec<(usize, Vec<u8>)>,
canonical_form: bool,
name_encoding: NameEncoding,
}
impl<'a> BinEncoder<'a> {
pub fn new(buf: &'a mut Vec<u8>) -> Self {
Self::with_offset(buf, 0)
}
pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32) -> Self {
if buf.capacity() < 512 {
let reserve = 512 - buf.capacity();
buf.reserve(reserve);
}
BinEncoder {
offset: offset as usize,
buffer: private::MaximalBuf::new(u16::MAX, buf),
name_pointers: Vec::new(),
canonical_form: false,
name_encoding: NameEncoding::Compressed,
}
}
pub fn set_max_size(&mut self, max: u16) {
self.buffer.set_max_size(max);
}
pub fn into_bytes(self) -> &'a Vec<u8> {
self.buffer.into_bytes()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.buffer().is_empty()
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
pub fn set_canonical_form(&mut self, canonical_form: bool) {
self.canonical_form = canonical_form;
}
pub fn is_canonical_form(&self) -> bool {
self.canonical_form
}
pub fn set_name_encoding(&mut self, name_encoding: NameEncoding) {
self.name_encoding = name_encoding;
}
pub fn name_encoding(&self) -> NameEncoding {
self.name_encoding
}
pub fn with_name_encoding<'e>(
&'e mut self,
name_encoding: NameEncoding,
) -> ModalEncoder<'a, 'e> {
let previous_name_encoding = self.name_encoding();
self.set_name_encoding(name_encoding);
ModalEncoder {
previous_name_encoding,
inner: self,
}
}
pub fn with_rdata_behavior<'e>(
&'e mut self,
rdata_encoding: RDataEncoding,
) -> ModalEncoder<'a, 'e> {
let previous_name_encoding = self.name_encoding();
match (rdata_encoding, self.is_canonical_form()) {
(RDataEncoding::StandardRecord, true) | (RDataEncoding::Canonical, true) => {
self.set_name_encoding(NameEncoding::UncompressedLowercase)
}
(RDataEncoding::StandardRecord, false) => {}
(RDataEncoding::Canonical, false)
| (RDataEncoding::Other, true)
| (RDataEncoding::Other, false) => self.set_name_encoding(NameEncoding::Uncompressed),
}
ModalEncoder {
previous_name_encoding,
inner: self,
}
}
pub fn trim(&mut self) {
let offset = self.offset;
self.buffer.truncate(offset);
self.name_pointers.retain(|&(start, _)| start < offset);
}
pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
assert!(start < self.offset);
assert!(end <= self.buffer.len());
&self.buffer.buffer()[start..end]
}
pub fn store_label_pointer(&mut self, start: usize, end: usize) {
assert!(start <= (u16::MAX as usize));
assert!(end <= (u16::MAX as usize));
assert!(start <= end);
if self.offset < 0x3FFF_usize {
self.name_pointers
.push((start, self.slice_of(start, end).to_vec())); }
}
pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
let search = self.slice_of(start, end);
for (match_start, matcher) in &self.name_pointers {
if matcher.as_slice() == search {
assert!(match_start <= &(u16::MAX as usize));
return Some(*match_start as u16);
}
}
None
}
pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
self.buffer.write(self.offset, &[b])?;
self.offset += 1;
Ok(())
}
pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
let char_bytes = char_data.as_ref();
if char_bytes.len() > 255 {
return Err(ProtoError::CharacterDataTooLong {
max: 255,
len: char_bytes.len(),
});
}
self.emit_character_data_unrestricted(char_data)
}
pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
let data = data.as_ref();
self.emit(data.len() as u8)?;
self.write_slice(data)
}
pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
self.emit(data)
}
pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
self.write_slice(&data.to_be_bytes())
}
pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
self.write_slice(&data.to_be_bytes())
}
pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
self.write_slice(&data.to_be_bytes())
}
fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
self.buffer.write(self.offset, data)?;
self.offset += data.len();
Ok(())
}
pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
self.write_slice(data)
}
pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
&mut self,
mut iter: I,
) -> ProtoResult<usize> {
self.emit_iter(&mut iter)
}
pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
where
'e: 'r,
I: Iterator<Item = &'r &'e E>,
E: 'r + 'e + BinEncodable,
{
let mut iter = iter.cloned();
self.emit_iter(&mut iter)
}
pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
&mut self,
iter: &mut I,
) -> ProtoResult<usize> {
let mut count = 0;
for i in iter {
let rollback = self.set_rollback();
if let Err(e) = i.emit(self) {
return Err(match &e {
ProtoError::MaxBufferSizeExceeded(_) => {
rollback.rollback(self);
ProtoError::NotAllRecordsWritten { count }
}
_ => e,
});
}
count += 1;
}
Ok(count)
}
pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
let index = self.offset;
self.buffer.reserve(self.offset, T::LEN)?;
self.offset += T::LEN;
Ok(Place {
start_index: index,
phantom: PhantomData,
})
}
pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
(self.offset - place.start_index) - T::LEN
}
pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
let current_index = self.offset;
assert!(place.start_index < current_index);
self.offset = place.start_index;
let emit_result = data.emit(self);
assert!((self.offset - place.start_index) == T::LEN);
self.offset = current_index;
emit_result
}
fn set_rollback(&self) -> Rollback {
Rollback {
offset: self.offset(),
pointers: self.name_pointers.len(),
}
}
}
#[derive(Clone, Copy)]
pub enum NameEncoding {
Compressed,
Uncompressed,
UncompressedLowercase,
}
#[derive(Clone, Copy)]
pub enum RDataEncoding {
StandardRecord,
Canonical,
Other,
}
pub struct ModalEncoder<'a, 'e> {
previous_name_encoding: NameEncoding,
inner: &'e mut BinEncoder<'a>,
}
impl<'a> Deref for ModalEncoder<'a, '_> {
type Target = BinEncoder<'a>;
fn deref(&self) -> &Self::Target {
self.inner
}
}
impl DerefMut for ModalEncoder<'_, '_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner
}
}
impl Drop for ModalEncoder<'_, '_> {
fn drop(&mut self) {
self.inner.set_name_encoding(self.previous_name_encoding);
}
}
pub trait EncodedSize: BinEncodable {
const LEN: usize;
}
impl EncodedSize for u16 {
const LEN: usize = 2;
}
#[derive(Debug)]
#[must_use = "data must be written back to the place"]
pub struct Place<T: EncodedSize> {
start_index: usize,
phantom: PhantomData<T>,
}
impl<T: EncodedSize> Place<T> {
pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
encoder.emit_at(self, data)
}
}
pub(crate) struct Rollback {
offset: usize,
pointers: usize,
}
impl Rollback {
pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
let Self { offset, pointers } = self;
encoder.set_offset(offset);
encoder.name_pointers.truncate(pointers);
}
}
#[cfg(test)]
mod tests {
#[cfg(any(feature = "std", feature = "no-std-rand"))]
use core::str::FromStr;
use super::*;
use crate::{op::Message, serialize::binary::BinDecoder};
#[cfg(any(feature = "std", feature = "no-std-rand"))]
use crate::{
op::Query,
rr::Name,
rr::{
RData, Record, RecordType,
rdata::{CNAME, SRV},
},
};
#[test]
fn test_label_compression_regression() {
let data = vec![
154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
];
let msg = Message::from_vec(&data).unwrap();
msg.to_bytes().unwrap();
}
#[test]
fn test_place() {
let mut buf = vec![];
{
let mut encoder = BinEncoder::new(&mut buf);
let place = encoder.place::<u16>().unwrap();
assert_eq!(encoder.len_since_place(&place), 0);
encoder.emit(42_u8).expect("failed 0");
assert_eq!(encoder.len_since_place(&place), 1);
encoder.emit(48_u8).expect("failed 1");
assert_eq!(encoder.len_since_place(&place), 2);
place
.replace(&mut encoder, 4_u16)
.expect("failed to replace");
drop(encoder);
}
assert_eq!(buf.len(), 4);
let mut decoder = BinDecoder::new(&buf);
let written = decoder.read_u16().expect("cound not read u16").unverified();
assert_eq!(written, 4);
}
#[test]
fn test_max_size() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(5);
encoder.emit(0).expect("failed to write");
encoder.emit(1).expect("failed to write");
encoder.emit(2).expect("failed to write");
encoder.emit(3).expect("failed to write");
encoder.emit(4).expect("failed to write");
let error = encoder.emit(5).unwrap_err();
match error {
ProtoError::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
#[test]
fn test_max_size_0() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(0);
let error = encoder.emit(0).unwrap_err();
match error {
ProtoError::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
#[test]
fn test_max_size_place() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(2);
let place = encoder.place::<u16>().expect("place failed");
place.replace(&mut encoder, 16).expect("placeback failed");
let error = encoder.place::<u16>().unwrap_err();
match error {
ProtoError::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
#[cfg(any(feature = "std", feature = "no-std-rand"))]
#[test]
fn test_target_compression() {
let mut msg = Message::query();
msg.add_query(Query::query(
Name::from_str("www.google.com.").unwrap(),
RecordType::A,
))
.add_answer(Record::from_rdata(
Name::from_str("www.google.com.").unwrap(),
0,
RData::SRV(SRV::new(
0,
0,
0,
Name::from_str("www.compressme.com.").unwrap(),
)),
))
.add_additional(Record::from_rdata(
Name::from_str("www.google.com.").unwrap(),
0,
RData::SRV(SRV::new(
0,
0,
0,
Name::from_str("www.compressme.com.").unwrap(),
)),
))
.add_answer(Record::from_rdata(
Name::from_str("www.compressme.com.").unwrap(),
0,
RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
));
let bytes = msg.to_vec().unwrap();
assert_eq!(bytes.len(), 130);
assert!(Message::from_vec(&bytes).is_ok());
}
}