use core::convert::Infallible;
use core::future::{ready, Ready};
use core::marker::PhantomData;
use core::ops::ControlFlow;
use std::vec::Vec;
use octseq::{Octets, OctetsFrom};
use tracing::{error, trace, warn};
use crate::base::iana::Rcode;
use crate::base::message_builder::AdditionalBuilder;
use crate::base::wire::Composer;
use crate::base::{Message, StreamTarget};
use crate::net::server::message::Request;
use crate::net::server::service::{
CallResult, Service, ServiceError, ServiceFeedback, ServiceResult,
};
use crate::net::server::util::mk_builder_for_target;
use crate::rdata::tsig::Time48;
use crate::tsig::{self, KeyStore, ServerSequence, ServerTransaction};
use super::stream::{MiddlewareStream, PostprocessingStream};
use futures_util::stream::{once, Once};
use futures_util::Stream;
#[derive(Clone, Debug)]
pub struct TsigMiddlewareSvc<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
where
Infallible: From<<RequestOctets as octseq::OctetsFrom<Vec<u8>>>::Error>,
KS: Clone + KeyStore,
KS::Key: Clone,
NextSvc: Service<RequestOctets, Option<KS::Key>>,
NextSvc::Target: Composer + Default,
RequestOctets: Octets + OctetsFrom<Vec<u8>> + Send + Sync + Unpin,
{
next_svc: NextSvc,
key_store: KS,
_phantom: PhantomData<(RequestOctets, IgnoredRequestMeta)>,
}
impl<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
TsigMiddlewareSvc<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
where
IgnoredRequestMeta: Default + Clone + Send + Sync + Unpin + 'static,
Infallible: From<<RequestOctets as octseq::OctetsFrom<Vec<u8>>>::Error>,
KS: Clone + KeyStore + Unpin + Send + Sync + 'static,
KS::Key: Clone + Unpin + Send + Sync,
NextSvc: Service<RequestOctets, Option<KS::Key>>,
NextSvc::Future: Unpin,
RequestOctets:
Octets + OctetsFrom<Vec<u8>> + Send + Sync + 'static + Unpin + Clone,
{
#[must_use]
pub fn new(next_svc: NextSvc, key_store: KS) -> Self {
Self {
next_svc,
key_store,
_phantom: PhantomData,
}
}
}
impl<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
TsigMiddlewareSvc<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
where
IgnoredRequestMeta: Default + Clone + Send + Sync + Unpin + 'static,
Infallible: From<<RequestOctets as octseq::OctetsFrom<Vec<u8>>>::Error>,
KS: Clone + KeyStore + Unpin + Send + Sync + 'static,
KS::Key: Clone + Unpin + Send + Sync,
NextSvc: Service<RequestOctets, Option<KS::Key>>,
NextSvc::Future: Unpin,
RequestOctets:
Octets + OctetsFrom<Vec<u8>> + Send + Sync + 'static + Unpin + Clone,
{
#[allow(clippy::type_complexity)]
fn preprocess(
req: &Request<RequestOctets, IgnoredRequestMeta>,
key_store: &KS,
) -> Result<
ControlFlow<
AdditionalBuilder<StreamTarget<NextSvc::Target>>,
Option<(
Request<RequestOctets, Option<KS::Key>>,
TsigSigner<KS::Key>,
)>,
>,
ServiceError,
> {
let octets = req.message().as_slice().to_vec();
let mut mut_msg = Message::from_octets(octets).unwrap();
match tsig::ServerTransaction::request(
key_store,
&mut mut_msg,
Time48::now(),
) {
Ok(None) => {
}
Ok(Some(tsig)) => {
trace!(
"Request is signed with TSIG key '{}'",
tsig.key().name()
);
let source = mut_msg.into_octets();
let octets = RequestOctets::octets_from(source);
let new_msg = Message::from_octets(octets).unwrap();
let mut new_req = Request::new(
req.client_addr(),
req.received_at(),
new_msg,
req.transport_ctx().clone(),
Some(tsig.wrapped_key().clone()),
);
let num_bytes_to_reserve = tsig.key().compose_len();
new_req.reserve_bytes(num_bytes_to_reserve);
return Ok(ControlFlow::Continue(Some((
new_req,
TsigSigner::Transaction(tsig),
))));
}
Err(err) => {
warn!(
"{} from {} refused: {err}",
req.message().header().opcode(),
req.client_addr(),
);
let builder = mk_builder_for_target();
let res = match err.build_message(req.message(), builder) {
Ok(additional) => Ok(ControlFlow::Break(additional)),
Err(err) => {
error!("Unable to build TSIG error response: {err}");
Err(ServiceError::InternalError)
}
};
return res;
}
}
Ok(ControlFlow::Continue(None))
}
fn postprocess(
request: &Request<RequestOctets, IgnoredRequestMeta>,
response: &mut AdditionalBuilder<StreamTarget<NextSvc::Target>>,
state: &mut PostprocessingState<KS::Key>,
) -> Result<
Option<AdditionalBuilder<StreamTarget<NextSvc::Target>>>,
ServiceError,
> {
response.clear_push_limit();
let truncation_ctx;
let res = match &mut state.signer {
Some(TsigSigner::Transaction(_)) => {
let Some(TsigSigner::Transaction(signer)) =
state.signer.take()
else {
unreachable!()
};
trace!(
"Signing single response with TSIG key '{}'",
signer.key().name()
);
truncation_ctx = TruncationContext::NoSignerOnlyTheKey(
signer.key().clone(),
);
signer.answer(response, Time48::now())
}
Some(TsigSigner::Sequence(ref mut signer)) => {
trace!(
"Signing response stream with TSIG key '{}'",
signer.key().name()
);
let res = signer.answer(response, Time48::now());
truncation_ctx = TruncationContext::HaveSigner(signer);
res
}
None => {
return Ok(None);
}
};
if res.is_err() {
Ok(Some(Self::mk_signed_truncated_response(
request,
truncation_ctx,
)?))
} else {
Ok(None)
}
}
fn mk_signed_truncated_response(
request: &Request<RequestOctets, IgnoredRequestMeta>,
truncation_ctx: TruncationContext<'_, KS::Key, tsig::Key>,
) -> Result<AdditionalBuilder<StreamTarget<NextSvc::Target>>, ServiceError>
{
let builder = mk_builder_for_target();
let mut new_response = builder
.start_answer(request.message(), Rcode::NOERROR)
.unwrap();
new_response.header_mut().set_tc(true);
let mut additional = new_response.additional();
match truncation_ctx {
TruncationContext::HaveSigner(signer) => {
if let Err(err) =
signer.answer(&mut additional, Time48::now())
{
error!("Unable to sign truncated TSIG response: {err}");
Err(ServiceError::InternalError)
} else {
Ok(additional)
}
}
TruncationContext::NoSignerOnlyTheKey(key) => {
let octets = request.message().as_slice().to_vec();
let mut mut_msg = Message::from_octets(octets).unwrap();
match ServerTransaction::request(
&key,
&mut mut_msg,
Time48::now(),
) {
Ok(None) => {
error!("Unable to create signer for truncated TSIG response: internal error: request is not signed but was expected to be");
Err(ServiceError::InternalError)
}
Err(err) => {
error!("Unable to create signer for truncated TSIG response: {err}");
Err(ServiceError::InternalError)
}
Ok(Some(signer)) => {
if let Err(err) =
signer.answer(&mut additional, Time48::now())
{
error!("Unable to sign truncated TSIG response: {err}");
Err(ServiceError::InternalError)
} else {
Ok(additional)
}
}
}
}
}
}
fn map_stream_item(
request: Request<RequestOctets, IgnoredRequestMeta>,
stream_item: ServiceResult<NextSvc::Target>,
pp_config: &mut PostprocessingState<KS::Key>,
) -> ServiceResult<NextSvc::Target> {
if let Ok(mut call_res) = stream_item {
if matches!(
call_res.feedback(),
Some(ServiceFeedback::BeginTransaction)
) {
if let Some(TsigSigner::Transaction(tsig_txn)) =
pp_config.signer.take()
{
pp_config.signer = Some(TsigSigner::Sequence(
ServerSequence::from(tsig_txn),
));
}
}
if let Some(response) = call_res.response_mut() {
if let Some(new_response) =
Self::postprocess(&request, response, pp_config)?
{
*response = new_response;
}
}
Ok(call_res)
} else {
stream_item
}
}
}
impl<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
Service<RequestOctets, IgnoredRequestMeta>
for TsigMiddlewareSvc<RequestOctets, NextSvc, KS, IgnoredRequestMeta>
where
IgnoredRequestMeta: Default + Clone + Send + Sync + Unpin + 'static,
Infallible: From<<RequestOctets as octseq::OctetsFrom<Vec<u8>>>::Error>,
KS: Clone + KeyStore + Unpin + Send + Sync + 'static,
KS::Key: Clone + Unpin + Send + Sync,
NextSvc: Service<RequestOctets, Option<KS::Key>>,
NextSvc::Future: Unpin,
RequestOctets:
Octets + OctetsFrom<Vec<u8>> + Send + Sync + 'static + Unpin + Clone,
{
type Target = NextSvc::Target;
type Stream = MiddlewareStream<
NextSvc::Future,
NextSvc::Stream,
PostprocessingStream<
RequestOctets,
NextSvc::Future,
NextSvc::Stream,
IgnoredRequestMeta,
PostprocessingState<KS::Key>,
>,
Once<Ready<ServiceResult<Self::Target>>>,
<NextSvc::Stream as Stream>::Item,
>;
type Future = Ready<Self::Stream>;
fn call(
&self,
request: Request<RequestOctets, IgnoredRequestMeta>,
) -> Self::Future {
match Self::preprocess(&request, &self.key_store) {
Ok(ControlFlow::Continue(Some((modified_req, signer)))) => {
let pp_config = PostprocessingState::new(signer);
let svc_call_fut = self.next_svc.call(modified_req);
let map = PostprocessingStream::new(
svc_call_fut,
request,
pp_config,
Self::map_stream_item,
);
ready(MiddlewareStream::Map(map))
}
Ok(ControlFlow::Continue(None)) => {
let request = request.with_new_metadata(None);
let svc_call_fut = self.next_svc.call(request);
ready(MiddlewareStream::IdentityFuture(svc_call_fut))
}
Ok(ControlFlow::Break(additional)) => {
ready(MiddlewareStream::Result(once(ready(Ok(
CallResult::new(additional),
)))))
}
Err(err) => {
ready(MiddlewareStream::Result(once(ready(Err(err)))))
}
}
}
}
pub struct PostprocessingState<K> {
signer: Option<TsigSigner<K>>,
}
impl<K> PostprocessingState<K> {
fn new(signer: TsigSigner<K>) -> Self {
Self {
signer: Some(signer),
}
}
}
#[derive(Clone, Debug)]
enum TsigSigner<K> {
Transaction(ServerTransaction<K>),
Sequence(ServerSequence<K>),
}
enum TruncationContext<'a, KSeq, KTxn> {
HaveSigner(&'a mut tsig::ServerSequence<KSeq>),
NoSignerOnlyTheKey(KTxn),
}