#[cfg(feature = "std")]
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::{
convert::TryFrom,
ops::{Deref, DerefMut},
};
#[cfg(feature = "std")]
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "std")]
use std::io;
#[cfg(feature = "std")]
use futures_channel::mpsc;
#[cfg(feature = "std")]
use futures_util::{ready, stream::Stream};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "std")]
use crate::{ProtoErrorKind, error::ProtoResult};
use crate::{
error::ProtoError,
op::{Message, ResponseCode},
rr::{RecordType, rdata::SOA, resource::RecordRef},
};
#[cfg(feature = "std")]
pub struct DnsResponseStream {
inner: DnsResponseStreamInner,
done: bool,
}
#[cfg(feature = "std")]
impl DnsResponseStream {
fn new(inner: DnsResponseStreamInner) -> Self {
Self { inner, done: false }
}
}
#[cfg(feature = "std")]
impl Stream for DnsResponseStream {
type Item = Result<DnsResponse, ProtoError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use DnsResponseStreamInner::*;
if self.done {
return Poll::Ready(None);
}
let Self { inner, done } = self.get_mut();
let result = match inner {
Timeout(fut) => {
let x = match ready!(fut.as_mut().poll(cx)) {
Ok(x) => x,
Err(e) => Err(e.into()),
};
*done = true;
x
}
Receiver(fut) => match ready!(Pin::new(fut).poll_next(cx)) {
Some(x) => x,
None => return Poll::Ready(None),
},
Error(err) => {
*done = true;
Err(err.take().expect("cannot poll after complete"))
}
Boxed(fut) => {
let x = ready!(fut.as_mut().poll(cx));
*done = true;
x
}
};
match result {
Err(e) if matches!(e.kind(), ProtoErrorKind::Timeout) => Poll::Ready(None),
r => Poll::Ready(Some(r)),
}
}
}
#[cfg(feature = "std")]
impl From<TimeoutFuture> for DnsResponseStream {
fn from(f: TimeoutFuture) -> Self {
Self::new(DnsResponseStreamInner::Timeout(f))
}
}
#[cfg(feature = "std")]
impl From<mpsc::Receiver<ProtoResult<DnsResponse>>> for DnsResponseStream {
fn from(receiver: mpsc::Receiver<ProtoResult<DnsResponse>>) -> Self {
Self::new(DnsResponseStreamInner::Receiver(receiver))
}
}
#[cfg(feature = "std")]
impl From<ProtoError> for DnsResponseStream {
fn from(e: ProtoError) -> Self {
Self::new(DnsResponseStreamInner::Error(Some(e)))
}
}
#[cfg(feature = "std")]
impl<F> From<Pin<Box<F>>> for DnsResponseStream
where
F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static,
{
fn from(f: Pin<Box<F>>) -> Self {
Self::new(DnsResponseStreamInner::Boxed(
f as Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
))
}
}
#[cfg(feature = "std")]
enum DnsResponseStreamInner {
Timeout(TimeoutFuture),
Receiver(mpsc::Receiver<ProtoResult<DnsResponse>>),
Error(Option<ProtoError>),
Boxed(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>),
}
#[cfg(feature = "std")]
type TimeoutFuture = Pin<
Box<dyn Future<Output = Result<Result<DnsResponse, ProtoError>, io::Error>> + Send + 'static>,
>;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct DnsResponse {
message: Message,
buffer: Vec<u8>,
}
impl DnsResponse {
pub fn from_message(message: Message) -> Result<Self, ProtoError> {
Ok(Self {
buffer: message.to_vec()?,
message,
})
}
pub fn from_buffer(buffer: Vec<u8>) -> Result<Self, ProtoError> {
let message = Message::from_vec(&buffer)?;
Ok(Self { message, buffer })
}
pub fn soa(&self) -> Option<RecordRef<'_, SOA>> {
self.name_servers()
.iter()
.find_map(|record| RecordRef::try_from(record).ok())
}
pub fn negative_ttl(&self) -> Option<u32> {
self.name_servers()
.iter()
.filter_map(|record| record.data().as_soa().map(|soa| (record.ttl(), soa)))
.next()
.map(|(ttl, soa)| (ttl).min(soa.minimum()))
}
pub fn contains_answer(&self) -> bool {
for q in self.queries() {
let found = match q.query_type() {
RecordType::ANY => self.all_sections().any(|r| r.name() == q.name()),
RecordType::SOA => {
self.all_sections()
.filter(|r| r.record_type().is_soa())
.any(|r| r.name().zone_of(q.name()))
}
q_type => {
if !self.answers().is_empty() {
true
} else {
self.all_sections()
.filter(|r| r.record_type() == q_type)
.any(|r| r.name() == q.name())
}
}
};
if found {
return true;
}
}
false
}
pub fn negative_type(&self) -> Option<NegativeType> {
let response_code = self.response_code();
let ttl_from_soa = self.negative_ttl();
let has_soa = ttl_from_soa.is_some();
let has_ns_records = self.name_servers().iter().any(|r| r.record_type().is_ns());
let has_cname = self.answers().iter().any(|r| r.record_type().is_cname());
let has_non_cname = self.answers().iter().any(|r| !r.record_type().is_cname());
let has_additionals = self.additional_count() > 0;
match (
response_code,
has_soa,
has_ns_records,
has_cname,
has_non_cname,
has_additionals,
) {
(ResponseCode::NXDomain, true, true, _, false, _) => Some(NegativeType::NameErrorType1),
(ResponseCode::NXDomain, true, false, _, false, _) => {
Some(NegativeType::NameErrorType2)
}
(ResponseCode::NXDomain, false, false, true, false, _) => {
Some(NegativeType::NameErrorType3)
}
(ResponseCode::NXDomain, false, true, _, false, _) => {
Some(NegativeType::NameErrorType4)
}
(ResponseCode::NoError, true, true, false, false, _) => Some(NegativeType::NoDataType1),
(ResponseCode::NoError, true, false, false, false, _) => {
Some(NegativeType::NoDataType2)
}
(ResponseCode::NoError, false, false, false, false, false) => {
Some(NegativeType::NoDataType3)
}
(ResponseCode::NoError, false, true, _, false, _) => Some(NegativeType::Referral),
_ => None,
}
}
pub fn as_buffer(&self) -> &[u8] {
&self.buffer
}
pub fn into_buffer(self) -> Vec<u8> {
self.buffer
}
pub fn into_message(self) -> Message {
self.message
}
pub fn into_parts(self) -> (Message, Vec<u8>) {
(self.message, self.buffer)
}
}
impl Deref for DnsResponse {
type Target = Message;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl DerefMut for DnsResponse {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.message
}
}
impl From<DnsResponse> for Message {
fn from(response: DnsResponse) -> Self {
response.message
}
}
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub enum NegativeType {
NameErrorType1,
NameErrorType2,
NameErrorType3,
NameErrorType4,
NoDataType1,
NoDataType2,
NoDataType3,
Referral,
}
impl NegativeType {
pub fn is_authoritative(&self) -> bool {
matches!(
self,
Self::NameErrorType1 | Self::NameErrorType2 | Self::NoDataType1 | Self::NoDataType2
)
}
}
#[cfg(test)]
mod tests {
use crate::op::{Message, Query, ResponseCode};
use crate::rr::RData;
use crate::rr::rdata::{A, CNAME, NS, SOA};
use crate::rr::{Name, Record, RecordType};
use super::*;
fn xx() -> Name {
Name::from_ascii("XX.").unwrap()
}
fn ns1() -> Name {
Name::from_ascii("NS1.XX.").unwrap()
}
fn ns2() -> Name {
Name::from_ascii("NS1.XX.").unwrap()
}
fn hostmaster() -> Name {
Name::from_ascii("HOSTMASTER.NS1.XX.").unwrap()
}
fn tripple_xx() -> Name {
Name::from_ascii("TRIPPLE.XX.").unwrap()
}
fn example() -> Name {
Name::from_ascii("EXAMPLE.").unwrap()
}
fn an_example() -> Name {
Name::from_ascii("AN.EXAMPLE.").unwrap()
}
fn another_example() -> Name {
Name::from_ascii("ANOTHER.EXAMPLE.").unwrap()
}
fn an_cname_record() -> Record {
Record::from_rdata(an_example(), 88640, RData::CNAME(CNAME(tripple_xx())))
}
fn ns1_record() -> Record {
Record::from_rdata(xx(), 88640, RData::NS(NS(ns1())))
}
fn ns2_record() -> Record {
Record::from_rdata(xx(), 88640, RData::NS(NS(ns2())))
}
fn ns1_a() -> Record {
Record::from_rdata(xx(), 88640, RData::A(A::new(127, 0, 0, 2)))
}
fn ns2_a() -> Record {
Record::from_rdata(xx(), 88640, RData::A(A::new(127, 0, 0, 3)))
}
fn soa() -> Record {
Record::from_rdata(
example(),
88640,
RData::SOA(SOA::new(ns1(), hostmaster(), 1, 2, 3, 4, 5)),
)
}
fn an_query() -> Query {
Query::query(an_example(), RecordType::A)
}
fn another_query() -> Query {
Query::query(another_example(), RecordType::A)
}
#[test]
fn test_contains_answer() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NXDomain);
message.add_query(Query::query(Name::root(), RecordType::A));
message.add_answer(Record::from_rdata(
Name::root(),
88640,
RData::A(A::new(127, 0, 0, 2)),
));
let response = DnsResponse::from_message(message).unwrap();
assert!(response.contains_answer())
}
#[test]
fn test_nx_type1() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NXDomain);
message.add_query(an_query());
message.add_answer(an_cname_record());
message.add_name_server(soa());
message.add_name_server(ns1_record());
message.add_name_server(ns2_record());
message.add_additional(ns1_a());
message.add_additional(ns2_a());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NameErrorType1);
}
#[test]
fn test_nx_type2() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NXDomain);
message.add_query(an_query());
message.add_answer(an_cname_record());
message.add_name_server(soa());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NameErrorType2);
}
#[test]
fn test_nx_type3() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NXDomain);
message.add_query(an_query());
message.add_answer(an_cname_record());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NameErrorType3);
}
#[test]
fn test_nx_type4() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NXDomain);
message.add_query(an_query());
message.add_answer(an_cname_record());
message.add_name_server(ns1_record());
message.add_name_server(ns2_record());
message.add_additional(ns1_a());
message.add_additional(ns2_a());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NameErrorType4);
}
#[test]
fn test_no_data_type1() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(another_query());
message.add_name_server(soa());
message.add_name_server(ns1_record());
message.add_name_server(ns2_record());
message.add_additional(ns1_a());
message.add_additional(ns2_a());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(!response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NoDataType1);
}
#[test]
fn test_no_data_type2() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(another_query());
message.add_name_server(soa());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(!response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NoDataType2);
}
#[test]
fn test_no_data_type3() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(another_query());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(!response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::NoDataType3);
}
#[test]
fn referral() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(an_query());
message.add_answer(an_cname_record());
message.add_name_server(ns1_record());
message.add_name_server(ns2_record());
message.add_additional(ns1_a());
message.add_additional(ns2_a());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::Referral);
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(another_query());
message.add_name_server(ns1_record());
message.add_name_server(ns2_record());
message.add_additional(ns1_a());
message.add_additional(ns2_a());
let response = DnsResponse::from_message(message).unwrap();
let ty = response.negative_type();
assert!(!response.contains_answer());
assert_eq!(ty.unwrap(), NegativeType::Referral);
}
#[test]
fn contains_soa() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(Query::query(an_example(), RecordType::SOA));
message.add_name_server(soa());
let response = DnsResponse::from_message(message).unwrap();
assert!(response.contains_answer());
}
#[test]
fn contains_any() {
let mut message = Message::default();
message.set_response_code(ResponseCode::NoError);
message.add_query(Query::query(xx(), RecordType::ANY));
message.add_name_server(ns1_record());
message.add_additional(ns1_a());
let response = DnsResponse::from_message(message).unwrap();
assert!(response.contains_answer());
}
}