#![cfg(feature = "unstable-client-cache")]
use crate::base::iana::{Class, Opcode, OptRcode, Rtype};
use crate::base::name::ToName;
use crate::base::{
Header, Message, MessageBuilder, Name, ParsedName, StaticCompressor, Ttl,
};
use crate::dep::octseq::Octets;
use crate::net::client::request::{
ComposeRequest, Error, GetResponse, SendRequest,
};
use crate::rdata::AllRecordData;
use crate::utils::config::DefMinMax;
use bytes::Bytes;
use moka::future::Cache;
use std::boxed::Box;
use std::cmp::min;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use std::vec::Vec;
use tokio::time::Instant;
const MAX_CACHE_ENTRIES: DefMinMax<u64> =
DefMinMax::new(1_000, 1, 1_000_000_000);
const MAX_VALIDITY: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(604800),
Duration::from_secs(60),
Duration::from_secs(6048000),
);
const TRANSPORT_FAILURE_DURATION: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(30),
Duration::from_secs(1),
Duration::from_secs(5 * 60),
);
const MISC_ERROR_DURATION: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(30),
Duration::from_secs(1),
Duration::from_secs(5 * 60),
);
const MAX_NXDOMAIN_VALIDITY: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(3600),
Duration::from_secs(60),
Duration::from_secs(24 * 3600),
);
const MAX_NODATA_VALIDITY: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(3600),
Duration::from_secs(60),
Duration::from_secs(24 * 3600),
);
const MAX_DELEGATION_VALIDITY: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(1_000_000),
Duration::from_secs(60),
Duration::from_secs(1_000_000_000),
);
#[derive(Clone, Debug)]
pub struct Config {
max_cache_entries: u64,
max_validity: Duration,
transport_failure_duration: Duration,
misc_error_duration: Duration,
max_nxdomain_validity: Duration,
max_nodata_validity: Duration,
max_delegation_validity: Duration,
cache_truncated: bool,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn set_max_cache_entries(&mut self, value: u64) {
self.max_cache_entries = MAX_CACHE_ENTRIES.limit(value)
}
pub fn set_max_validity(&mut self, value: Duration) {
self.max_validity = MAX_VALIDITY.limit(value)
}
pub fn set_transport_failure_duration(&mut self, value: Duration) {
self.transport_failure_duration =
TRANSPORT_FAILURE_DURATION.limit(value)
}
pub fn set_misc_error_duration(&mut self, value: Duration) {
self.misc_error_duration = MISC_ERROR_DURATION.limit(value)
}
pub fn set_max_nxdomain_validity(&mut self, value: Duration) {
self.max_nxdomain_validity = MAX_NXDOMAIN_VALIDITY.limit(value)
}
pub fn set_max_nodata_validity(&mut self, value: Duration) {
self.max_nodata_validity = MAX_NODATA_VALIDITY.limit(value)
}
pub fn set_max_delegation_validity(&mut self, value: Duration) {
self.max_delegation_validity = MAX_DELEGATION_VALIDITY.limit(value)
}
pub fn set_cache_truncated(&mut self, value: bool) {
self.cache_truncated = value;
}
}
impl Default for Config {
fn default() -> Self {
Self {
max_cache_entries: MAX_CACHE_ENTRIES.default(),
max_validity: MAX_VALIDITY.default(),
transport_failure_duration: TRANSPORT_FAILURE_DURATION.default(),
misc_error_duration: MISC_ERROR_DURATION.default(),
max_nxdomain_validity: MAX_NXDOMAIN_VALIDITY.default(),
max_nodata_validity: MAX_NODATA_VALIDITY.default(),
max_delegation_validity: MAX_DELEGATION_VALIDITY.default(),
cache_truncated: false,
}
}
}
#[derive(Clone)]
pub struct Connection<
Upstream,
> {
upstream: Upstream,
cache: Cache<Key, Arc<Value >>,
config: Config,
}
impl<Upstream> Connection<Upstream> {
pub fn new(upstream: Upstream) -> Self {
Self::with_config(upstream, Default::default())
}
pub fn with_config(upstream: Upstream, config: Config) -> Self {
Self {
upstream,
cache: Cache::new(config.max_cache_entries),
config,
}
}
}
impl<Upstream > Connection<Upstream >
{
}
impl<CR, Upstream > SendRequest<CR>
for Connection<Upstream >
where
CR: Clone + ComposeRequest + 'static,
Upstream: Clone + SendRequest<CR> + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request::<CR, Upstream >::new(
request_msg,
self.upstream.clone(),
self.cache.clone(),
self.config.clone(),
))
}
}
pub struct Request<CR, Upstream >
where
CR: Send + Sync,
Upstream: Send + Sync,
{
state: RequestState,
request_msg: CR,
upstream: Upstream,
cache: Cache<Key, Arc<Value >>,
config: Config,
}
impl<CR, Upstream > Request<CR, Upstream >
where
CR: Clone + ComposeRequest + Send + Sync,
Upstream: SendRequest<CR> + Send + Sync,
{
fn new(
request_msg: CR,
upstream: Upstream,
cache: Cache<Key, Arc<Value >>,
config: Config,
) -> Request<CR, Upstream > {
Self {
state: RequestState::Init,
request_msg,
upstream,
cache,
config,
}
}
async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
loop {
match &mut self.state {
RequestState::Init => {
let msg = self.request_msg.to_message()?;
let header = msg.header();
let opcode = header.opcode();
let mut question_section = msg.question();
let question = match question_section.next() {
None => {
let request = self
.upstream
.send_request(self.request_msg.clone());
self.state =
RequestState::GetResponseNoCache(request);
continue;
}
Some(question) => question?,
};
if question_section.next().is_some() {
let request = self
.upstream
.send_request(self.request_msg.clone());
self.state =
RequestState::GetResponseNoCache(request);
continue;
}
let qname = question.qname();
let qclass = question.qclass();
let qtype = question.qtype();
if !(opcode == Opcode::QUERY && qclass == Class::IN) {
let request = self
.upstream
.send_request(self.request_msg.clone());
self.state =
RequestState::GetResponseNoCache(request);
continue;
}
let mut ad = header.ad();
let cd = header.cd();
let rd = header.rd();
let dnssec_ok =
msg.opt().is_some_and(|opt| opt.dnssec_ok());
if dnssec_ok && !ad {
ad = true;
}
let key =
Key::new(qname, qclass, qtype, ad, cd, dnssec_ok, rd);
let opt_ce = self.cache_lookup(&key).await?;
if let Some(value) = opt_ce {
let opt_response = value.get_response(qname);
if let Some(response) = opt_response {
return response;
}
}
let request =
self.upstream.send_request(self.request_msg.clone());
self.state = RequestState::GetResponse(key, request);
continue;
}
RequestState::GetResponse(key, request) => {
let response = request.get_response().await;
let key = key.clone();
let value = Arc::new(Value::new(
response.clone(),
&self.config,
)?);
self.cache_insert(key, value).await;
return response;
}
RequestState::GetResponseNoCache(request) => {
return request.get_response().await;
}
}
}
}
async fn cache_lookup(
&self,
key: &Key,
) -> Result<Option<Arc<Value >>, Error> {
self.cache_lookup_rd_do_ad(key).await
}
async fn cache_lookup_rd_do_ad(
&self,
key: &Key,
) -> Result<Option<Arc<Value >>, Error> {
let opt_value = self.cache_lookup_do_ad(key).await?;
if opt_value.is_some() || key.rd {
return Ok(opt_value);
}
let mut alt_key = key.clone();
alt_key.rd = true;
let opt_value = self.cache_lookup_do_ad(&alt_key).await?;
if let Some(value) = opt_value {
let value = update_header(
value,
&self.config,
|_hdr| true,
|hdr| hdr.set_rd(false),
)?;
self.cache_insert(key.clone(), value.clone()).await;
return Ok(Some(value));
}
Ok(opt_value)
}
async fn cache_lookup_do_ad(
&self,
key: &Key,
) -> Result<Option<Arc<Value >>, Error> {
let opt_value = self.cache_lookup_ad(key).await?;
if opt_value.is_some() || key.addo.dnssec_ok() {
return Ok(opt_value);
}
if is_dnssec(key.qtype) {
return Ok(None);
}
let mut alt_key = key.clone();
alt_key.addo = AdDo::Do;
let opt_value = self.cache.get(&alt_key).await;
if let Some(value) = opt_value {
let value = update_message(
value,
&self.config,
|_hdr| true,
|msg| remove_dnssec(msg, key.addo.ad()),
)?;
self.cache_insert(key.clone(), value.clone()).await;
return Ok(Some(value));
}
Ok(opt_value)
}
async fn cache_lookup_ad(
&self,
key: &Key,
) -> Result<Option<Arc<Value >>, Error> {
let opt_value = self.cache.get(key).await;
if opt_value.is_some() || key.addo.ad() {
return Ok(opt_value);
}
let mut alt_key = key.clone();
alt_key.addo = AdDo::Ad;
let opt_value = self.cache.get(&alt_key).await;
if let Some(value) = opt_value {
let value = update_header(
value,
&self.config,
|hdr| hdr.ad(),
|hdr| hdr.set_ad(false),
)?;
self.cache_insert(key.clone(), value.clone()).await;
return Ok(Some(value));
}
Ok(opt_value)
}
async fn cache_insert(&self, key: Key, value: Arc<Value >) {
if value.valid_for.is_zero() {
return;
}
let value = match prepare_for_insert(value.clone(), &self.config) {
Ok(value) => value,
Err(e) => {
Arc::new(
Value::new_from_value_and_response(
value,
Err(e),
&self.config,
)
.expect("value from error does not fail"),
)
}
};
self.cache.insert(key, value).await
}
}
impl<CR, Upstream > Debug for Request<CR, Upstream >
where
CR: Send + Sync,
Upstream: Send + Sync,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
f.debug_struct("Request")
.field("fut", &format_args!("_"))
.finish()
}
}
impl<CR, Upstream > GetResponse for Request<CR, Upstream >
where
CR: Clone + ComposeRequest + Debug + Sync,
Upstream: SendRequest<CR> + Send + Sync + 'static,
{
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl())
}
}
enum RequestState {
Init,
GetResponse(Key, Box<dyn GetResponse + Send + Sync>),
GetResponseNoCache(Box<dyn GetResponse + Send + Sync>),
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct Key {
qname: Name<Bytes>,
qclass: Class,
qtype: Rtype,
addo: AdDo,
cd: bool,
rd: bool,
}
impl Key {
fn new<TDN>(
qname: TDN,
qclass: Class,
qtype: Rtype,
ad: bool,
cd: bool,
dnssec_ok: bool,
rd: bool,
) -> Key
where
TDN: ToName,
{
Self {
qname: qname.to_canonical_name(),
qclass,
qtype,
addo: AdDo::new(ad, dnssec_ok),
cd,
rd,
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum AdDo {
Do,
Ad,
None,
}
impl AdDo {
fn new(ad: bool, dnssec_ok: bool) -> Self {
if dnssec_ok {
AdDo::Do
} else if ad {
AdDo::Ad
} else {
AdDo::None
}
}
fn ad(&self) -> bool {
match self {
AdDo::Ad | AdDo::Do => true,
AdDo::None => false,
}
}
fn dnssec_ok(&self) -> bool {
match self {
AdDo::Do => true,
AdDo::Ad | AdDo::None => false,
}
}
}
#[derive(Debug)]
struct Value
{
created_at: Instant,
valid_for: Duration,
response: Result<Message<Bytes>, Error>,
}
impl Value
{
fn new(
response: Result<Message<Bytes>, Error>,
config: &Config,
) -> Result<Value , Error> {
Ok(Self {
created_at: Instant::now(),
valid_for: validity(&response, config)?,
response,
})
}
fn new_from_value_and_response(
val: Arc<Value >,
response: Result<Message<Bytes>, Error>,
config: &Config,
) -> Result<Value , Error> {
Ok(Self {
created_at: val.created_at, valid_for: validity(&response, config)?,
response,
})
}
fn get_response<TDN>(
&self,
orig_qname: TDN,
) -> Option<Result<Message<Bytes>, Error>>
where
TDN: ToName + Clone,
{
let elapsed = self.created_at.elapsed();
if elapsed > self.valid_for {
return None;
}
let secs = elapsed.as_secs() as u32;
let response = decrement_ttl(orig_qname, &self.response, secs);
Some(response)
}
}
fn validity(
response: &Result<Message<Bytes>, Error>,
config: &Config,
) -> Result<Duration, Error> {
let Ok(msg) = response else {
return Ok(config.transport_failure_duration);
};
if msg.header().tc() && !config.cache_truncated {
return Ok(Duration::ZERO);
}
let mut min_val = config.max_validity;
match msg.opt_rcode() {
OptRcode::NOERROR => {
match classify_no_error(msg)? {
NoErrorType::Answer => (),
NoErrorType::NoData => {
min_val = min(min_val, config.max_nodata_validity)
}
NoErrorType::Delegation => {
min_val = min(min_val, config.max_delegation_validity)
}
NoErrorType::NoErrorWeird =>
{
min_val = Duration::ZERO
}
}
}
OptRcode::NXDOMAIN => {
min_val = min(min_val, config.max_nxdomain_validity);
}
_ => {
min_val = min(min_val, config.misc_error_duration);
}
}
let msg = msg.question();
let mut msg = msg.answer()?;
for rr in &mut msg {
let rr = rr?;
min_val =
min(min_val, Duration::from_secs(rr.ttl().as_secs() as u64));
}
let mut msg = msg.next_section()?.expect("section should be present");
for rr in &mut msg {
let rr = rr?;
min_val =
min(min_val, Duration::from_secs(rr.ttl().as_secs() as u64));
}
let msg = msg.next_section()?.expect("section should be present");
for rr in msg {
let rr = rr?;
if rr.rtype() != Rtype::OPT {
min_val =
min(min_val, Duration::from_secs(rr.ttl().as_secs() as u64));
}
}
Ok(min_val)
}
fn decrement_ttl<TDN>(
orig_qname: TDN,
response: &Result<Message<Bytes>, Error>,
amount: u32,
) -> Result<Message<Bytes>, Error>
where
TDN: ToName + Clone,
{
let msg = match response {
Err(err) => return Err(err.clone()),
Ok(msg) => msg,
};
let amount = Ttl::from_secs(amount);
let mut target =
MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
.expect("Vec is expected to have enough space");
let source = msg;
*target.header_mut() = source.header();
let source = source.question();
let mut target = target.question();
for rr in source {
let rr = rr?;
target
.push((orig_qname.clone(), rr.qtype(), rr.qclass()))
.expect("push failed");
}
let mut source = source.answer()?;
let mut target = target.answer();
for rr in &mut source {
let mut rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
rr.set_ttl(rr.ttl() - amount);
target.push(rr).expect("push failed");
}
let mut source =
source.next_section()?.expect("section should be present");
let mut target = target.authority();
for rr in &mut source {
let mut rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
rr.set_ttl(rr.ttl() - amount);
target.push(rr).expect("push failed");
}
let source = source.next_section()?.expect("section should be present");
let mut target = target.additional();
for rr in source {
let rr = rr?;
let mut rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
if rr.rtype() != Rtype::OPT {
rr.set_ttl(rr.ttl() - amount);
}
target.push(rr).expect("push failed");
}
let result = target.as_builder().clone();
let msg =
Message::<Bytes>::from_octets(result.finish().into_target().into())
.expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
fn remove_dnssec(
msg: &Message<Bytes>,
ad: bool,
) -> Result<Message<Bytes>, Error> {
let mut target =
MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
.expect("Vec is expected to have enough space");
let source = msg;
*target.header_mut() = source.header();
if !ad {
target.header_mut().set_ad(false);
}
let source = source.question();
let mut target = target.question();
for rr in source {
target.push(rr?).expect("push failed");
}
let mut source = source.answer()?;
let mut target = target.answer();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
if is_dnssec(rr.rtype()) {
continue;
}
target.push(rr).expect("push error");
}
let mut source =
source.next_section()?.expect("section should be present");
let mut target = target.authority();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
if is_dnssec(rr.rtype()) {
continue;
}
target.push(rr).expect("push error");
}
let source = source.next_section()?.expect("section should be present");
let mut target = target.additional();
for rr in source {
let rr = rr?;
let rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
if is_dnssec(rr.rtype()) {
continue;
}
target.push(rr).expect("push error");
}
let result = target.as_builder().clone();
Ok(
Message::<Bytes>::from_octets(result.finish().into_target().into())
.expect(
"Message should be able to parse output from MessageBuilder",
),
)
}
fn is_dnssec(rtype: Rtype) -> bool {
rtype == Rtype::RRSIG || rtype == Rtype::NSEC || rtype == Rtype::NSEC3
}
enum NoErrorType {
Answer,
NoData,
Delegation,
NoErrorWeird,
}
fn classify_no_error<Octs>(msg: &Message<Octs>) -> Result<NoErrorType, Error>
where
Octs: Octets,
{
let mut question_section = msg.question();
let question = question_section.next().expect("section expected")?;
let qtype = question.qtype();
let qclass = question.qclass();
let mut msg = msg.answer()?;
for rr in &mut msg {
let rr = rr?;
if rr.rtype() == qtype && rr.class() == qclass {
return Ok(NoErrorType::Answer);
}
}
let mut found_ns = false;
let mut msg = msg.next_section()?.expect("section should be present");
for rr in &mut msg {
let rr = rr?;
if rr.class() == qclass && rr.rtype() == Rtype::SOA {
return Ok(NoErrorType::NoData);
}
if rr.class() == qclass && rr.rtype() == Rtype::NS {
found_ns = true;
}
}
if found_ns {
return Ok(NoErrorType::Delegation);
}
Ok(NoErrorType::NoErrorWeird)
}
fn prepare_for_insert(
value: Arc<Value >,
config: &Config,
) -> Result<Arc<Value >, Error>
{
update_header(value, config, |hdr| hdr.aa(), |hdr| hdr.set_aa(false))
}
fn update_header(
value: Arc<Value >,
config: &Config,
hdrtst: fn(hdr: &Header) -> bool,
fhdr: fn(&mut Header) -> (),
) -> Result<Arc<Value >, Error>
{
update_message(value, config, hdrtst, |msg| {
let mut msg = Message::<Vec<u8>>::from_octets(msg.as_slice().into())?;
let hdr = msg.header_mut();
fhdr(hdr);
Ok(Message::<Bytes>::from_octets(msg.into_octets().into())?)
})
}
fn update_message< FmsgFn>(
value: Arc<Value >,
config: &Config,
hdrtst: fn(hdr: &Header) -> bool,
fmsg: FmsgFn,
) -> Result<Arc<Value >, Error>
where
FmsgFn: Fn(&Message<Bytes>) -> Result<Message<Bytes>, Error>,
{
Ok(match &value.response {
Err(_) => {
value
}
Ok(msg) => {
if hdrtst(&msg.header()) {
let msg = fmsg(msg)?;
Arc::new(Value::new_from_value_and_response(
value.clone(),
Ok(msg),
config,
)?)
} else {
value
}
}
})
}