use core::marker::PhantomData;
use core::ops::Range;
use byteorder::{ByteOrder, NetworkEndian};
use crate::core::{Class, Opcode, Rcode, Type};
use crate::emit::extension::{ExtensionBuilder, ExtensionError};
use crate::emit::question::{QuestionBuilder, QuestionError};
use crate::emit::record::{RecordBuilder, RecordData, RecordError, RecordName};
use crate::emit::{Buffer, Builder, ChildBuilder, GrowError, PushBuilder, Sink};
use crate::view::{Name, Question};
error!(MessageError, Grow, Extension, Record, Question);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum MessageError {
Grow(GrowError),
QdTooManyQuestions,
AnTooManyRecords,
NsTooManyRecords,
ArTooManyRecords,
Extension(ExtensionError),
Record(RecordError),
Question(QuestionError),
ExtensionRequired(u16),
}
#[must_use]
pub struct MessageBuilder<'b, P, Q> {
buffer: PhantomData<&'b mut dyn Buffer>,
parent: P,
#[allow(dead_code)]
section: Q,
header: Range<usize>,
rcode: Rcode,
}
pub struct QdSection;
pub struct AnSection;
pub struct NsSection;
pub struct ArSection;
pub struct ArWithOpt;
#[doc(hidden)]
pub struct Passkey(());
pub trait Step
where
Self: Sized,
{
#[doc(hidden)]
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError>;
}
pub trait QuestionStep {}
pub trait RecordStep {}
pub trait StepWithoutOpt
where
Self: Sized,
{
#[doc(hidden)]
fn into_ar<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> MessageBuilder<'b, P, ArSection>;
}
impl QuestionStep for QdSection {}
impl RecordStep for AnSection {}
impl RecordStep for NsSection {}
impl RecordStep for ArSection {}
impl RecordStep for ArWithOpt {}
impl StepWithoutOpt for QdSection {
fn into_ar<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> MessageBuilder<'b, P, ArSection> {
message.into_ar()
}
}
impl StepWithoutOpt for AnSection {
fn into_ar<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> MessageBuilder<'b, P, ArSection> {
message.into_ar()
}
}
impl StepWithoutOpt for NsSection {
fn into_ar<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> MessageBuilder<'b, P, ArSection> {
message.into_ar()
}
}
impl StepWithoutOpt for ArSection {
fn into_ar<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> MessageBuilder<'b, P, ArSection> {
message
}
}
impl Step for QdSection {
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError> {
message.qd_increment()
}
}
impl Step for AnSection {
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError> {
message.an_increment()
}
}
impl Step for NsSection {
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError> {
message.ns_increment()
}
}
impl Step for ArSection {
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError> {
message.ar_increment()
}
}
impl Step for ArWithOpt {
fn increment<'b, P: Builder<'b>>(
message: MessageBuilder<'b, P, Self>,
_: Passkey,
) -> Result<MessageBuilder<'b, P, Self>, MessageError> {
message.ar_increment()
}
}
impl<'b, P: Builder<'b>, Q> ChildBuilder<'b, P> for MessageBuilder<'b, P, Q> {
fn parent(&mut self) -> &mut P {
&mut self.parent
}
}
impl<'b> PushBuilder<'b, Sink<'b>> for MessageBuilder<'b, Sink<'b>, QdSection> {
type Error = MessageError;
fn push(mut parent: Sink<'b>) -> Result<Self, MessageError> {
let header = parent.sink().grow_range(12).map_err(MessageError::Grow)?;
Ok(Self {
buffer: PhantomData,
parent,
section: QdSection,
header,
rcode: Rcode::NoError,
})
}
}
builder! {
<'b, P> MessageBuilder {
Builder [Q];
@ <P> Q [Q]:
pub fn id(mut self, value: u16) -> Self = {
let offset = self.header.start;
NetworkEndian::write_u16(&mut self.sink().inner_mut()[offset..], value);
self
}
pub fn qr(mut self, value: bool) -> Self = {
let offset = self.header.start + 2;
self.set(offset, 7, value)
}
pub fn opcode(mut self, value: Opcode) -> Self = {
let offset = self.header.start + 2;
self.u16_set(offset, 11, value.value().into())
}
pub fn aa(mut self, value: bool) -> Self = {
let offset = self.header.start + 2;
self.set(offset, 2, value)
}
pub fn tc(mut self, value: bool) -> Self = {
let offset = self.header.start + 2;
self.set(offset, 1, value)
}
pub fn rd(mut self, value: bool) -> Self = {
let offset = self.header.start + 2;
self.set(offset, 0, value)
}
pub fn ra(mut self, value: bool) -> Self = {
let offset = self.header.start + 3;
self.set(offset, 7, value)
}
fn qd_increment(mut self) -> Result<Self, MessageError> = {
let offset = self.header.start + 4;
self.u16_increment(offset)
.or(Err(MessageError::QdTooManyQuestions))
}
fn an_increment(mut self) -> Result<Self, MessageError> = {
let offset = self.header.start + 6;
self.u16_increment(offset)
.or(Err(MessageError::AnTooManyRecords))
}
fn ns_increment(mut self) -> Result<Self, MessageError> = {
let offset = self.header.start + 8;
self.u16_increment(offset)
.or(Err(MessageError::NsTooManyRecords))
}
fn ar_increment(mut self) -> Result<Self, MessageError> = {
let offset = self.header.start + 10;
self.u16_increment(offset)
.or(Err(MessageError::ArTooManyRecords))
}
fn set(mut self, offset: usize, shift: usize, value: bool) -> Self = {
self.sink().inner_mut()[offset] &= !(1 << shift);
self.sink().inner_mut()[offset] |= (value as u8) << shift;
self
}
fn u16_set(mut self, offset: usize, shift: usize, value: u16) -> Self = {
let x = NetworkEndian::read_u16(&self.sink().inner()[offset..]);
let x = x | value << shift;
NetworkEndian::write_u16(&mut self.sink().inner_mut()[offset..], x);
self
}
fn u16_increment(mut self, offset: usize) -> Result<Self, ()> = {
let x = NetworkEndian::read_u16(&self.sink().inner()[offset..]);
let x = x.checked_add(1).ok_or(())?;
NetworkEndian::write_u16(&mut self.sink().inner_mut()[offset..], x);
Ok(self)
}
@ <P> Q [Q: QuestionStep + Step]:
pub fn copy_question(mut self, question: &Question) -> Result<Self, MessageError> = {
let mut qname = self.question()?.qname().map_err(MessageError::Question)?;
for label in question.qname().labels_not_null().flat_map(|x| x.value()) {
qname = qname.label(label).map_err(QuestionError::Name).map_err(MessageError::Question)?;
}
Ok(qname.finish_question(question.qtype(), question.qclass()).map_err(QuestionError::Name).map_err(MessageError::Question)?)
}
pub fn question_with_name(mut self, name: &Name, qtype: Type, qclass: Class) -> Result<Self, MessageError> = {
let mut qname = self.question()?.qname().map_err(MessageError::Question)?;
for label in name.labels_not_null().flat_map(|x| x.value()) {
qname = qname.label(label).map_err(QuestionError::Name).map_err(MessageError::Question)?;
}
Ok(qname.finish_question(qtype, qclass).map_err(QuestionError::Name).map_err(MessageError::Question)?)
}
}
}
builder! {
<'b, P> MessageBuilder {
@ <P> Q [Q: StepWithoutOpt]:
pub fn rcode(mut self, value: Rcode) -> Self = {
self.rcode = value;
self
}
}
}
impl<'b, P: Builder<'b>, Q> MessageBuilder<'b, P, Q> {
pub(in crate::emit) fn get_rcode(&self) -> &Rcode {
&self.rcode
}
}
impl<__> MessageBuilder<'_, __, __> {}
builder! {
<'b, P> MessageBuilder {
@ <P> Q [Q: QuestionStep + Step]:
pub fn question(mut self) = [push QuestionBuilder | MessageError::Question] { Q::increment(self, Passkey(()))? }
@ <P> Q [Q: RecordStep + Step]:
pub fn record(mut self) = [push RecordBuilder<RecordName> | MessageError::Record] { Q::increment(self, Passkey(()))? }
@ <P> Q [Q: StepWithoutOpt]:
pub fn extension(mut self) -> Result<ExtensionBuilder<'b, MessageBuilder<'b, P, ArSection>>, MessageError> = {
ExtensionBuilder::push(Q::into_ar(self, Passkey(()))).map_err(MessageError::Extension)
}
@ <P> Q [Q: RecordStep + Step]:
pub fn record_with_name(mut self, name: &Name, r#type: Type, class: Class) -> Result<RecordBuilder<'b, MessageBuilder<'b, P, Q>, RecordData>, MessageError> = {
let mut record_name = self.record()?.name().map_err(MessageError::Record)?;
for label in name.labels_not_null().flat_map(|x| x.value()) {
record_name = record_name.label(label).map_err(RecordError::Name).map_err(MessageError::Record)?;
}
Ok(record_name.try_into_rdata().map_err(RecordError::Name).map_err(MessageError::Record)?.r#type(r#type).class(class))
}
}
}
impl<__> MessageBuilder<'_, __, __> {}
builder! {
<'b, P> MessageBuilder {
@ <P> QdSection:
pub fn into_an(mut self) = [into AnSection] { self }
pub fn into_ns(mut self) = [into NsSection] { self }
pub fn into_ar(mut self) = [into ArSection] { self }
@ <P> AnSection:
pub fn into_ns(mut self) = [into NsSection] { self }
pub fn into_ar(mut self) = [into ArSection] { self }
@ <P> NsSection:
pub fn into_ar(mut self) = [into ArSection] { self }
}
}
impl<__> MessageBuilder<'_, __, __> {}
builder! {
<'b, P> MessageBuilder {
@ <P> Q [Q: StepWithoutOpt]:
pub fn finish(mut self) -> Result<P, MessageError> = {
if self.rcode.extended_part() > 0 {
return Err(MessageError::ExtensionRequired(self.rcode.value()));
}
self.finish0()
}
@ <P> ArWithOpt:
pub fn finish(mut self) -> Result<P, MessageError> = {
self.finish0()
}
@ <P> Q [Q]:
fn finish0(mut self) -> Result<P, MessageError> = {
let basic_part = self.rcode.basic_part();
let offset = self.header.start + 2;
self = self.u16_set(offset, 0, basic_part.into());
Ok(self.parent)
}
}
}
transition! {
MessageBuilder.section {
(buffer, parent, header, rcode) QdSection -> AnSection;
(buffer, parent, header, rcode) QdSection -> NsSection;
(buffer, parent, header, rcode) QdSection -> ArSection;
(buffer, parent, header, rcode) AnSection -> NsSection;
(buffer, parent, header, rcode) AnSection -> ArSection;
(buffer, parent, header, rcode) NsSection -> ArSection;
(buffer, parent, header, rcode) ArSection -> ArWithOpt;
}
}