use core::fmt::{self, Display};
use core::net::SocketAddr;
use core::pin::Pin;
use core::task::{Context, Poll};
use core::time::Duration;
use std::collections::HashSet;
use std::sync::Arc;
use futures_util::{FutureExt, Stream, StreamExt, pin_mut, stream::FuturesUnordered};
use tracing::{debug, trace, warn};
use crate::error::NetError;
use crate::proto::op::{DEFAULT_RETRY_FLOOR, DnsRequest, DnsResponse, Message, SerialMessage};
#[cfg(feature = "__dnssec")]
use crate::proto::rr::TSigner;
use crate::runtime::{DnsUdpSocket, RuntimeProvider, Spawn, Time};
use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
use crate::udp::udp_stream::NextRandomUdpSocket;
use crate::xfer::{DnsExchange, DnsRequestSender, DnsResponseStream};
#[must_use = "futures do nothing unless polled"]
pub struct UdpClientStream<P> {
name_server: SocketAddr,
timeout: Duration,
is_shutdown: bool,
#[cfg(feature = "__dnssec")]
signer: Option<TSigner>,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
max_retries: u8,
retry_interval_floor: Duration,
}
impl<P: RuntimeProvider> UdpClientStream<P> {
pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
UdpClientStreamBuilder {
name_server,
timeout: None,
#[cfg(feature = "__dnssec")]
signer: None,
bind_addr: None,
avoid_local_ports: Arc::default(),
os_port_selection: false,
provider,
max_retries: 3,
retry_interval_floor: DEFAULT_RETRY_FLOOR,
}
}
}
impl<P> Display for UdpClientStream<P> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(formatter, "UDP({})", self.name_server)
}
}
impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
let retry_interval_time = request.options().retry_interval;
let request = UdpRequest::new(request, self);
let max_retries = self.max_retries;
let retry_interval = if retry_interval_time < self.retry_interval_floor {
self.retry_interval_floor
} else {
retry_interval_time
};
P::Timer::timeout(
self.timeout,
retry::<P>(request, retry_interval, max_retries.into()),
)
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<P> Stream for UdpClientStream<P> {
type Item = Result<(), NetError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(())))
}
}
}
struct UdpRequest<P> {
avoid_local_ports: Arc<HashSet<u16>>,
name_server: SocketAddr,
request: DnsRequest,
provider: P,
#[cfg(feature = "__dnssec")]
signer: Option<TSigner>,
#[cfg(feature = "__dnssec")]
now: u64,
bind_addr: Option<SocketAddr>,
os_port_selection: bool,
case_randomization: bool,
recv_buf_size: usize,
}
impl<P: RuntimeProvider> UdpRequest<P> {
fn new(request: DnsRequest, stream: &UdpClientStream<P>) -> Self {
Self {
avoid_local_ports: stream.avoid_local_ports.clone(),
recv_buf_size: MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize),
case_randomization: request.options().case_randomization,
name_server: stream.name_server,
#[cfg(feature = "__dnssec")]
signer: match &stream.signer {
Some(signer) if signer.should_sign_message(&request) => stream.signer.clone(),
_ => None,
},
request,
provider: stream.provider.clone(),
#[cfg(feature = "__dnssec")]
now: P::Timer::current_time(),
bind_addr: stream.bind_addr,
os_port_selection: stream.os_port_selection,
}
}
}
impl<P: RuntimeProvider> Request for UdpRequest<P> {
async fn send(&self) -> Result<DnsResponse, NetError> {
let original_query = self.request.original_query();
#[cfg_attr(not(feature = "__dnssec"), expect(unused_mut))]
let mut request = self.request.clone();
#[cfg(feature = "__dnssec")]
let mut verifier = None;
#[cfg(feature = "__dnssec")]
if let Some(signer) = &self.signer {
match request.finalize(signer, self.now) {
Ok(answer_verifier) => verifier = answer_verifier,
Err(e) => {
debug!("could not sign message: {}", e);
return Err(e.into());
}
}
}
let request_bytes = match request.to_vec() {
Ok(bytes) => bytes,
Err(err) => return Err(err.into()),
};
let msg_id = request.id;
let msg = SerialMessage::new(request_bytes, self.name_server);
let addr = msg.addr();
let final_message = match msg.to_message() {
Ok(m) => m,
Err(e) => return Err(e.into()),
};
debug!(%final_message, "final message");
let socket = NextRandomUdpSocket::new(
addr,
self.bind_addr,
self.avoid_local_ports.clone(),
self.os_port_selection,
self.provider.clone(),
)
.await?;
let bytes = msg.bytes();
let len_sent: usize = socket.send_to(bytes, addr).await?;
if bytes.len() != len_sent {
return Err(NetError::from(format!(
"Not all bytes of message sent, {} of {}",
len_sent,
bytes.len()
)));
}
trace!(
recv_buf_size = self.recv_buf_size,
"creating UDP receive buffer"
);
let mut recv_buf = vec![0; self.recv_buf_size];
for _ in 0..3 {
let (len, src) = socket.recv_from(&mut recv_buf).await?;
let response_bytes = &recv_buf[0..len];
let response_buffer = Vec::from(response_bytes);
let request_target = msg.addr();
if src.ip().to_canonical() != request_target.ip().to_canonical()
|| src.port() != request_target.port()
{
warn!(
"ignoring response from {}:{} because it does not match name_server: {}:{}.",
src.ip().to_canonical(),
src.port(),
request_target.ip().to_canonical(),
request_target.port(),
);
continue;
}
let mut response = DnsResponse::from_buffer(response_buffer)?;
if msg_id != response.id {
warn!(
"expected message id: {} got: {}, dropped",
msg_id, response.id
);
continue;
}
let request_message = Message::from_vec(msg.bytes())?;
let request_queries = &request_message.queries;
let response_queries = &mut response.queries;
let question_matches = response_queries
.iter()
.all(|elem| request_queries.contains(elem));
if self.case_randomization
&& question_matches
&& !response_queries.iter().all(|elem| {
request_queries
.iter()
.any(|req_q| req_q == elem && req_q.name().eq_case(elem.name()))
})
{
warn!(
"case of question section did not match: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
);
return Err(NetError::QueryCaseMismatch);
}
if !question_matches {
warn!(
"detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
);
continue;
}
if self.case_randomization {
if let Some(original_query) = original_query {
for response_query in response_queries.iter_mut() {
if response_query == original_query {
*response_query = original_query.clone();
}
}
}
}
debug!("received message id: {}", response.id);
#[cfg(feature = "__dnssec")]
if let Some(mut verifier) = verifier {
return Ok(verifier.verify(response_bytes)?);
}
return Ok(response);
}
Err(NetError::from("udp receive attempts exceeded"))
}
}
pub struct UdpClientStreamBuilder<P> {
name_server: SocketAddr,
timeout: Option<Duration>,
#[cfg(feature = "__dnssec")]
signer: Option<TSigner>,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
max_retries: u8,
retry_interval_floor: Duration,
}
impl<P: RuntimeProvider> UdpClientStreamBuilder<P> {
pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}
#[cfg(feature = "__dnssec")]
pub fn with_signer(self, signer: Option<TSigner>) -> Self {
Self {
name_server: self.name_server,
timeout: self.timeout,
signer,
bind_addr: self.bind_addr,
avoid_local_ports: self.avoid_local_ports,
os_port_selection: self.os_port_selection,
provider: self.provider,
max_retries: self.max_retries,
retry_interval_floor: self.retry_interval_floor,
}
}
pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
self.bind_addr = bind_addr;
self
}
pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
self.avoid_local_ports = avoid_local_ports;
self
}
pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
self.os_port_selection = os_port_selection;
self
}
pub fn with_max_retries(mut self, max_retries: u8) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_retry_interval_floor(mut self, floor: u64) -> Self {
self.retry_interval_floor = Duration::from_millis(floor);
self
}
pub fn exchange(self) -> DnsExchange<P> {
let mut handle = self.provider.create_handle();
let stream = self.build();
let (exchange, bg) = DnsExchange::from_stream(stream);
handle.spawn_bg(bg);
exchange
}
pub fn build(self) -> UdpClientStream<P> {
UdpClientStream {
name_server: self.name_server,
timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
is_shutdown: false,
#[cfg(feature = "__dnssec")]
signer: self.signer,
bind_addr: self.bind_addr,
avoid_local_ports: self.avoid_local_ports.clone(),
os_port_selection: self.os_port_selection,
provider: self.provider,
max_retries: self.max_retries,
retry_interval_floor: self.retry_interval_floor,
}
}
}
async fn retry<Provider: RuntimeProvider>(
request: impl Request,
retry_interval_time: Duration,
max_tasks: usize,
) -> Result<DnsResponse, NetError> {
let mut futures = FuturesUnordered::new();
let retry_timer = Provider::Timer::delay_for(retry_interval_time).fuse();
pin_mut!(retry_timer);
futures.push(request.send());
let mut tasks = 1;
loop {
futures_util::select! {
result = futures.next() => {
match result {
Some(result) => return result,
None => return Err(NetError::from("no tasks successful")),
}
}
_ = &mut retry_timer => {
if tasks < max_tasks {
tasks += 1;
futures.push(request.send());
retry_timer.set(Provider::Timer::delay_for(retry_interval_time).fuse());
}
}
}
}
}
trait Request {
async fn send(&self) -> Result<DnsResponse, NetError>;
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use core::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
sync::atomic::{AtomicU8, Ordering},
};
use std::io;
use test_support::subscribe;
use tokio::time::sleep;
use super::*;
use crate::{
proto::op::ResponseCode,
runtime::{TokioRuntimeProvider, TokioTime},
udp::tests::{
udp_client_stream_bad_id_test, udp_client_stream_response_limit_test,
udp_client_stream_test,
},
};
#[tokio::test]
async fn test_udp_client_stream_ipv4() {
subscribe();
udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv4_bad_id() {
subscribe();
udp_client_stream_bad_id_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new())
.await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv4_resp_limit() {
subscribe();
udp_client_stream_response_limit_test(
IpAddr::V4(Ipv4Addr::LOCALHOST),
TokioRuntimeProvider::new(),
)
.await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv6() {
subscribe();
udp_client_stream_test(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
TokioRuntimeProvider::new(),
)
.await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv6_bad_id() {
subscribe();
udp_client_stream_bad_id_test(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
TokioRuntimeProvider::new(),
)
.await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv6_resp_limit() {
subscribe();
udp_client_stream_response_limit_test(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
TokioRuntimeProvider::new(),
)
.await;
}
#[tokio::test(start_paused = true)]
async fn retry_handler_test() -> Result<(), NetError> {
let mut message = Message::query().into_response();
message.metadata.response_code = ResponseCode::NoError;
let ret = retry::<TokioRuntimeProvider>(
FixedResponse {
response: DnsResponse::from_message(message.clone())?,
},
Duration::from_millis(200),
5,
)
.await?;
assert_eq!(ret.response_code, ResponseCode::NoError);
let (req, tries) = DelayedResponse::new(
DnsResponse::from_message(message.clone()).unwrap(),
Duration::from_millis(100),
Arc::new(AtomicU8::new(0)),
);
retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
assert_eq!(tries.load(Ordering::Relaxed), 1);
let (req, tries) = DelayedResponse::new(
DnsResponse::from_message(message.clone()).unwrap(),
Duration::from_millis(1500),
Arc::new(AtomicU8::new(0)),
);
retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
assert_eq!(tries.load(Ordering::Relaxed), 5);
let (req, tries) = DelayedResponse::new(
DnsResponse::from_message(message.clone()).unwrap(),
Duration::from_millis(1000),
Arc::new(AtomicU8::new(0)),
);
let timer_ret = TokioTime::timeout(
Duration::from_millis(500),
retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5),
)
.await;
if let Err(e) = timer_ret {
assert_eq!(e.kind(), io::ErrorKind::TimedOut);
} else {
panic!("timer did not timeout");
}
assert_eq!(tries.load(Ordering::Relaxed), 3);
Ok(())
}
struct FixedResponse {
response: DnsResponse,
}
impl Request for FixedResponse {
async fn send(&self) -> Result<DnsResponse, NetError> {
Ok(self.response.clone())
}
}
struct DelayedResponse {
response: DnsResponse,
delay: Duration,
counter: Arc<AtomicU8>,
}
impl DelayedResponse {
fn new(
response: DnsResponse,
delay: Duration,
counter: Arc<AtomicU8>,
) -> (Self, Arc<AtomicU8>) {
(
Self {
response,
delay,
counter: counter.clone(),
},
counter,
)
}
}
impl Request for DelayedResponse {
async fn send(&self) -> Result<DnsResponse, NetError> {
let _ = self.counter.fetch_add(1, Ordering::Relaxed);
sleep(self.delay).await;
Ok(self.response.clone())
}
}
}