use core::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use std::collections::{HashMap, hash_map::Entry};
use futures_channel::mpsc;
use futures_util::{
FutureExt,
future::BoxFuture,
stream::{Stream, StreamExt},
};
use rand::RngExt;
use tracing::debug;
use super::{
BufDnsStreamHandle, DnsClientStream, DnsRequestSender, DnsResponseStream, ignore_send,
};
use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
#[cfg(feature = "__dnssec")]
use crate::proto::rr::{TSigVerifier, TSigner};
use crate::{DnsStreamHandle, error::NetError, runtime::Time};
struct ActiveRequest {
completion: mpsc::Sender<Result<DnsResponse, NetError>>,
request_id: u16,
timeout: BoxFuture<'static, ()>,
#[cfg(feature = "__dnssec")]
verifier: Option<TSigVerifier>,
}
impl ActiveRequest {
fn new(
completion: mpsc::Sender<Result<DnsResponse, NetError>>,
request_id: u16,
timeout: BoxFuture<'static, ()>,
#[cfg(feature = "__dnssec")] verifier: Option<TSigVerifier>,
) -> Self {
Self {
completion,
request_id,
timeout,
#[cfg(feature = "__dnssec")]
verifier,
}
}
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.timeout.poll_unpin(cx)
}
fn is_canceled(&self) -> bool {
self.completion.is_closed()
}
fn request_id(&self) -> u16 {
self.request_id
}
fn complete_with_error(mut self, error: NetError) {
ignore_send(self.completion.try_send(Err(error)));
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexer<S> {
stream: S,
timeout_duration: Duration,
stream_handle: BufDnsStreamHandle,
active_requests: HashMap<u16, ActiveRequest>,
max_active_requests: usize,
#[cfg(feature = "__dnssec")]
signer: Option<TSigner>,
is_shutdown: bool,
}
impl<S: DnsClientStream> DnsMultiplexer<S> {
pub fn new(stream: S, stream_handle: BufDnsStreamHandle) -> Self {
Self {
stream,
timeout_duration: Duration::from_secs(5),
stream_handle,
active_requests: HashMap::default(),
max_active_requests: 32,
#[cfg(feature = "__dnssec")]
signer: None,
is_shutdown: false,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout_duration = timeout;
self
}
pub fn with_max_active_requests(mut self, max: usize) -> Self {
self.max_active_requests = max;
self
}
#[cfg(feature = "__dnssec")]
pub fn with_signer(mut self, signer: TSigner) -> Self {
self.signer = Some(signer);
self
}
fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
let mut canceled = HashMap::<u16, NetError>::new();
for (&id, active_req) in &mut self.active_requests {
if active_req.is_canceled() {
canceled.insert(id, NetError::from("requestor canceled"));
}
match active_req.poll_timeout(cx) {
Poll::Ready(()) => {
debug!("request timed out: {}", id);
canceled.insert(id, NetError::Timeout);
}
Poll::Pending => (),
}
}
for (id, error) in canceled {
if let Some(active_request) = self.active_requests.remove(&id) {
active_request.complete_with_error(error);
}
}
}
fn next_random_query_id(&self) -> Result<u16, NetError> {
let mut rand = rand::rng();
for _ in 0..100 {
let id: u16 = rand.random();
if !self.active_requests.contains_key(&id) {
return Ok(id);
}
}
Err(NetError::from(
"id space exhausted, consider filing an issue",
))
}
fn stream_closed_close_all(&mut self, error: NetError) {
debug!(%error, addr = %self.stream.name_server_addr());
for (_, active_request) in self.active_requests.drain() {
active_request.complete_with_error(error.clone());
}
}
}
impl<S: DnsClientStream> DnsRequestSender for DnsMultiplexer<S> {
fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
if self.active_requests.len() >= self.max_active_requests {
return NetError::Busy.into();
}
let query_id = match self.next_random_query_id() {
Ok(id) => id,
Err(e) => return e.into(),
};
let (mut request, _) = request.into_parts();
request.metadata.id = query_id;
#[cfg(feature = "__dnssec")]
let mut verifier = None;
#[cfg(feature = "__dnssec")]
if let Some(signer) = &self.signer {
if signer.should_sign_message(&request) {
match request.finalize(signer, S::Time::current_time()) {
Ok(answer_verifier) => verifier = answer_verifier,
Err(e) => {
debug!("could not sign message: {}", e);
return NetError::from(e).into();
}
}
}
}
let timeout = S::Time::delay_for(self.timeout_duration);
let (complete, receiver) = mpsc::channel(QUERY_RESPONSE_BUFFER_SIZE);
let active_request = ActiveRequest::new(
complete,
request.id,
timeout,
#[cfg(feature = "__dnssec")]
verifier,
);
match request.to_vec() {
Ok(buffer) => {
debug!(id = %active_request.request_id(), "sending message");
let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
debug!(
"final message: {}",
serial_message
.to_message()
.expect("bizarre we just made this message")
);
match self.stream_handle.send(serial_message) {
Ok(()) => self
.active_requests
.insert(active_request.request_id(), active_request),
Err(err) => return err.into(),
};
}
Err(error) => {
debug!(
id = %active_request.request_id(),
%error,
"error message"
);
return NetError::from(error).into();
}
}
receiver.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<S: DnsClientStream> Stream for DnsMultiplexer<S> {
type Item = Result<(), NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.drop_cancelled(cx);
if self.is_shutdown && self.active_requests.is_empty() {
debug!("stream is done: {}", self.stream.name_server_addr());
return Poll::Ready(None);
}
let mut messages_received = 0;
for i in 0..QOS_MAX_RECEIVE_MSGS {
match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(buffer))) => {
messages_received = i;
match DnsResponse::from_buffer(buffer.into_parts().0) {
Ok(response) => match self.active_requests.entry(response.id) {
Entry::Occupied(mut request_entry) => {
let active_request = request_entry.get_mut();
#[cfg(feature = "__dnssec")]
if let Some(verifier) = &mut active_request.verifier {
ignore_send(
active_request.completion.try_send(
verifier
.verify(response.as_buffer())
.map_err(NetError::from),
),
);
} else {
ignore_send(active_request.completion.try_send(Ok(response)));
}
#[cfg(not(feature = "__dnssec"))]
ignore_send(active_request.completion.try_send(Ok(response)));
}
Entry::Vacant(..) => debug!("unexpected request_id: {}", response.id),
},
Err(error) => debug!(%error, "error decoding message"),
}
}
Poll::Ready(err) => {
let err = match err {
Some(Err(e)) => e,
None => NetError::from("stream closed"),
_ => unreachable!(),
};
self.stream_closed_close_all(err);
self.is_shutdown = true;
return Poll::Ready(None);
}
Poll::Pending => break,
}
}
if messages_received == QOS_MAX_RECEIVE_MSGS {
cx.waker().wake_by_ref();
}
Poll::Pending
}
}
const QOS_MAX_RECEIVE_MSGS: usize = 100;
const QUERY_RESPONSE_BUFFER_SIZE: usize = 8;
#[cfg(test)]
mod test {
use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use futures_util::{
future::{self, BoxFuture},
ready,
stream::TryStreamExt,
};
use test_support::subscribe;
use super::*;
use crate::proto::op::{DnsRequestOptions, Message, Query};
use crate::proto::rr::rdata::{NS, SOA};
use crate::proto::rr::{DNSClass, Name, RData, Record, RecordType};
use crate::proto::serialize::binary::BinEncodable;
use crate::xfer::{DnsClientStream, StreamReceiver};
struct MockClientStream {
messages: Vec<Message>,
addr: SocketAddr,
id: Option<u16>,
receiver: Option<StreamReceiver>,
}
impl MockClientStream {
fn new(
mut messages: Vec<Message>,
addr: SocketAddr,
) -> BoxFuture<'static, Result<Self, NetError>> {
messages.reverse(); Box::pin(future::ok(Self {
messages,
addr,
id: None,
receiver: None,
}))
}
}
impl Stream for MockClientStream {
type Item = Result<SerialMessage, NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let id = if let Some(id) = self.id {
id
} else {
let serial = ready!(
self.receiver
.as_mut()
.expect("should only be polled after receiver has been set")
.poll_next_unpin(cx)
);
let message = serial.unwrap().to_message().unwrap();
self.id = Some(message.id);
message.id
};
if let Some(mut message) = self.messages.pop() {
message.metadata.id = id;
Poll::Ready(Some(Ok(SerialMessage::new(
message.to_bytes().unwrap(),
self.addr,
))))
} else {
Poll::Pending
}
}
}
impl DnsClientStream for MockClientStream {
type Time = crate::runtime::TokioTime;
fn name_server_addr(&self) -> SocketAddr {
self.addr
}
}
async fn get_mocked_multiplexer(
mock_response: Vec<Message>,
) -> DnsMultiplexer<MockClientStream> {
let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let mock_response = MockClientStream::new(mock_response, addr).await.unwrap();
let (handler, receiver) = BufDnsStreamHandle::new(addr);
let mut multiplexer =
DnsMultiplexer::new(mock_response, handler).with_timeout(Duration::from_millis(100));
multiplexer.stream.receiver = Some(receiver);
multiplexer
}
fn a_query_answer() -> (DnsRequest, Vec<Message>) {
let name = Name::from_ascii("www.example.com.").unwrap();
let mut request = Message::query();
request.metadata.recursion_desired = true;
request.add_query({
let mut q = Query::query(name.clone(), RecordType::A);
q.set_query_class(DNSClass::IN);
q
});
let mut response = request.clone().into_response();
response.add_answer(Record::from_rdata(
name,
86400,
RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
));
(
DnsRequest::new(request, DnsRequestOptions::default()),
vec![response],
)
}
fn axfr_query() -> Message {
let name = Name::from_ascii("example.com.").unwrap();
let mut msg = Message::query();
msg.metadata.recursion_desired = true;
msg.add_query({
let mut query = Query::query(name, RecordType::AXFR);
query.set_query_class(DNSClass::IN);
query
});
msg
}
fn axfr_response() -> Vec<Record> {
let origin = Name::from_ascii("example.com.").unwrap();
let soa = Record::from_rdata(
origin.clone(),
3600,
RData::SOA(SOA::new(
Name::parse("sns.dns.icann.org.", None).unwrap(),
Name::parse("noc.dns.icann.org.", None).unwrap(),
2015082403,
7200,
3600,
1209600,
3600,
)),
);
vec![
soa.clone(),
Record::from_rdata(
origin.clone(),
86400,
RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
),
Record::from_rdata(
origin.clone(),
86400,
RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
),
Record::from_rdata(
origin.clone(),
86400,
RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
),
Record::from_rdata(
origin,
86400,
RData::AAAA(
Ipv6Addr::new(
0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
)
.into(),
),
),
soa,
]
}
fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
let msg = axfr_query();
let mut response = msg.clone().into_response();
response.insert_answers(axfr_response());
(
DnsRequest::new(msg, DnsRequestOptions::default()),
vec![response],
)
}
fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
let base = axfr_query();
let query = base.clone();
let mut rr = axfr_response();
let rr2 = rr.split_off(3);
let mut msg1 = base.clone().into_response();
msg1.insert_answers(rr);
let mut msg2 = base.into_response();
msg2.insert_answers(rr2);
(
DnsRequest::new(query, DnsRequestOptions::default()),
vec![msg1, msg2],
)
}
#[tokio::test]
async fn test_multiplexer_a() {
subscribe();
let (query, answer) = a_query_answer();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 1);
}
#[tokio::test]
async fn test_multiplexer_axfr() {
subscribe();
let (query, answer) = axfr_query_answer();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 1);
assert_eq!(response[0].answers.len(), axfr_response().len());
}
#[tokio::test]
async fn test_multiplexer_axfr_multi() {
subscribe();
let (query, answer) = axfr_query_answer_multi();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 2);
assert_eq!(
response.iter().map(|m| m.answers.len()).sum::<usize>(),
axfr_response().len()
);
}
}