use std::iter::once;
use crate::proto::{
error::*,
op::{
message::{self, EmitAndCount},
Edns, Header, LowerQuery, Message, MessageType, OpCode, ResponseCode,
},
rr::Record,
serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder},
};
#[derive(Debug, PartialEq)]
pub struct MessageRequest {
header: Header,
query: WireQuery,
answers: Vec<Record>,
name_servers: Vec<Record>,
additionals: Vec<Record>,
sig0: Vec<Record>,
edns: Option<Edns>,
}
impl MessageRequest {
pub fn header(&self) -> &Header {
&self.header
}
pub fn id(&self) -> u16 {
self.header.id()
}
pub fn message_type(&self) -> MessageType {
self.header.message_type()
}
pub fn op_code(&self) -> OpCode {
self.header.op_code()
}
pub fn authoritative(&self) -> bool {
self.header.authoritative()
}
pub fn truncated(&self) -> bool {
self.header.truncated()
}
pub fn recursion_desired(&self) -> bool {
self.header.recursion_desired()
}
pub fn recursion_available(&self) -> bool {
self.header.recursion_available()
}
pub fn authentic_data(&self) -> bool {
self.header.authentic_data()
}
pub fn checking_disabled(&self) -> bool {
self.header.checking_disabled()
}
pub fn response_code(&self) -> ResponseCode {
self.header.response_code()
}
pub fn query(&self) -> &LowerQuery {
&self.query.query
}
pub fn answers(&self) -> &[Record] {
&self.answers
}
pub fn name_servers(&self) -> &[Record] {
&self.name_servers
}
pub fn additionals(&self) -> &[Record] {
&self.additionals
}
pub fn edns(&self) -> Option<&Edns> {
self.edns.as_ref()
}
pub fn sig0(&self) -> &[Record] {
&self.sig0
}
pub fn max_payload(&self) -> u16 {
let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
if max_size < 512 {
512
} else {
max_size
}
}
pub fn version(&self) -> u8 {
self.edns.as_ref().map_or(0, Edns::version)
}
pub(crate) fn raw_query(&self) -> &WireQuery {
&self.query
}
}
impl<'q> BinDecodable<'q> for MessageRequest {
fn read(decoder: &mut BinDecoder<'q>) -> ProtoResult<Self> {
let mut header = Header::read(decoder)?;
let mut try_parse_rest = move || {
let query_count = header.query_count() as usize;
let answer_count = header.answer_count() as usize;
let name_server_count = header.name_server_count() as usize;
let additional_count = header.additional_count() as usize;
let queries = Queries::read(decoder, query_count)?;
let query = queries.try_into_query()?;
let (answers, _, _) = Message::read_records(decoder, answer_count, false)?;
let (name_servers, _, _) = Message::read_records(decoder, name_server_count, false)?;
let (additionals, edns, sig0) = Message::read_records(decoder, additional_count, true)?;
if let Some(edns) = &edns {
let high_response_code = edns.rcode_high();
header.merge_response_code(high_response_code);
}
Ok(Self {
header,
query,
answers,
name_servers,
additionals,
sig0,
edns,
})
};
match try_parse_rest() {
Ok(message) => Ok(message),
Err(e) => Err(ProtoErrorKind::FormError {
header,
error: Box::new(e),
}
.into()),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Queries {
queries: Vec<LowerQuery>,
original: Box<[u8]>,
}
impl Queries {
fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<LowerQuery>> {
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(LowerQuery::read(decoder)?);
}
Ok(queries)
}
pub fn read(decoder: &mut BinDecoder<'_>, num_queries: usize) -> ProtoResult<Self> {
let queries_start = decoder.index();
let queries = Self::read_queries(decoder, num_queries)?;
let original = decoder
.slice_from(queries_start)?
.to_vec()
.into_boxed_slice();
Ok(Self { queries, original })
}
pub fn len(&self) -> usize {
self.queries.len()
}
pub fn is_empty(&self) -> bool {
self.queries.is_empty()
}
pub fn as_bytes(&self) -> &[u8] {
self.original.as_ref()
}
pub(crate) fn as_emit_and_count(&self) -> QueriesEmitAndCount<'_> {
QueriesEmitAndCount {
length: self.queries.len(),
first_query: self.queries.get(0),
cached_serialized: self.original.as_ref(),
}
}
pub(crate) fn try_into_query(mut self) -> Result<WireQuery, ProtoError> {
let count = self.queries.len();
if count == 1 {
let query = self.queries.pop().expect("should have been at least one");
Ok(WireQuery {
query,
original: self.original,
})
} else {
Err(ProtoErrorKind::BadQueryCount(count).into())
}
}
}
#[derive(Debug, PartialEq)]
pub(crate) struct WireQuery {
query: LowerQuery,
original: Box<[u8]>,
}
impl WireQuery {
pub(crate) fn as_emit_and_count(&self) -> QueriesEmitAndCount<'_> {
QueriesEmitAndCount {
length: 1,
first_query: Some(&self.query),
cached_serialized: self.original.as_ref(),
}
}
}
pub(crate) struct QueriesEmitAndCount<'q> {
length: usize,
first_query: Option<&'q LowerQuery>,
cached_serialized: &'q [u8],
}
impl<'q> EmitAndCount for QueriesEmitAndCount<'q> {
fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
let original_offset = encoder.offset();
encoder.emit_vec(self.cached_serialized)?;
if !encoder.is_canonical_names() {
if let Some(query) = self.first_query {
encoder.store_label_pointer(
original_offset,
original_offset + query.original().name().len(),
)
}
}
Ok(self.length)
}
}
impl BinEncodable for MessageRequest {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
message::emit_message_parts(
&self.header,
&mut once(&self.query.query),
&mut self.answers.iter(),
&mut self.name_servers.iter(),
&mut self.additionals.iter(),
self.edns.as_ref(),
&self.sig0,
encoder,
)?;
Ok(())
}
}
pub trait UpdateRequest {
fn id(&self) -> u16;
fn zone(&self) -> &LowerQuery;
fn prerequisites(&self) -> &[Record];
fn updates(&self) -> &[Record];
fn additionals(&self) -> &[Record];
fn sig0(&self) -> &[Record];
}
impl UpdateRequest for MessageRequest {
fn id(&self) -> u16 {
Self::id(self)
}
fn zone(&self) -> &LowerQuery {
self.query()
}
fn prerequisites(&self) -> &[Record] {
self.answers()
}
fn updates(&self) -> &[Record] {
self.name_servers()
}
fn additionals(&self) -> &[Record] {
self.additionals()
}
fn sig0(&self) -> &[Record] {
self.sig0()
}
}