use crate::proto::{
ProtoError, ProtoErrorKind,
op::{
Edns, Header, LowerQuery, Message, MessageType, OpCode, ResponseCode,
message::{self, EmitAndCount},
},
rr::Record,
serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder},
};
#[derive(Debug, PartialEq, Clone)]
pub struct MessageRequest {
pub header: Header,
pub queries: Queries,
pub answers: Vec<Record>,
pub name_servers: Vec<Record>,
pub additionals: Vec<Record>,
pub sig0: Vec<Record>,
pub edns: Option<Edns>,
}
impl MessageRequest {
pub fn header(&self) -> &Header {
&self.header
}
pub fn id(&self) -> u16 {
self.header.id()
}
pub fn message_type(&self) -> MessageType {
self.header.message_type()
}
pub fn op_code(&self) -> OpCode {
self.header.op_code()
}
pub fn authoritative(&self) -> bool {
self.header.authoritative()
}
pub fn truncated(&self) -> bool {
self.header.truncated()
}
pub fn recursion_desired(&self) -> bool {
self.header.recursion_desired()
}
pub fn recursion_available(&self) -> bool {
self.header.recursion_available()
}
pub fn authentic_data(&self) -> bool {
self.header.authentic_data()
}
pub fn checking_disabled(&self) -> bool {
self.header.checking_disabled()
}
pub fn response_code(&self) -> ResponseCode {
self.header.response_code()
}
pub fn queries(&self) -> &[LowerQuery] {
&self.queries.queries
}
pub fn answers(&self) -> &[Record] {
&self.answers
}
pub fn name_servers(&self) -> &[Record] {
&self.name_servers
}
pub fn additionals(&self) -> &[Record] {
&self.additionals
}
pub fn edns(&self) -> Option<&Edns> {
self.edns.as_ref()
}
pub fn sig0(&self) -> &[Record] {
&self.sig0
}
pub fn max_payload(&self) -> u16 {
let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
if max_size < 512 { 512 } else { max_size }
}
pub fn version(&self) -> u8 {
self.edns.as_ref().map_or(0, Edns::version)
}
pub(crate) fn raw_queries(&self) -> &Queries {
&self.queries
}
}
impl<'q> BinDecodable<'q> for MessageRequest {
fn read(decoder: &mut BinDecoder<'q>) -> Result<Self, ProtoError> {
let mut header = Header::read(decoder)?;
let mut try_parse_rest = move || {
let query_count = header.query_count() as usize;
let answer_count = header.answer_count() as usize;
let name_server_count = header.name_server_count() as usize;
let additional_count = header.additional_count() as usize;
let queries = Queries::read(decoder, query_count)?;
let (answers, _, _) = Message::read_records(decoder, answer_count, false)?;
let (name_servers, _, _) = Message::read_records(decoder, name_server_count, false)?;
let (additionals, edns, sig0) = Message::read_records(decoder, additional_count, true)?;
if let Some(edns) = &edns {
let high_response_code = edns.rcode_high();
header.merge_response_code(high_response_code);
}
Ok(Self {
header,
queries,
answers,
name_servers,
additionals,
sig0,
edns,
})
};
match try_parse_rest() {
Ok(message) => Ok(message),
Err(e) => Err(ProtoErrorKind::FormError {
header,
error: Box::new(e),
}
.into()),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Queries {
pub queries: Vec<LowerQuery>,
pub original: Box<[u8]>,
}
impl Queries {
fn read_queries(
decoder: &mut BinDecoder<'_>,
count: usize,
) -> Result<Vec<LowerQuery>, ProtoError> {
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(LowerQuery::read(decoder)?);
}
Ok(queries)
}
pub fn read(decoder: &mut BinDecoder<'_>, num_queries: usize) -> Result<Self, ProtoError> {
let queries_start = decoder.index();
let queries = Self::read_queries(decoder, num_queries)?;
let original = decoder
.slice_from(queries_start)?
.to_vec()
.into_boxed_slice();
Ok(Self { queries, original })
}
pub fn len(&self) -> usize {
self.queries.len()
}
pub fn is_empty(&self) -> bool {
self.queries.is_empty()
}
pub fn queries(&self) -> &[LowerQuery] {
&self.queries
}
pub fn as_bytes(&self) -> &[u8] {
self.original.as_ref()
}
pub(crate) fn as_emit_and_count(&self) -> QueriesEmitAndCount<'_> {
QueriesEmitAndCount {
length: self.queries.len(),
first_query: self.queries.first(),
cached_serialized: self.original.as_ref(),
}
}
pub(crate) fn try_as_query(&self) -> Result<&LowerQuery, ProtoError> {
let count = self.queries.len();
if count != 1 {
return Err(ProtoErrorKind::BadQueryCount(count).into());
}
Ok(&self.queries[0])
}
pub(crate) fn empty() -> Self {
Self {
queries: Vec::new(),
original: (*b"").into(),
}
}
}
pub(crate) struct QueriesEmitAndCount<'q> {
length: usize,
first_query: Option<&'q LowerQuery>,
cached_serialized: &'q [u8],
}
impl EmitAndCount for QueriesEmitAndCount<'_> {
fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> Result<usize, ProtoError> {
let original_offset = encoder.offset();
encoder.emit_vec(self.cached_serialized)?;
if !encoder.is_canonical_names() && self.first_query.is_some() {
encoder.store_label_pointer(
original_offset,
original_offset + self.cached_serialized.len(),
)
}
Ok(self.length)
}
}
impl BinEncodable for MessageRequest {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> Result<(), ProtoError> {
message::emit_message_parts(
&self.header,
&mut self.queries.queries.iter(),
&mut self.answers.iter(),
&mut self.name_servers.iter(),
&mut self.additionals.iter(),
self.edns.as_ref(),
&self.sig0,
encoder,
)?;
Ok(())
}
}
pub trait UpdateRequest {
fn id(&self) -> u16;
fn zone(&self) -> Result<&LowerQuery, ProtoError>;
fn prerequisites(&self) -> &[Record];
fn updates(&self) -> &[Record];
fn additionals(&self) -> &[Record];
fn sig0(&self) -> &[Record];
}
impl UpdateRequest for MessageRequest {
fn id(&self) -> u16 {
Self::id(self)
}
fn zone(&self) -> Result<&LowerQuery, ProtoError> {
self.raw_queries().try_as_query()
}
fn prerequisites(&self) -> &[Record] {
self.answers()
}
fn updates(&self) -> &[Record] {
self.name_servers()
}
fn additionals(&self) -> &[Record] {
self.additionals()
}
fn sig0(&self) -> &[Record] {
self.sig0()
}
}