use core::fmt;
use crate::new::edns::EdnsRecord;
use crate::utils::dst::UnsizedCopy;
use super::build::{BuildInMessage, NameCompressor};
use super::parse::MessageParser;
use super::wire::{
AsBytes, BuildBytes, ParseBytes, ParseBytesZC, SplitBytes, SplitBytesZC,
TruncationError, U16,
};
use super::{Question, Record};
#[derive(AsBytes, BuildBytes, ParseBytesZC, UnsizedCopy)]
#[repr(C, packed)]
pub struct Message {
pub header: Header,
pub contents: [u8],
}
impl Message {
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
unsafe {
core::slice::from_raw_parts_mut(
self as *mut Self as *mut u8,
core::mem::size_of_val(self),
)
}
}
}
impl Message {
pub const fn parse(&self) -> MessageParser<'_> {
MessageParser::for_message(self)
}
}
impl Message {
pub fn truncate(&self, size: usize) -> &Self {
let bytes = &self.as_bytes()[..12 + size];
unsafe { Self::parse_bytes_by_ref(bytes).unwrap_unchecked() }
}
pub fn truncate_mut(&mut self, size: usize) -> &mut Self {
let bytes = &mut self.as_bytes_mut()[..12 + size];
unsafe { Self::parse_bytes_in(bytes).unwrap_unchecked() }
}
pub unsafe fn truncate_ptr(this: *mut Message, size: usize) -> *mut Self {
let len = unsafe { &*(this as *mut [()]) }.len();
debug_assert!(size <= len);
core::ptr::slice_from_raw_parts_mut(this.cast::<u8>(), size)
as *mut Self
}
}
#[cfg(feature = "alloc")]
impl Clone for alloc::boxed::Box<Message> {
fn clone(&self) -> Self {
(*self).unsized_copy_into()
}
}
#[derive(
Copy,
Clone,
Debug,
Hash,
AsBytes,
BuildBytes,
ParseBytes,
ParseBytesZC,
SplitBytes,
SplitBytesZC,
UnsizedCopy,
)]
#[repr(C)]
pub struct Header {
pub id: U16,
pub flags: HeaderFlags,
pub counts: SectionCounts,
}
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} of ID {:04X} ({})",
self.flags,
self.id.get(),
self.counts
)
}
}
#[derive(
Copy,
Clone,
Default,
Hash,
AsBytes,
BuildBytes,
ParseBytes,
ParseBytesZC,
SplitBytes,
SplitBytesZC,
UnsizedCopy,
)]
#[repr(transparent)]
pub struct HeaderFlags {
inner: U16,
}
impl HeaderFlags {
const fn get_flag(&self, pos: u32) -> bool {
self.inner.get() & (1 << pos) != 0
}
fn set_flag(&mut self, pos: u32, value: bool) -> &mut Self {
self.inner &= !(1 << pos);
self.inner |= (value as u16) << pos;
self
}
pub const fn bits(&self) -> u16 {
self.inner.get()
}
pub const fn qr(&self) -> bool {
self.get_flag(15)
}
pub fn set_qr(&mut self, value: bool) -> &mut Self {
self.set_flag(15, value)
}
pub const fn opcode(&self) -> u8 {
(self.inner.get() >> 11) as u8 & 0xF
}
pub fn set_opcode(&mut self, value: u8) -> &mut Self {
debug_assert!(value < 16);
self.inner &= !(0xF << 11);
self.inner |= (value as u16) << 11;
self
}
pub fn aa(&self) -> bool {
self.get_flag(10)
}
pub fn set_aa(&mut self, value: bool) -> &mut Self {
self.set_flag(10, value)
}
pub fn tc(&self) -> bool {
self.get_flag(9)
}
pub fn set_tc(&mut self, value: bool) -> &mut Self {
self.set_flag(9, value)
}
pub fn rd(&self) -> bool {
self.get_flag(8)
}
pub fn set_rd(&mut self, value: bool) -> &mut Self {
self.set_flag(8, value)
}
pub fn ra(&self) -> bool {
self.get_flag(7)
}
pub fn set_ra(&mut self, value: bool) -> &mut Self {
self.set_flag(7, value)
}
pub fn ad(&self) -> bool {
self.get_flag(5)
}
pub fn set_ad(&mut self, value: bool) -> &mut Self {
self.set_flag(5, value)
}
pub fn cd(&self) -> bool {
self.get_flag(4)
}
pub fn set_cd(&mut self, value: bool) -> &mut Self {
self.set_flag(4, value)
}
pub const fn rcode(&self) -> u8 {
self.inner.get() as u8 & 0xF
}
pub fn set_rcode(&mut self, value: u8) -> &mut Self {
debug_assert!(value < 16);
self.inner &= !0xF;
self.inner |= value as u16;
self
}
}
impl fmt::Debug for HeaderFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HeaderFlags")
.field("qr", &self.qr())
.field("opcode", &self.opcode())
.field("aa", &self.aa())
.field("tc", &self.tc())
.field("rd", &self.rd())
.field("ra", &self.ra())
.field("rcode", &self.rcode())
.field("bits", &self.bits())
.finish()
}
}
impl fmt::Display for HeaderFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.qr() {
if self.rd() {
f.write_str("recursive ")?;
}
write!(f, "query (opcode {})", self.opcode())?;
if self.cd() {
f.write_str(" (checking disabled)")?;
}
} else {
if self.ad() {
f.write_str("authentic ")?;
}
if self.aa() {
f.write_str("authoritative ")?;
}
if self.rd() && self.ra() {
f.write_str("recursive ")?;
}
write!(f, "response (rcode {})", self.rcode())?;
}
if self.tc() {
f.write_str(" (message truncated)")?;
}
Ok(())
}
}
#[derive(
Copy,
Clone,
Debug,
Default,
PartialEq,
Eq,
Hash,
AsBytes,
BuildBytes,
ParseBytes,
ParseBytesZC,
SplitBytes,
SplitBytesZC,
UnsizedCopy,
)]
#[repr(C)]
pub struct SectionCounts {
pub questions: U16,
pub answers: U16,
pub authorities: U16,
pub additionals: U16,
}
impl SectionCounts {
pub fn as_array(&self) -> &[U16; 4] {
unsafe { core::mem::transmute(self) }
}
pub fn as_array_mut(&mut self) -> &mut [U16; 4] {
unsafe { core::mem::transmute(self) }
}
}
impl fmt::Display for SectionCounts {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut some = false;
for (num, single, many) in [
(self.questions.get(), "question", "questions"),
(self.answers.get(), "answer", "answers"),
(self.authorities.get(), "authority", "authorities"),
(self.additionals.get(), "additional", "additionals"),
] {
if some && num > 0 {
f.write_str(", ")?;
}
match num {
0 => {}
1 => write!(f, "1 {single}")?,
n => write!(f, "{n} {many}")?,
}
some |= num > 0;
}
if !some {
f.write_str("empty")?;
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub enum MessageItem<N, RD, ED> {
Question(Question<N>),
Answer(Record<N, RD>),
Authority(Record<N, RD>),
Additional(Record<N, RD>),
Edns(EdnsRecord<ED>),
}
impl<N, RD, ED> MessageItem<N, RD, ED> {
pub fn transform<NN, NRD, NED>(
self,
name_map: impl FnOnce(N) -> NN,
rdata_map: impl FnOnce(RD) -> NRD,
edata_map: impl FnOnce(ED) -> NED,
) -> MessageItem<NN, NRD, NED> {
match self {
Self::Question(this) => {
MessageItem::Question(this.transform(name_map))
}
Self::Answer(this) => {
MessageItem::Answer(this.transform(name_map, rdata_map))
}
Self::Authority(this) => {
MessageItem::Authority(this.transform(name_map, rdata_map))
}
Self::Additional(this) => {
MessageItem::Additional(this.transform(name_map, rdata_map))
}
Self::Edns(this) => MessageItem::Edns(this.transform(edata_map)),
}
}
pub fn transform_ref<'a, NN, NRD, NED>(
&'a self,
name_map: impl FnOnce(&'a N) -> NN,
rdata_map: impl FnOnce(&'a RD) -> NRD,
edata_map: impl FnOnce(&'a ED) -> NED,
) -> MessageItem<NN, NRD, NED> {
match self {
Self::Question(this) => {
MessageItem::Question(this.transform_ref(name_map))
}
Self::Answer(this) => {
MessageItem::Answer(this.transform_ref(name_map, rdata_map))
}
Self::Authority(this) => MessageItem::Authority(
this.transform_ref(name_map, rdata_map),
),
Self::Additional(this) => MessageItem::Additional(
this.transform_ref(name_map, rdata_map),
),
Self::Edns(this) => {
MessageItem::Edns(this.transform_ref(edata_map))
}
}
}
}
impl<N, RD, LED, RED> PartialEq<MessageItem<N, RD, RED>>
for MessageItem<N, RD, LED>
where
N: PartialEq,
RD: PartialEq,
LED: PartialEq<RED>,
{
fn eq(&self, other: &MessageItem<N, RD, RED>) -> bool {
match (self, other) {
(MessageItem::Question(l), MessageItem::Question(r)) => l == r,
(MessageItem::Answer(l), MessageItem::Answer(r)) => l == r,
(MessageItem::Authority(l), MessageItem::Authority(r)) => l == r,
(MessageItem::Additional(l), MessageItem::Additional(r)) => {
l == r
}
(MessageItem::Edns(l), MessageItem::Edns(r)) => l == r,
_ => false,
}
}
}
impl<N: Eq, RD: Eq, ED: Eq> Eq for MessageItem<N, RD, ED> {}
impl<N, RD, ED> BuildInMessage for MessageItem<N, RD, ED>
where
N: BuildInMessage,
RD: BuildInMessage,
ED: BuildBytes,
{
fn build_in_message(
&self,
contents: &mut [u8],
start: usize,
compressor: &mut NameCompressor,
) -> Result<usize, TruncationError> {
match self {
Self::Question(i) => {
i.build_in_message(contents, start, compressor)
}
Self::Answer(i) => {
i.build_in_message(contents, start, compressor)
}
Self::Authority(i) => {
i.build_in_message(contents, start, compressor)
}
Self::Additional(i) => {
i.build_in_message(contents, start, compressor)
}
Self::Edns(i) => i.build_in_message(contents, start, compressor),
}
}
}
impl<N, RD, ED> BuildBytes for MessageItem<N, RD, ED>
where
N: BuildBytes,
RD: BuildBytes,
ED: BuildBytes,
{
fn build_bytes<'b>(
&self,
bytes: &'b mut [u8],
) -> Result<&'b mut [u8], TruncationError> {
match self {
Self::Question(this) => this.build_bytes(bytes),
Self::Answer(this) => this.build_bytes(bytes),
Self::Authority(this) => this.build_bytes(bytes),
Self::Additional(this) => this.build_bytes(bytes),
Self::Edns(this) => this.build_bytes(bytes),
}
}
fn built_bytes_size(&self) -> usize {
match self {
Self::Question(this) => this.built_bytes_size(),
Self::Answer(this) => this.built_bytes_size(),
Self::Authority(this) => this.built_bytes_size(),
Self::Additional(this) => this.built_bytes_size(),
Self::Edns(this) => this.built_bytes_size(),
}
}
}