use alloc::vec::Vec;
use core::{
convert::TryFrom,
ops::{Deref, DerefMut},
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
error::ProtoError,
op::{Message, MessageType},
rr::{RData, RecordType, rdata::SOA, record::RecordRef},
};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct DnsResponse {
message: Message,
buffer: Vec<u8>,
}
impl DnsResponse {
pub fn from_message(message: Message) -> Result<Self, ProtoError> {
if message.metadata.message_type != MessageType::Response {
return Err(ProtoError::NotAResponse);
}
Ok(Self {
buffer: message.to_vec()?,
message,
})
}
pub fn from_buffer(buffer: Vec<u8>) -> Result<Self, ProtoError> {
let message = Message::from_vec(&buffer)?;
if message.metadata.message_type != MessageType::Response {
return Err(ProtoError::NotAResponse);
}
Ok(Self { message, buffer })
}
pub fn soa(&self) -> Option<RecordRef<'_, SOA>> {
self.authorities
.iter()
.find_map(|record| RecordRef::try_from(record).ok())
}
pub fn negative_ttl(&self) -> Option<u32> {
self.authorities
.iter()
.filter_map(|record| match &record.data {
RData::SOA(soa) => Some((record.ttl, soa)),
_ => None,
})
.next()
.map(|(ttl, soa)| (ttl).min(soa.minimum))
}
pub fn contains_answer(&self) -> bool {
for q in &self.queries {
let found = match q.query_type() {
RecordType::ANY => self.all_sections().any(|r| &r.name == q.name()),
RecordType::SOA => {
self.all_sections()
.filter(|r| r.record_type().is_soa())
.any(|r| r.name.zone_of(q.name()))
}
q_type => {
if !self.answers.is_empty() {
true
} else {
self.all_sections()
.filter(|r| r.record_type() == q_type)
.any(|r| &r.name == q.name())
}
}
};
if found {
return true;
}
}
false
}
pub fn as_buffer(&self) -> &[u8] {
&self.buffer
}
pub fn into_buffer(self) -> Vec<u8> {
self.buffer
}
pub fn into_message(self) -> Message {
self.message
}
pub fn into_parts(self) -> (Message, Vec<u8>) {
(self.message, self.buffer)
}
}
impl Deref for DnsResponse {
type Target = Message;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl DerefMut for DnsResponse {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.message
}
}
impl From<DnsResponse> for Message {
fn from(response: DnsResponse) -> Self {
response.message
}
}
#[cfg(all(test, any(feature = "std", feature = "no-std-rand")))]
mod tests {
use crate::op::{Message, Query, ResponseCode};
use crate::rr::RData;
use crate::rr::rdata::{A, NS, SOA};
use crate::rr::{Name, Record, RecordType};
use super::*;
fn xx() -> Name {
Name::from_ascii("XX.").unwrap()
}
fn ns1() -> Name {
Name::from_ascii("NS1.XX.").unwrap()
}
fn hostmaster() -> Name {
Name::from_ascii("HOSTMASTER.NS1.XX.").unwrap()
}
fn example() -> Name {
Name::from_ascii("EXAMPLE.").unwrap()
}
fn an_example() -> Name {
Name::from_ascii("AN.EXAMPLE.").unwrap()
}
fn ns1_record() -> Record {
Record::from_rdata(xx(), 88640, RData::NS(NS(ns1())))
}
fn ns1_a() -> Record {
Record::from_rdata(xx(), 88640, RData::A(A::new(127, 0, 0, 2)))
}
fn soa() -> Record {
Record::from_rdata(
example(),
88640,
RData::SOA(SOA::new(ns1(), hostmaster(), 1, 2, 3, 4, 5)),
)
}
#[test]
fn test_contains_answer() {
let mut message = Message::query();
message.metadata.response_code = ResponseCode::NXDomain;
message.add_query(Query::query(Name::root(), RecordType::A));
message.add_answer(Record::from_rdata(
Name::root(),
88640,
RData::A(A::new(127, 0, 0, 2)),
));
let response = DnsResponse::from_message(message.into_response()).unwrap();
assert!(response.contains_answer())
}
#[test]
fn contains_soa() {
let mut message = Message::query();
message.metadata.response_code = ResponseCode::NoError;
message.add_query(Query::query(an_example(), RecordType::SOA));
message.add_authority(soa());
let response = DnsResponse::from_message(message.into_response()).unwrap();
assert!(response.contains_answer());
}
#[test]
fn contains_any() {
let mut message = Message::query();
message.metadata.response_code = ResponseCode::NoError;
message.add_query(Query::query(xx(), RecordType::ANY));
message.add_authority(ns1_record());
message.add_additional(ns1_a());
let response = DnsResponse::from_message(message.into_response()).unwrap();
assert!(response.contains_answer());
}
#[test]
fn not_a_response() {
assert!(matches!(
DnsResponse::from_message(Message::query()).unwrap_err(),
ProtoError::NotAResponse
));
assert!(matches!(
DnsResponse::from_buffer(Message::query().to_vec().unwrap()).unwrap_err(),
ProtoError::NotAResponse
));
}
}