use core::future::{ready, Future, Ready};
use core::marker::PhantomData;
use core::ops::ControlFlow;
use core::pin::Pin;
use std::boxed::Box;
use std::fmt::Debug;
use std::sync::Arc;
use bytes::Bytes;
use futures_util::stream::{once, Once, Stream};
use octseq::Octets;
use tracing::{error, info, warn};
use crate::base::iana::{Class, Opcode, OptRcode, Rcode};
use crate::base::message::CopyRecordsError;
use crate::base::message_builder::AdditionalBuilder;
use crate::base::name::Name;
use crate::base::net::IpAddr;
use crate::base::{
Message, ParsedName, Question, Rtype, Serial, StreamTarget, ToName,
};
use crate::net::server::message::Request;
use crate::net::server::middleware::stream::MiddlewareStream;
use crate::net::server::service::{CallResult, Service};
use crate::net::server::util::{mk_builder_for_target, mk_error_response};
use crate::rdata::{AllRecordData, ZoneRecordData};
#[derive(Clone, Debug)]
pub struct NotifyMiddlewareSvc<RequestOctets, NextSvc, RequestMeta, N>
where
NextSvc: Service<RequestOctets, RequestMeta> + Unpin + Clone,
NextSvc::Future: Sync + Unpin,
N: Notifiable + Clone + Sync + Send + 'static,
RequestOctets: Octets + Send + Sync + 'static + Clone,
RequestMeta: Clone + Default + Sync + Send + 'static,
for<'a> <RequestOctets as octseq::Octets>::Range<'a>: Send + Sync,
{
next_svc: NextSvc,
notify_target: N,
_phantom: PhantomData<(RequestOctets, RequestMeta)>,
}
impl<RequestOctets, NextSvc, RequestMeta, N>
NotifyMiddlewareSvc<RequestOctets, NextSvc, RequestMeta, N>
where
NextSvc: Service<RequestOctets, RequestMeta> + Unpin + Clone,
NextSvc::Future: Sync + Unpin,
N: Notifiable + Clone + Sync + Send + 'static,
RequestOctets: Octets + Send + Sync + 'static + Clone,
RequestMeta: Clone + Default + Sync + Send + 'static,
for<'a> <RequestOctets as octseq::Octets>::Range<'a>: Send + Sync,
{
#[must_use]
pub fn new(next_svc: NextSvc, notify_target: N) -> Self {
Self {
next_svc,
notify_target,
_phantom: PhantomData,
}
}
}
impl<RequestOctets, NextSvc, RequestMeta, N>
NotifyMiddlewareSvc<RequestOctets, NextSvc, RequestMeta, N>
where
NextSvc: Service<RequestOctets, RequestMeta> + Unpin + Clone,
NextSvc::Future: Sync + Unpin,
N: Notifiable + Clone + Sync + Send + 'static,
RequestOctets: Octets + Send + Sync + 'static + Clone,
RequestMeta: Clone + Default + Sync + Send + 'static,
for<'a> <RequestOctets as octseq::Octets>::Range<'a>: Send + Sync,
{
async fn preprocess(
req: &Request<RequestOctets, RequestMeta>,
notify_target: N,
) -> ControlFlow<Once<Ready<<NextSvc::Stream as Stream>::Item>>> {
let msg = req.message();
let Some(q) = Self::get_relevant_question(msg) else {
return ControlFlow::Continue(());
};
let class = q.qclass();
let apex_name = q.qname().to_name();
let source = req.client_addr().ip();
let mut serial = None;
if msg.header_counts().ancount() > 0 {
if let Ok(mut answer) = msg.answer() {
if let Some(Ok(record)) = answer.next() {
if let Ok(Some(record)) = record.to_record() {
if let ZoneRecordData::Soa(soa) = record.data() {
serial = Some(soa.serial());
}
}
}
}
}
info!(
"NOTIFY received from {} for zone '{}' with serial {:?}",
req.client_addr(),
q.qname(),
serial,
);
match notify_target
.notify_zone_changed(class, &apex_name, serial, source)
.await
{
Err(NotifyError::NotAuthForZone) => {
warn!("Ignoring NOTIFY from {} for zone '{}': Not authoritative for zone",
req.client_addr(),
q.qname()
);
ControlFlow::Break(once(ready(Ok(CallResult::new(
mk_error_response(msg, OptRcode::NOTAUTH),
)))))
}
Err(NotifyError::Other) => {
error!(
"Error while processing NOTIFY from {} for zone '{}'.",
req.client_addr(),
q.qname()
);
ControlFlow::Break(once(ready(Ok(CallResult::new(
mk_error_response(msg, OptRcode::SERVFAIL),
)))))
}
Ok(()) => {
let mut additional = Self::copy_message(msg).unwrap();
let response_hdr = additional.header_mut();
response_hdr.set_opcode(Opcode::NOTIFY);
response_hdr.set_rcode(Rcode::NOERROR);
response_hdr.set_qr(true);
response_hdr.set_aa(true);
let res = once(ready(Ok(CallResult::new(additional))));
ControlFlow::Break(res)
}
}
}
fn get_relevant_question(
msg: &Message<RequestOctets>,
) -> Option<Question<ParsedName<RequestOctets::Range<'_>>>> {
if Opcode::NOTIFY == msg.header().opcode() {
if let Some(q) = msg.first_question() {
if q.qtype() == Rtype::SOA {
return Some(q);
}
}
}
None
}
fn copy_message(
source: &Message<RequestOctets>,
) -> Result<
AdditionalBuilder<StreamTarget<NextSvc::Target>>,
CopyRecordsError,
> {
let mut builder = mk_builder_for_target();
*builder.header_mut() = source.header();
let source = source.question();
let mut question = builder.question();
for rr in source {
question.push(rr?)?;
}
let mut source = source.answer()?;
let mut answer = question.answer();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
answer.push(rr)?;
}
let mut source =
source.next_section()?.expect("section should be present");
let mut authority = answer.authority();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
authority.push(rr)?;
}
let source =
source.next_section()?.expect("section should be present");
let mut additional = authority.additional();
for rr in source {
let rr = rr?;
if rr.rtype() != Rtype::OPT {
let rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
additional.push(rr)?;
}
}
Ok(additional)
}
}
impl<RequestOctets, NextSvc, RequestMeta, N>
Service<RequestOctets, RequestMeta>
for NotifyMiddlewareSvc<RequestOctets, NextSvc, RequestMeta, N>
where
NextSvc: Service<RequestOctets, RequestMeta> + Unpin + Clone,
NextSvc::Future: Sync + Unpin,
N: Notifiable + Clone + Sync + Send + 'static,
RequestOctets: Octets + Send + Sync + 'static + Clone,
RequestMeta: Clone + Default + Sync + Send + 'static,
for<'a> <RequestOctets as octseq::Octets>::Range<'a>: Send + Sync,
{
type Target = NextSvc::Target;
type Stream = MiddlewareStream<
NextSvc::Future,
NextSvc::Stream,
NextSvc::Stream,
Once<Ready<<NextSvc::Stream as Stream>::Item>>,
<NextSvc::Stream as Stream>::Item,
>;
type Future = Pin<Box<dyn Future<Output = Self::Stream> + Send + Sync>>;
fn call(
&self,
request: Request<RequestOctets, RequestMeta>,
) -> Self::Future {
let request = request.clone();
let next_svc = self.next_svc.clone();
let notify_target = self.notify_target.clone();
Box::pin(async move {
match Self::preprocess(&request, notify_target).await {
ControlFlow::Continue(()) => {
let stream = next_svc.call(request).await;
MiddlewareStream::IdentityStream(stream)
}
ControlFlow::Break(stream) => {
MiddlewareStream::Result(stream)
}
}
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NotifyError {
NotAuthForZone,
Other,
}
pub trait Notifiable {
fn notify_zone_changed(
&self,
class: Class,
apex_name: &Name<Bytes>,
serial: Option<Serial>,
source: IpAddr,
) -> Pin<
Box<dyn Future<Output = Result<(), NotifyError>> + Sync + Send + '_>,
>;
}
impl<T: Notifiable> Notifiable for Arc<T> {
fn notify_zone_changed(
&self,
class: Class,
apex_name: &Name<Bytes>,
serial: Option<Serial>,
source: IpAddr,
) -> Pin<
Box<dyn Future<Output = Result<(), NotifyError>> + Sync + Send + '_>,
> {
(**self).notify_zone_changed(class, apex_name, serial, source)
}
}