#![forbid(
unsafe_code,
missing_docs,
missing_debug_implementations,
rust_2018_idioms,
future_incompatible
)]
#![no_std]
#[cfg(feature = "std")]
extern crate std;
use core::convert::{TryFrom, TryInto};
use core::fmt;
use core::iter;
use core::mem;
use core::num::NonZeroUsize;
use core::str;
use core::error::Error as StdError;
mod ser;
pub use ser::{Cursor, Deserialize, Label, LabelSegment, Serialize};
macro_rules! serialize {
(
$(#[$outer:meta])*
pub struct $name:ident $(<$lt: lifetime>)? {
$(
$(#[$inner:meta])*
$vis: vis $field:ident: $ty:ty,
)*
}
) => {
$(#[$outer])*
pub struct $name $(<$lt>)? {
$(
$(#[$inner])*
$vis $field: $ty,
)*
}
impl<'a> Serialize<'a> for $name $(<$lt>)? {
fn serialized_len(&self) -> usize {
let mut len = 0;
$(
len += self.$field.serialized_len();
)*
len
}
fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
let mut index = 0;
$(
index += self.$field.serialize(&mut cursor[index..])?;
)*
Ok(index)
}
}
impl<'a> Deserialize<'a> for $name $(<$lt>)? {
fn deserialize(&mut self, mut cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
$(
cursor = self.$field.deserialize(cursor)?;
)*
Ok(cursor)
}
}
};
}
macro_rules! num_enum {
(
$(#[$outer:meta])*
pub enum $name:ident {
$(
$(#[$inner:meta])*
$variant:ident = $value:expr,
)*
}
) => {
$(#[$outer])*
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u16)]
#[non_exhaustive]
pub enum $name {
$(
$(#[$inner])*
$variant = $value,
)*
}
impl TryFrom<u16> for $name {
type Error = InvalidCode;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
$(
$value => Ok($name::$variant),
)*
_ => Err(InvalidCode(value)),
}
}
}
impl From<$name> for u16 {
fn from(value: $name) -> Self {
value as u16
}
}
impl<'a> Serialize<'a> for $name {
fn serialized_len(&self) -> usize {
mem::size_of::<u16>()
}
fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
let value: u16 = (*self).into();
value.serialize(cursor)
}
}
impl<'a> Deserialize<'a> for $name {
fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
let mut value = 0;
let cursor = value.deserialize(cursor)?;
*self = value.try_into()?;
Ok(cursor)
}
}
};
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::IsVariant, derive_more::Display)]
pub enum BufferType {
#[display("question")]
Question,
#[display("answer")]
Answer,
#[display("authority")]
Authority,
#[display("additional")]
Additional,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
NotEnoughWriteSpace {
tried_to_write: NonZeroUsize,
available: usize,
buffer_type: BufferType,
},
NotEnoughReadBytes {
tried_to_read: NonZeroUsize,
available: usize,
},
Parse {
name: &'static str,
},
NameTooLong(usize),
InvalidUtf8(simdutf8::compat::Utf8Error),
InvalidCode(InvalidCode),
TooManyUrlSegments(usize),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::NotEnoughWriteSpace {
tried_to_write,
available,
buffer_type,
} => {
write!(
f,
"not enough write space: tried to write {} entries to {} buffer, but only {} were available",
tried_to_write, buffer_type, available
)
}
Error::NotEnoughReadBytes {
tried_to_read,
available,
} => {
write!(
f,
"not enough read bytes: tried to read {} bytes, but only {} were available",
tried_to_read, available
)
}
Error::Parse { name } => {
write!(f, "parse error: could not parse a {}", name)
}
Error::NameTooLong(len) => {
write!(f, "name too long: name was {} bytes long", len)
}
Error::InvalidUtf8(err) => {
write!(f, "invalid UTF-8: {}", err)
}
Error::TooManyUrlSegments(segments) => {
write!(f, "too many URL segments: {} segments", segments)
}
Error::InvalidCode(err) => {
write!(f, "{}", err)
}
}
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Error::InvalidCode(err) => Some(err),
_ => None,
}
}
}
impl From<simdutf8::compat::Utf8Error> for Error {
fn from(err: simdutf8::compat::Utf8Error) -> Self {
Error::InvalidUtf8(err)
}
}
impl From<InvalidCode> for Error {
fn from(err: InvalidCode) -> Self {
Error::InvalidCode(err)
}
}
#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Message<'arrays, 'innards> {
header: Header,
questions: &'arrays mut [Question<'innards>],
answers: &'arrays mut [ResourceRecord<'innards>],
authorities: &'arrays mut [ResourceRecord<'innards>],
additional: &'arrays mut [ResourceRecord<'innards>],
}
impl fmt::Debug for Message<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Message")
.field("header", &self.header)
.field("questions", &self.questions())
.field("answers", &self.answers())
.field("authorities", &self.authorities())
.field("additional", &self.additional())
.finish()
}
}
impl<'arrays, 'innards> Message<'arrays, 'innards> {
pub fn new(
id: u16,
flags: Flags,
questions: &'arrays mut [Question<'innards>],
answers: &'arrays mut [ResourceRecord<'innards>],
authorities: &'arrays mut [ResourceRecord<'innards>],
additional: &'arrays mut [ResourceRecord<'innards>],
) -> Self {
Self {
header: Header {
id,
flags,
question_count: questions.len().try_into().unwrap(),
answer_count: answers.len().try_into().unwrap(),
authority_count: authorities.len().try_into().unwrap(),
additional_count: additional.len().try_into().unwrap(),
},
questions,
answers,
authorities,
additional,
}
}
pub fn id(&self) -> u16 {
self.header.id
}
pub fn id_mut(&mut self) -> &mut u16 {
&mut self.header.id
}
pub fn header(&self) -> Header {
self.header
}
pub fn flags(&self) -> Flags {
self.header.flags
}
pub fn flags_mut(&mut self) -> &mut Flags {
&mut self.header.flags
}
pub fn questions(&self) -> &[Question<'innards>] {
&self.questions[..self.header.question_count as usize]
}
pub fn questions_mut(&mut self) -> &mut [Question<'innards>] {
&mut self.questions[..self.header.question_count as usize]
}
pub fn answers(&self) -> &[ResourceRecord<'innards>] {
&self.answers[..self.header.answer_count as usize]
}
pub fn answers_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
&mut self.answers[..self.header.answer_count as usize]
}
pub fn authorities(&self) -> &[ResourceRecord<'innards>] {
&self.authorities[..self.header.authority_count as usize]
}
pub fn authorities_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
&mut self.authorities[..self.header.authority_count as usize]
}
pub fn additional(&self) -> &[ResourceRecord<'innards>] {
&self.additional[..self.header.additional_count as usize]
}
pub fn additional_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
&mut self.additional[..self.header.additional_count as usize]
}
pub fn space_needed(&self) -> usize {
self.serialized_len()
}
pub fn write(&self, buffer: &mut [u8]) -> Result<usize, Error> {
self.serialize(buffer)
}
pub fn read(
buffer: &'innards [u8],
questions: &'arrays mut [Question<'innards>],
answers: &'arrays mut [ResourceRecord<'innards>],
authorities: &'arrays mut [ResourceRecord<'innards>],
additional: &'arrays mut [ResourceRecord<'innards>],
) -> Result<Message<'arrays, 'innards>, Error> {
let mut message = Message::new(
0,
Flags::default(),
questions,
answers,
authorities,
additional,
);
let cursor = Cursor::new(buffer);
message.deserialize(cursor)?;
Ok(message)
}
}
impl<'innards> Serialize<'innards> for Message<'_, 'innards> {
fn serialized_len(&self) -> usize {
iter::once(self.header.serialized_len())
.chain(self.questions().iter().map(Serialize::serialized_len))
.chain(self.answers().iter().map(Serialize::serialized_len))
.chain(self.authorities().iter().map(Serialize::serialized_len))
.chain(self.additional().iter().map(Serialize::serialized_len))
.fold(0, |a, b| a.saturating_add(b))
}
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
let mut offset = 0;
offset += self.header.serialize(&mut bytes[offset..])?;
for question in self.questions.iter() {
offset += question.serialize(&mut bytes[offset..])?;
}
for answer in self.answers.iter() {
offset += answer.serialize(&mut bytes[offset..])?;
}
for authority in self.authorities.iter() {
offset += authority.serialize(&mut bytes[offset..])?;
}
for additional in self.additional.iter() {
offset += additional.serialize(&mut bytes[offset..])?;
}
Ok(offset)
}
}
impl<'innards> Deserialize<'innards> for Message<'_, 'innards> {
fn deserialize(&mut self, cursor: Cursor<'innards>) -> Result<Cursor<'innards>, Error> {
fn try_read_set<'a, T: Deserialize<'a>>(
mut cursor: Cursor<'a>,
count: usize,
items: &mut [T],
ty: BufferType,
) -> Result<Cursor<'a>, Error> {
let len = items.len();
if count == 0 {
return Ok(cursor);
}
for i in 0..count {
cursor = items
.get_mut(i)
.ok_or_else(|| Error::NotEnoughWriteSpace {
tried_to_write: NonZeroUsize::new(count).unwrap(),
available: len,
buffer_type: ty,
})?
.deserialize(cursor)?;
}
Ok(cursor)
}
let cursor = self.header.deserialize(cursor)?;
if self.header.flags.truncated() {
self.header.clear();
return Ok(cursor);
}
let cursor = try_read_set(
cursor,
self.header.question_count as usize,
self.questions,
BufferType::Question,
)?;
let cursor = try_read_set(
cursor,
self.header.answer_count as usize,
self.answers,
BufferType::Answer,
)?;
let cursor = try_read_set(
cursor,
self.header.authority_count as usize,
self.authorities,
BufferType::Authority,
)?;
let cursor = try_read_set(
cursor,
self.header.additional_count as usize,
self.additional,
BufferType::Additional,
)?;
Ok(cursor)
}
}
serialize! {
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Header {
id: u16,
flags: Flags,
question_count: u16,
answer_count: u16,
authority_count: u16,
additional_count: u16,
}
}
impl Header {
fn clear(&mut self) {
self.question_count = 0;
self.answer_count = 0;
self.authority_count = 0;
self.additional_count = 0;
}
}
serialize! {
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Question<'a> {
name: Label<'a>,
ty: ResourceType,
class: u16,
}
}
impl<'a> Question<'a> {
pub fn new(label: impl Into<Label<'a>>, ty: ResourceType, class: u16) -> Self {
Self {
name: label.into(),
ty,
class,
}
}
pub fn name(&self) -> Label<'a> {
self.name
}
pub fn ty(&self) -> ResourceType {
self.ty
}
pub fn class(&self) -> u16 {
self.class
}
}
serialize! {
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ResourceRecord<'a> {
name: Label<'a>,
ty: ResourceType,
class: u16,
ttl: u32,
data: ResourceData<'a>,
}
}
impl<'a> ResourceRecord<'a> {
pub fn new(
name: impl Into<Label<'a>>,
ty: ResourceType,
class: u16,
ttl: u32,
data: &'a [u8],
) -> Self {
Self {
name: name.into(),
ty,
class,
ttl,
data: data.into(),
}
}
pub fn name(&self) -> Label<'a> {
self.name
}
pub fn ty(&self) -> ResourceType {
self.ty
}
pub fn class(&self) -> u16 {
self.class
}
pub fn ttl(&self) -> u32 {
self.ttl
}
pub fn data(&self) -> &'a [u8] {
self.data.0
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct ResourceData<'a>(&'a [u8]);
impl<'a> From<&'a [u8]> for ResourceData<'a> {
fn from(data: &'a [u8]) -> Self {
ResourceData(data)
}
}
impl<'a> Serialize<'a> for ResourceData<'a> {
fn serialized_len(&self) -> usize {
2 + self.0.len()
}
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
let len = self.serialized_len();
if bytes.len() < len {
panic!("not enough bytes to serialize resource data");
}
let [b1, b2] = (self.0.len() as u16).to_be_bytes();
bytes[0] = b1;
bytes[1] = b2;
bytes[2..len].copy_from_slice(self.0);
Ok(len)
}
}
impl<'a> Deserialize<'a> for ResourceData<'a> {
fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
let mut len = 0u16;
let cursor = len.deserialize(cursor)?;
if len == 0 {
self.0 = &[];
return Ok(cursor);
}
if cursor.len() < len as usize {
return Err(Error::NotEnoughReadBytes {
tried_to_read: NonZeroUsize::new(len as usize).unwrap(),
available: cursor.len(),
});
}
self.0 = &cursor.remaining()[..len as usize];
cursor.advance(len as usize)
}
}
#[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct Flags(u16);
impl Flags {
const RAW_QR: u16 = 1 << 15;
const RAW_OPCODE_SHIFT: u16 = 11;
const RAW_OPCODE_MASK: u16 = 0b1111;
const RAW_AA: u16 = 1 << 10;
const RAW_TC: u16 = 1 << 9;
const RAW_RD: u16 = 1 << 8;
const RAW_RA: u16 = 1 << 7;
const RAW_RCODE_SHIFT: u16 = 0;
const RAW_RCODE_MASK: u16 = 0b1111;
pub const fn new() -> Self {
Self(0)
}
pub const fn standard_query() -> Self {
Self(0x0100)
}
pub fn qr(&self) -> MessageType {
if self.0 & Self::RAW_QR != 0 {
MessageType::Reply
} else {
MessageType::Query
}
}
pub fn set_qr(&mut self, qr: MessageType) -> &mut Self {
if qr == MessageType::Reply {
self.0 |= Self::RAW_QR;
} else {
self.0 &= !Self::RAW_QR;
}
self
}
pub fn opcode(&self) -> Opcode {
let raw = (self.0 >> Self::RAW_OPCODE_SHIFT) & Self::RAW_OPCODE_MASK;
raw.try_into()
.unwrap_or_else(|_| panic!("invalid opcode: {}", raw))
}
pub fn set_opcode(&mut self, opcode: Opcode) {
self.0 |= (opcode as u16) << Self::RAW_OPCODE_SHIFT;
}
pub fn authoritative(&self) -> bool {
self.0 & Self::RAW_AA != 0
}
pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
if authoritative {
self.0 |= Self::RAW_AA;
} else {
self.0 &= !Self::RAW_AA;
}
self
}
pub fn truncated(&self) -> bool {
self.0 & Self::RAW_TC != 0
}
pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
if truncated {
self.0 |= Self::RAW_TC;
} else {
self.0 &= !Self::RAW_TC;
}
self
}
pub fn recursive(&self) -> bool {
self.0 & Self::RAW_RD != 0
}
pub fn set_recursive(&mut self, recursive: bool) -> &mut Self {
if recursive {
self.0 |= Self::RAW_RD;
} else {
self.0 &= !Self::RAW_RD;
}
self
}
pub fn recursion_available(&self) -> bool {
self.0 & Self::RAW_RA != 0
}
pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
if recursion_available {
self.0 |= Self::RAW_RA;
} else {
self.0 &= !Self::RAW_RA;
}
self
}
pub fn response_code(&self) -> ResponseCode {
let raw = (self.0 >> Self::RAW_RCODE_SHIFT) & Self::RAW_RCODE_MASK;
raw.try_into()
.unwrap_or_else(|_| panic!("invalid response code: {}", raw))
}
pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
self.0 |= (response_code as u16) << Self::RAW_RCODE_SHIFT;
self
}
pub fn raw(self) -> u16 {
self.0
}
}
impl fmt::Debug for Flags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut list = f.debug_list();
list.entry(&self.qr());
list.entry(&self.opcode());
if self.authoritative() {
list.entry(&"authoritative");
}
if self.truncated() {
list.entry(&"truncated");
}
if self.recursive() {
list.entry(&"recursive");
}
if self.recursion_available() {
list.entry(&"recursion available");
}
list.entry(&self.response_code());
list.finish()
}
}
impl Serialize<'_> for Flags {
fn serialized_len(&self) -> usize {
2
}
fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error> {
self.0.serialize(buf)
}
}
impl<'a> Deserialize<'a> for Flags {
fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
u16::deserialize(&mut self.0, bytes)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum MessageType {
Query,
Reply,
}
num_enum! {
pub enum Opcode {
Query = 0,
IQuery = 1,
Status = 2,
Notify = 4,
Update = 5,
Dso = 6,
}
}
num_enum! {
pub enum ResponseCode {
NoError = 0,
FormatError = 1,
ServerFailure = 2,
NameError = 3,
NotImplemented = 4,
Refused = 5,
YxDomain = 6,
YxRrSet = 7,
NxRrSet = 8,
NotAuth = 9,
NotZone = 10,
DsoTypeNi = 11,
BadVers = 16,
BadKey = 17,
BadTime = 18,
BadMode = 19,
BadName = 20,
BadAlg = 21,
BadTrunc = 22,
BadCookie = 23,
}
}
num_enum! {
pub enum ResourceType {
A = 1,
NS = 2,
MD = 3,
MF = 4,
CName = 5,
Soa = 6,
MB = 7,
MG = 8,
MR = 9,
Null = 10,
Wks = 11,
Ptr = 12,
HInfo = 13,
MInfo = 14,
MX = 15,
Txt = 16,
RP = 17,
AfsDb = 18,
X25 = 19,
Isdn = 20,
Rt = 21,
NSap = 22,
NSapPtr = 23,
Sig = 24,
Key = 25,
Px = 26,
GPos = 27,
AAAA = 28,
Loc = 29,
Nxt = 30,
EId = 31,
NimLoc = 32,
Srv = 33,
AtmA = 34,
NAPtr = 35,
Kx = 36,
Cert = 37,
A6 = 38,
DName = 39,
Sink = 40,
Opt = 41,
ApL = 42,
DS = 43,
SshFp = 44,
IpSecKey = 45,
RRSig = 46,
NSEC = 47,
DNSKey = 48,
DHCID = 49,
NSEC3 = 50,
NSEC3Param = 51,
TLSA = 52,
SMimeA = 53,
HIP = 55,
NInfo = 56,
RKey = 57,
TALink = 58,
CDS = 59,
CDNSKey = 60,
OpenPGPKey = 61,
CSync = 62,
ZoneMD = 63,
Svcb = 64,
Https = 65,
Spf = 99,
UInfo = 100,
UID = 101,
GID = 102,
Unspec = 103,
NID = 104,
L32 = 105,
L64 = 106,
LP = 107,
EUI48 = 108,
EUI64 = 109,
TKey = 249,
TSig = 250,
Ixfr = 251,
Axfr = 252,
MailB = 253,
MailA = 254,
Wildcard = 255,
Uri = 256,
Caa = 257,
Avc = 258,
Doa = 259,
Amtrelay = 260,
TA = 32768,
DLV = 32769,
}
}
impl Default for ResourceType {
fn default() -> Self {
Self::A
}
}
#[derive(Debug, Clone)]
pub struct InvalidCode(u16);
impl InvalidCode {
pub fn code(&self) -> u16 {
self.0
}
}
impl fmt::Display for InvalidCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid code: {}", self.0)
}
}
impl StdError for InvalidCode {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resource_data_serialization() {
let mut buf = [0u8; 7];
let record = ResourceData(&[0x1f, 0xfe, 0x02, 0x24, 0x75]);
let len = record
.serialize(&mut buf)
.expect("serialized into provided buffer");
assert_eq!(len, 7);
assert_eq!(buf, [0x00, 0x05, 0x1f, 0xfe, 0x02, 0x24, 0x75]);
}
}