use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures_util::{
ready,
stream::{Stream, StreamExt},
};
use tracing::debug;
use crate::{
error::NetError,
proto::{
ProtoError,
op::{
DEFAULT_MAX_PAYLOAD_LEN, DnsRequest, DnsRequestOptions, DnsResponse, Edns, Message,
OpCode, Query, update_message,
},
rr::{DNSClass, Name, RData, Record, RecordSet, RecordType, rdata::SOA},
},
runtime::RuntimeProvider,
xfer::{
BufDnsStreamHandle, DnsClientStream, DnsExchange, DnsExchangeBackground, DnsExchangeSend,
DnsHandle, DnsMultiplexer, DnsRequestSender,
},
};
#[cfg(all(feature = "__dnssec", feature = "tokio"))]
pub(crate) mod dnssec_client;
#[cfg(all(feature = "__dnssec", feature = "tokio"))]
pub use dnssec_client::{AsyncSecureClientBuilder, DnssecClient};
mod memoize_client_handle;
pub use memoize_client_handle::MemoizeClientHandle;
mod rc_stream;
#[cfg(test)]
mod tests;
#[derive(Clone)]
pub struct Client<P> {
exchange: DnsExchange<P>,
use_edns: bool,
}
impl<P: RuntimeProvider> Client<P> {
pub fn new<S: DnsClientStream>(
stream: S,
stream_handle: BufDnsStreamHandle,
) -> (Self, DnsExchangeBackground<DnsMultiplexer<S>, P::Timer>) {
Self::with_timeout(stream, stream_handle, Duration::from_secs(5))
}
pub fn with_timeout<S: DnsClientStream>(
stream: S,
stream_handle: BufDnsStreamHandle,
timeout_duration: Duration,
) -> (Self, DnsExchangeBackground<DnsMultiplexer<S>, P::Timer>) {
Self::from_sender(DnsMultiplexer::new(stream, stream_handle).with_timeout(timeout_duration))
}
pub fn from_sender<S: DnsRequestSender>(
sender: S,
) -> (Self, DnsExchangeBackground<S, P::Timer>) {
let (exchange, bg) = DnsExchange::from_stream(sender);
(
Self {
exchange,
use_edns: true,
},
bg,
)
}
pub fn enable_edns(&mut self) {
self.use_edns = true;
}
pub fn disable_edns(&mut self) {
self.use_edns = false;
}
}
impl<P: RuntimeProvider> DnsHandle for Client<P> {
type Response = DnsExchangeSend<P>;
type Runtime = P;
fn send(&self, request: DnsRequest) -> Self::Response {
self.exchange.send(request)
}
fn is_using_edns(&self) -> bool {
self.use_edns
}
}
impl<T> ClientHandle for T where T: DnsHandle {}
pub trait ClientHandle: 'static + Clone + DnsHandle + Send {
fn query(
&mut self,
name: Name,
query_class: DNSClass,
query_type: RecordType,
) -> ClientResponse<<Self as DnsHandle>::Response> {
let mut query = Query::query(name, query_type);
query.set_query_class(query_class);
let mut options = DnsRequestOptions::default();
options.use_edns = self.is_using_edns();
ClientResponse(self.lookup(query, options))
}
fn notify<R>(
&mut self,
name: Name,
query_class: DNSClass,
query_type: RecordType,
rrset: Option<R>,
) -> ClientResponse<<Self as DnsHandle>::Response>
where
R: Into<RecordSet>,
{
debug!("notifying: {} {:?}", name, query_type);
let mut message = Message::query();
message.metadata.op_code = OpCode::Notify;
if self.is_using_edns() {
message
.edns
.get_or_insert_with(Edns::new)
.set_max_payload(DEFAULT_MAX_PAYLOAD_LEN)
.set_version(0);
}
let mut query: Query = Query::new();
query
.set_name(name)
.set_query_class(query_class)
.set_query_type(query_type);
message.add_query(query);
if let Some(rrset) = rrset {
message.add_answers(rrset.into());
}
ClientResponse(self.send(DnsRequest::from(message)))
}
fn create<R>(
&mut self,
rrset: R,
zone_origin: Name,
) -> ClientResponse<<Self as DnsHandle>::Response>
where
R: Into<RecordSet>,
{
let rrset = rrset.into();
let message = update_message::create(rrset, zone_origin, self.is_using_edns());
ClientResponse(self.send(DnsRequest::from(message)))
}
fn append<R>(
&mut self,
rrset: R,
zone_origin: Name,
must_exist: bool,
) -> ClientResponse<<Self as DnsHandle>::Response>
where
R: Into<RecordSet>,
{
let rrset = rrset.into();
let message = update_message::append(rrset, zone_origin, must_exist, self.is_using_edns());
ClientResponse(self.send(DnsRequest::from(message)))
}
fn compare_and_swap<C, N>(
&mut self,
current: C,
new: N,
zone_origin: Name,
) -> ClientResponse<<Self as DnsHandle>::Response>
where
C: Into<RecordSet>,
N: Into<RecordSet>,
{
let current = current.into();
let new = new.into();
let message =
update_message::compare_and_swap(current, new, zone_origin, self.is_using_edns());
ClientResponse(self.send(DnsRequest::from(message)))
}
fn delete_by_rdata<R>(
&mut self,
rrset: R,
zone_origin: Name,
) -> ClientResponse<<Self as DnsHandle>::Response>
where
R: Into<RecordSet>,
{
let rrset = rrset.into();
let message = update_message::delete_by_rdata(rrset, zone_origin, self.is_using_edns());
ClientResponse(self.send(DnsRequest::from(message)))
}
fn delete_rrset(
&mut self,
record: Record,
zone_origin: Name,
) -> ClientResponse<<Self as DnsHandle>::Response> {
assert!(zone_origin.zone_of(&record.name));
let message = update_message::delete_rrset(record, zone_origin, self.is_using_edns());
ClientResponse(self.send(DnsRequest::from(message)))
}
fn delete_all(
&mut self,
name_of_records: Name,
zone_origin: Name,
dns_class: DNSClass,
) -> ClientResponse<<Self as DnsHandle>::Response> {
assert!(zone_origin.zone_of(&name_of_records));
let message = update_message::delete_all(
name_of_records,
zone_origin,
dns_class,
self.is_using_edns(),
);
ClientResponse(self.send(DnsRequest::from(message)))
}
fn zone_transfer(
&mut self,
zone_origin: Name,
last_soa: Option<SOA>,
) -> ClientStreamXfr<<Self as DnsHandle>::Response> {
let ixfr = last_soa.is_some();
let message = update_message::zone_transfer(zone_origin, last_soa);
ClientStreamXfr::new(self.send(DnsRequest::from(message)), ixfr)
}
}
#[must_use = "stream do nothing unless polled"]
pub struct ClientStreamingResponse<R>(pub(crate) R)
where
R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
impl<R> Stream for ClientStreamingResponse<R>
where
R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
{
type Item = Result<DnsResponse, ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(ready!(self.0.poll_next_unpin(cx)))
}
}
#[must_use = "futures do nothing unless polled"]
pub struct ClientResponse<R>(pub(crate) R)
where
R: Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static;
impl<R> Future for ClientResponse<R>
where
R: Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static,
{
type Output = Result<DnsResponse, NetError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(match ready!(self.0.poll_next_unpin(cx)) {
Some(r) => r,
None => Err(NetError::Timeout),
})
}
}
#[must_use = "stream do nothing unless polled"]
pub struct ClientStreamXfr<R>
where
R: Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static,
{
state: ClientStreamXfrState<R>,
}
impl<R> ClientStreamXfr<R>
where
R: Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static,
{
fn new(inner: R, maybe_incr: bool) -> Self {
Self {
state: ClientStreamXfrState::Start { inner, maybe_incr },
}
}
}
#[derive(Debug)]
enum ClientStreamXfrState<R> {
Start {
inner: R,
maybe_incr: bool,
},
Second {
inner: R,
expected_serial: u32,
maybe_incr: bool,
},
Axfr {
inner: R,
expected_serial: u32,
},
Ixfr {
inner: R,
even: bool,
expected_serial: u32,
},
Ended,
Invalid,
}
impl<R> ClientStreamXfrState<R> {
fn inner(&mut self) -> &mut R {
use ClientStreamXfrState::*;
match self {
Start { inner, .. } => inner,
Second { inner, .. } => inner,
Axfr { inner, .. } => inner,
Ixfr { inner, .. } => inner,
Ended | Invalid => unreachable!(),
}
}
fn process(&mut self, answers: &[Record]) -> Result<(), NetError> {
use ClientStreamXfrState::*;
fn get_serial(r: &Record) -> Option<u32> {
match &r.data {
RData::SOA(soa) => Some(soa.serial),
_ => None,
}
}
if answers.is_empty() {
return Ok(());
}
match core::mem::replace(self, Invalid) {
Start { inner, maybe_incr } => {
if let Some(expected_serial) = get_serial(&answers[0]) {
*self = Second {
inner,
maybe_incr,
expected_serial,
};
self.process(&answers[1..])
} else {
*self = Ended;
Ok(())
}
}
Second {
inner,
maybe_incr,
expected_serial,
} => {
if let Some(serial) = get_serial(&answers[0]) {
if serial == expected_serial {
*self = Ended;
if answers.len() == 1 {
Ok(())
} else {
Err("invalid zone transfer, contains trailing records".into())
}
} else if maybe_incr {
*self = Ixfr {
inner,
expected_serial,
even: true,
};
self.process(&answers[1..])
} else {
*self = Ended;
Err("invalid zone transfer, expected AXFR, got IXFR".into())
}
} else {
*self = Axfr {
inner,
expected_serial,
};
self.process(&answers[1..])
}
}
Axfr {
inner,
expected_serial,
} => {
let soa_count = answers
.iter()
.filter(|a| a.record_type() == RecordType::SOA)
.count();
match soa_count {
0 => {
*self = Axfr {
inner,
expected_serial,
};
Ok(())
}
1 => {
*self = Ended;
match answers.last().map(|r| r.record_type()) {
Some(RecordType::SOA) => Ok(()),
_ => Err("invalid zone transfer, contains trailing records".into()),
}
}
_ => {
*self = Ended;
Err("invalid zone transfer, contains trailing records".into())
}
}
}
Ixfr {
inner,
even,
expected_serial,
} => {
let even = answers
.iter()
.fold(even, |even, a| even ^ (a.record_type() == RecordType::SOA));
if even {
if let Some(serial) = get_serial(answers.last().unwrap()) {
if serial == expected_serial {
*self = Ended;
return Ok(());
}
}
}
*self = Ixfr {
inner,
even,
expected_serial,
};
Ok(())
}
Ended | Invalid => {
unreachable!();
}
}
}
}
impl<R> Stream for ClientStreamXfr<R>
where
R: Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static,
{
type Item = Result<DnsResponse, NetError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use ClientStreamXfrState::*;
if matches!(self.state, Ended) {
return Poll::Ready(None);
}
let message = ready!(self.state.inner().poll_next_unpin(cx)).map(|response| {
let ok = response?;
self.state.process(&ok.answers)?;
Ok(ok)
});
Poll::Ready(message)
}
}