#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]
use crate::base::iana::Rcode;
use crate::base::message::{CopyRecordsError, ShortMessage};
use crate::base::message_builder::{
AdditionalBuilder, MessageBuilder, PushError, StaticCompressor,
};
use crate::base::opt::{ComposeOptData, LongOptData, OptRecord};
use crate::base::wire::{Composer, ParseError};
use crate::base::{Header, Message, ParsedName, Rtype};
use crate::rdata::AllRecordData;
use bytes::Bytes;
use octseq::Octets;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{error, fmt};
use tracing::trace;
pub trait ComposeRequest: Debug + Send + Sync {
fn append_message<Target: Composer>(
&self,
target: &mut Target,
) -> Result<(), CopyRecordsError>;
fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;
fn to_vec(&self) -> Result<Vec<u8>, Error>;
fn header_mut(&mut self) -> &mut Header;
fn set_udp_payload_size(&mut self, value: u16);
fn set_dnssec_ok(&mut self, value: bool);
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData>;
fn is_answer(&self, answer: &Message<[u8]>) -> bool;
}
pub trait SendRequest<CR> {
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponse + Send + Sync>;
}
pub trait GetResponse: Debug {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
>;
}
#[derive(Clone, Debug)]
pub struct RequestMessage<Octs: AsRef<[u8]>> {
msg: Message<Octs>,
header: Header,
opt: Option<OptRecord<Vec<u8>>>,
}
impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessage<Octs> {
pub fn new(msg: impl Into<Message<Octs>>) -> Self {
let msg = msg.into();
let header = msg.header();
Self {
msg,
header,
opt: None,
}
}
fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
self.opt.get_or_insert_with(Default::default)
}
fn append_message_impl<Target: Composer>(
&self,
mut target: MessageBuilder<Target>,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let source = &self.msg;
*target.header_mut() = self.header;
let source = source.question();
let mut target = target.question();
for rr in source {
target.push(rr?)?;
}
let mut source = source.answer()?;
let mut target = target.answer();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let mut source =
source.next_section()?.expect("section should be present");
let mut target = target.authority();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let source =
source.next_section()?.expect("section should be present");
let mut target = target.additional();
for rr in source {
let rr = rr?;
if rr.rtype() != Rtype::OPT {
let rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
}
if let Some(opt) = self.opt.as_ref() {
target.push(opt.as_record())?;
}
Ok(target)
}
fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
let target =
MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
.expect("Vec is expected to have enough space");
let target = self.append_message_impl(target)?;
let result = target.as_builder().clone();
let msg = Message::from_octets(result.finish().into_target()).expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
}
impl<Octs: AsRef<[u8]> + Clone + Debug + Octets + Send + Sync + 'static>
ComposeRequest for RequestMessage<Octs>
{
fn append_message<Target: Composer>(
&self,
target: &mut Target,
) -> Result<(), CopyRecordsError> {
let target = MessageBuilder::from_target(target)
.map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
self.append_message_impl(target)?;
Ok(())
}
fn to_vec(&self) -> Result<Vec<u8>, Error> {
let msg = self.to_message()?;
Ok(msg.as_octets().clone())
}
fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
self.to_message_impl()
}
fn header_mut(&mut self) -> &mut Header {
&mut self.header
}
fn set_udp_payload_size(&mut self, value: u16) {
self.opt_mut().set_udp_payload_size(value);
}
fn set_dnssec_ok(&mut self, value: bool) {
self.opt_mut().set_dnssec_ok(value);
}
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData> {
self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
}
fn is_answer(&self, answer: &Message<[u8]>) -> bool {
let answer_header = answer.header();
let answer_hcounts = answer.header_counts();
if !answer_header.qr() || answer_header.id() != self.header.id() {
trace!(
"Wrong QR or ID: QR={}, answer ID={}, self ID={}",
answer_header.qr(),
answer_header.id(),
self.header.id()
);
return false;
}
if answer_header.rcode() != Rcode::NOERROR
&& answer_hcounts.qdcount() == 0
&& answer_hcounts.ancount() == 0
&& answer_hcounts.nscount() == 0
&& answer_hcounts.arcount() == 0
{
return true;
}
if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() {
trace!("Wrong QD count");
false
} else {
let res = answer.question() == self.msg.for_slice().question();
if !res {
trace!("Wrong question");
}
res
}
}
}
#[derive(Clone, Debug)]
pub enum Error {
ConnectionClosed,
OptTooLong,
MessageBuilderPushError,
MessageParseError,
RedundantTransportNotFound,
ShortMessage,
StreamLongMessage,
StreamIdleTimeout,
StreamReceiveError,
StreamReadError(Arc<std::io::Error>),
StreamReadTimeout,
StreamTooManyOutstandingQueries,
StreamWriteError(Arc<std::io::Error>),
StreamUnexpectedEndOfData,
WrongReplyForQuery,
NoTransportAvailable,
Dgram(Arc<super::dgram::QueryError>),
}
impl From<LongOptData> for Error {
fn from(_: LongOptData) -> Self {
Self::OptTooLong
}
}
impl From<ParseError> for Error {
fn from(_: ParseError) -> Self {
Self::MessageParseError
}
}
impl From<ShortMessage> for Error {
fn from(_: ShortMessage) -> Self {
Self::ShortMessage
}
}
impl From<super::dgram::QueryError> for Error {
fn from(err: super::dgram::QueryError) -> Self {
Self::Dgram(err.into())
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ConnectionClosed => write!(f, "connection closed"),
Error::OptTooLong => write!(f, "OPT record is too long"),
Error::MessageBuilderPushError => {
write!(f, "PushError from MessageBuilder")
}
Error::MessageParseError => write!(f, "ParseError from Message"),
Error::RedundantTransportNotFound => write!(
f,
"Underlying transport not found in redundant connection"
),
Error::ShortMessage => {
write!(f, "octet sequence to short to be a valid message")
}
Error::StreamLongMessage => {
write!(f, "message too long for stream transport")
}
Error::StreamIdleTimeout => {
write!(f, "stream was idle for too long")
}
Error::StreamReceiveError => write!(f, "error receiving a reply"),
Error::StreamReadError(_) => {
write!(f, "error reading from stream")
}
Error::StreamReadTimeout => {
write!(f, "timeout reading from stream")
}
Error::StreamTooManyOutstandingQueries => {
write!(f, "too many outstanding queries on stream")
}
Error::StreamWriteError(_) => {
write!(f, "error writing to stream")
}
Error::StreamUnexpectedEndOfData => {
write!(f, "unexpected end of data")
}
Error::WrongReplyForQuery => {
write!(f, "reply does not match query")
}
Error::NoTransportAvailable => {
write!(f, "no transport available")
}
Error::Dgram(err) => fmt::Display::fmt(err, f),
}
}
}
impl From<CopyRecordsError> for Error {
fn from(err: CopyRecordsError) -> Self {
match err {
CopyRecordsError::Parse(_) => Self::MessageParseError,
CopyRecordsError::Push(_) => Self::MessageBuilderPushError,
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::ConnectionClosed => None,
Error::OptTooLong => None,
Error::MessageBuilderPushError => None,
Error::MessageParseError => None,
Error::RedundantTransportNotFound => None,
Error::ShortMessage => None,
Error::StreamLongMessage => None,
Error::StreamIdleTimeout => None,
Error::StreamReceiveError => None,
Error::StreamReadError(e) => Some(e),
Error::StreamReadTimeout => None,
Error::StreamTooManyOutstandingQueries => None,
Error::StreamWriteError(e) => Some(e),
Error::StreamUnexpectedEndOfData => None,
Error::WrongReplyForQuery => None,
Error::NoTransportAvailable => None,
Error::Dgram(err) => Some(err),
}
}
}