use core::fmt;
use crate::{
new::base::{
wire::{ParseBytesZC, U16},
Header, HeaderFlags, Message, MessageItem, Question, Record,
SectionCounts,
},
new::edns::EdnsRecord,
};
use super::{BuildBytes, BuildInMessage, NameCompressor, TruncationError};
pub struct MessageBuilder<'b, 'c> {
message: &'b mut Message,
offset: usize,
compressor: &'c mut NameCompressor,
}
impl<'b, 'c> MessageBuilder<'b, 'c> {
#[must_use]
pub fn new(
buffer: &'b mut [u8],
compressor: &'c mut NameCompressor,
id: U16,
flags: HeaderFlags,
) -> Self {
let message = Message::parse_bytes_in(buffer)
.expect("The caller's buffer is at least 12 bytes big");
message.header = Header {
id,
flags,
counts: SectionCounts::default(),
};
Self {
message,
offset: 0,
compressor,
}
}
}
impl MessageBuilder<'_, '_> {
#[must_use]
pub fn header(&self) -> &Header {
&self.message.header
}
#[must_use]
pub fn header_mut(&mut self) -> &mut Header {
&mut self.message.header
}
#[must_use]
pub fn message(&self) -> &Message {
self.message.truncate(self.offset)
}
#[must_use]
pub fn message_mut(&mut self) -> &mut Message {
self.message.truncate_mut(self.offset)
}
#[must_use]
pub fn compressor(&self) -> &NameCompressor {
self.compressor
}
}
impl<'b> MessageBuilder<'b, '_> {
#[must_use]
pub fn finish(self) -> &'b mut Message {
self.message.truncate_mut(self.offset)
}
#[must_use]
pub fn reborrow(&mut self) -> MessageBuilder<'_, '_> {
MessageBuilder {
message: self.message,
offset: self.offset,
compressor: self.compressor,
}
}
pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> {
if 12 + self.offset <= size {
let size = (size - 12).min(self.message.contents.len());
let message = unsafe { core::ptr::read(&self.message) };
let message = message.truncate_mut(size);
unsafe { core::ptr::write(&mut self.message, message) };
Ok(())
} else {
Err(TruncationError)
}
}
pub fn truncate(&mut self) {
self.message.header.flags.set_tc(true);
self.offset = 0;
}
pub fn push<N, RD, ED>(
&mut self,
item: &MessageItem<N, RD, ED>,
) -> Result<(), MessageBuildError>
where
N: BuildInMessage,
RD: BuildInMessage,
ED: BuildBytes,
{
let section = match item {
MessageItem::Question(_) => 0,
MessageItem::Answer(_) => 1,
MessageItem::Authority(_) => 2,
MessageItem::Additional(_) => 3,
MessageItem::Edns(_) => 3,
};
let counts = self.message.header.counts.as_array_mut();
if counts[section + 1..].iter().any(|c| c.get() != 0) {
return Err(MessageBuildError::Misplaced);
}
self.offset = item.build_in_message(
&mut self.message.contents,
self.offset,
self.compressor,
)?;
counts[section] += 1;
Ok(())
}
pub fn push_question<N: BuildInMessage>(
&mut self,
question: &Question<N>,
) -> Result<(), MessageBuildError> {
let question = question.transform_ref(|n| n);
self.push(&MessageItem::<&N, (), ()>::Question(question))
}
pub fn push_answer<N: BuildInMessage, D: BuildInMessage>(
&mut self,
answer: &Record<N, D>,
) -> Result<(), MessageBuildError> {
let answer = answer.transform_ref(|n| n, |d| d);
self.push(&MessageItem::<&N, &D, ()>::Answer(answer))
}
pub fn push_authority<N: BuildInMessage, D: BuildInMessage>(
&mut self,
authority: &Record<N, D>,
) -> Result<(), MessageBuildError> {
let authority = authority.transform_ref(|n| n, |d| d);
self.push(&MessageItem::<&N, &D, ()>::Authority(authority))
}
pub fn push_additional<N: BuildInMessage, D: BuildInMessage>(
&mut self,
additional: &Record<N, D>,
) -> Result<(), TruncationError> {
let additional = additional.transform_ref(|n| n, |d| d);
self.push(&MessageItem::<&N, &D, ()>::Additional(additional))
.map_err(|err| match err {
MessageBuildError::Misplaced => {
unreachable!("An additional record is never misplaced")
}
MessageBuildError::Truncated(err) => err,
})
}
pub fn push_edns<D: ?Sized + BuildBytes>(
&mut self,
edns: &EdnsRecord<D>,
) -> Result<(), TruncationError> {
let edns = edns.transform_ref(|d| d);
self.push(&MessageItem::<(), (), &D>::Edns(edns))
.map_err(|err| match err {
MessageBuildError::Misplaced => {
unreachable!("An additional record is never misplaced")
}
MessageBuildError::Truncated(err) => err,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum MessageBuildError {
Misplaced,
Truncated(TruncationError),
}
#[cfg(feature = "std")]
impl std::error::Error for MessageBuildError {}
impl From<TruncationError> for MessageBuildError {
fn from(value: TruncationError) -> Self {
Self::Truncated(value)
}
}
impl fmt::Display for MessageBuildError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Misplaced => {
"a DNS message item was placed in the wrong order"
}
Self::Truncated(_) => "a DNS message item was too large to fit",
})
}
}
#[cfg(test)]
mod test {
use crate::new::base::name::RevNameBuf;
use crate::new::base::wire::U16;
use crate::new::base::{
HeaderFlags, QClass, QType, Question, RClass, RType, Record, TTL,
};
use crate::new::rdata::{RecordData, A};
use super::{MessageBuilder, NameCompressor};
#[test]
fn new() {
let mut buffer = [0u8; 12];
let mut compressor = NameCompressor::default();
let mut builder = MessageBuilder::new(
&mut buffer,
&mut compressor,
U16::new(0),
HeaderFlags::default(),
);
assert_eq!(&builder.message().contents, &[] as &[u8]);
assert_eq!(&builder.message_mut().contents, &[] as &[u8]);
}
#[test]
fn build_question() {
let mut buffer = [0u8; 33];
let mut compressor = NameCompressor::default();
let mut builder = MessageBuilder::new(
&mut buffer,
&mut compressor,
U16::new(0),
HeaderFlags::default(),
);
let question = Question::<RevNameBuf> {
qname: "www.example.org".parse().unwrap(),
qtype: QType::A,
qclass: QClass::IN,
};
builder.push_question(&question).unwrap();
let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01";
assert_eq!(&builder.message().contents, contents);
}
#[test]
fn build_record() {
let mut buffer = [0u8; 43];
let mut compressor = NameCompressor::default();
let mut builder = MessageBuilder::new(
&mut buffer,
&mut compressor,
U16::new(0),
HeaderFlags::default(),
);
let record = Record::<RevNameBuf, _> {
rname: "www.example.org".parse().unwrap(),
rtype: RType::A,
rclass: RClass::IN,
ttl: TTL::from(42),
rdata: RecordData::<()>::A(A {
octets: [127, 0, 0, 1],
}),
};
builder.push_answer(&record).unwrap();
assert_eq!(builder.message().header.counts.answers.get(), 1);
let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01";
assert_eq!(&builder.message().contents, contents.as_slice());
}
}