use crate::{
authority::{
message_request::{MessageRequest, QueriesEmitAndCount},
Queries,
},
proto::{
error::*,
op::{
message::{self, EmitAndCount},
Edns, Header, ResponseCode,
},
rr::Record,
serialize::binary::BinEncoder,
},
server::ResponseInfo,
};
use super::message_request::WireQuery;
#[derive(Debug)]
pub struct MessageResponse<'q, 'a, Answers, NameServers, Soa, Additionals>
where
Answers: Iterator<Item = &'a Record> + Send + 'a,
NameServers: Iterator<Item = &'a Record> + Send + 'a,
Soa: Iterator<Item = &'a Record> + Send + 'a,
Additionals: Iterator<Item = &'a Record> + Send + 'a,
{
header: Header,
query: Option<&'q WireQuery>,
answers: Answers,
name_servers: NameServers,
soa: Soa,
additionals: Additionals,
sig0: Vec<Record>,
edns: Option<Edns>,
}
enum EmptyOrQueries<'q> {
Empty,
Queries(QueriesEmitAndCount<'q>),
}
impl<'q> From<Option<&'q Queries>> for EmptyOrQueries<'q> {
fn from(option: Option<&'q Queries>) -> Self {
option.map_or(EmptyOrQueries::Empty, |q| {
EmptyOrQueries::Queries(q.as_emit_and_count())
})
}
}
impl<'q> From<Option<&'q WireQuery>> for EmptyOrQueries<'q> {
fn from(option: Option<&'q WireQuery>) -> Self {
option.map_or(EmptyOrQueries::Empty, |q| {
EmptyOrQueries::Queries(q.as_emit_and_count())
})
}
}
impl<'q> EmitAndCount for EmptyOrQueries<'q> {
fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
match self {
EmptyOrQueries::Empty => Ok(0),
EmptyOrQueries::Queries(q) => q.emit(encoder),
}
}
}
impl<'q, 'a, A, N, S, D> MessageResponse<'q, 'a, A, N, S, D>
where
A: Iterator<Item = &'a Record> + Send + 'a,
N: Iterator<Item = &'a Record> + Send + 'a,
S: Iterator<Item = &'a Record> + Send + 'a,
D: Iterator<Item = &'a Record> + Send + 'a,
{
pub fn header(&self) -> &Header {
&self.header
}
pub fn header_mut(&mut self) -> &mut Header {
&mut self.header
}
pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
self.edns = Some(edns);
self
}
pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<ResponseInfo> {
let mut name_servers = self.name_servers.chain(self.soa);
message::emit_message_parts(
&self.header,
&mut EmptyOrQueries::from(self.query),
&mut self.answers,
&mut name_servers,
&mut self.additionals,
self.edns.as_ref(),
&self.sig0,
encoder,
)
.map(Into::into)
}
}
pub struct MessageResponseBuilder<'q> {
query: Option<&'q WireQuery>,
sig0: Option<Vec<Record>>,
edns: Option<Edns>,
}
impl<'q> MessageResponseBuilder<'q> {
pub(crate) fn new(query: Option<&'q WireQuery>) -> MessageResponseBuilder<'q> {
MessageResponseBuilder {
query,
sig0: None,
edns: None,
}
}
pub fn from_message_request(message: &'q MessageRequest) -> Self {
Self::new(Some(message.raw_query()))
}
pub fn edns(&mut self, edns: Edns) -> &mut Self {
self.edns = Some(edns);
self
}
pub fn build<'a, A, N, S, D>(
self,
header: Header,
answers: A,
name_servers: N,
soa: S,
additionals: D,
) -> MessageResponse<'q, 'a, A::IntoIter, N::IntoIter, S::IntoIter, D::IntoIter>
where
A: IntoIterator<Item = &'a Record> + Send + 'a,
A::IntoIter: Send,
N: IntoIterator<Item = &'a Record> + Send + 'a,
N::IntoIter: Send,
S: IntoIterator<Item = &'a Record> + Send + 'a,
S::IntoIter: Send,
D: IntoIterator<Item = &'a Record> + Send + 'a,
D::IntoIter: Send,
{
MessageResponse {
header,
query: self.query,
answers: answers.into_iter(),
name_servers: name_servers.into_iter(),
soa: soa.into_iter(),
additionals: additionals.into_iter(),
sig0: self.sig0.unwrap_or_default(),
edns: self.edns,
}
}
pub fn build_no_records<'a>(
self,
header: Header,
) -> MessageResponse<
'q,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
> {
MessageResponse {
header,
query: self.query,
answers: Box::new(None.into_iter()),
name_servers: Box::new(None.into_iter()),
soa: Box::new(None.into_iter()),
additionals: Box::new(None.into_iter()),
sig0: self.sig0.unwrap_or_default(),
edns: self.edns,
}
}
pub fn error_msg<'a>(
self,
request_header: &Header,
response_code: ResponseCode,
) -> MessageResponse<
'q,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
> {
let mut header = Header::response_from_request(request_header);
header.set_response_code(response_code);
MessageResponse {
header,
query: self.query,
answers: Box::new(None.into_iter()),
name_servers: Box::new(None.into_iter()),
soa: Box::new(None.into_iter()),
additionals: Box::new(None.into_iter()),
sig0: self.sig0.unwrap_or_default(),
edns: self.edns,
}
}
}
#[cfg(test)]
mod tests {
use std::iter;
use std::net::Ipv4Addr;
use std::str::FromStr;
use crate::proto::op::{Header, Message};
use crate::proto::rr::{DNSClass, Name, RData, Record};
use crate::proto::serialize::binary::BinEncoder;
use super::*;
#[test]
fn test_truncation_ridiculous_number_answers() {
let mut buf = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(512);
let answer = Record::new()
.set_name(Name::from_str("www.example.com.").unwrap())
.set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
.set_dns_class(DNSClass::NONE)
.clone();
let message = MessageResponse {
header: Header::new(),
query: None,
answers: iter::repeat(&answer),
name_servers: iter::once(&answer),
soa: iter::once(&answer),
additionals: iter::once(&answer),
sig0: vec![],
edns: None,
};
message
.destructive_emit(&mut encoder)
.expect("failed to encode");
}
let response = Message::from_vec(&buf).expect("failed to decode");
assert!(response.header().truncated());
assert!(response.answer_count() > 1);
assert_eq!(response.name_server_count(), 0);
}
#[test]
fn test_truncation_ridiculous_number_nameservers() {
let mut buf = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(512);
let answer = Record::new()
.set_name(Name::from_str("www.example.com.").unwrap())
.set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
.set_dns_class(DNSClass::NONE)
.clone();
let message = MessageResponse {
header: Header::new(),
query: None,
answers: iter::empty(),
name_servers: iter::repeat(&answer),
soa: iter::repeat(&answer),
additionals: iter::repeat(&answer),
sig0: vec![],
edns: None,
};
message
.destructive_emit(&mut encoder)
.expect("failed to encode");
}
let response = Message::from_vec(&buf).expect("failed to decode");
assert!(response.header().truncated());
assert_eq!(response.answer_count(), 0);
assert!(response.name_server_count() > 1);
}
}