#![cfg(all(feature = "tsig", feature = "unstable-client-transport"))]
#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]
use core::ops::DerefMut;
use std::boxed::Box;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use bytes::Bytes;
use octseq::Octets;
use tracing::trace;
use crate::base::message::CopyRecordsError;
use crate::base::message_builder::AdditionalBuilder;
use crate::base::wire::Composer;
use crate::base::Message;
use crate::base::StaticCompressor;
use crate::net::client::request::{
ComposeRequest, ComposeRequestMulti, Error, GetResponse,
GetResponseMulti, SendRequest, SendRequestMulti,
};
use crate::rdata::tsig::Time48;
use crate::tsig::{ClientSequence, ClientTransaction, Key};
#[derive(Clone, Debug)]
enum TsigClient<K> {
Transaction(ClientTransaction<K>),
Sequence(ClientSequence<K>),
}
impl<K> TsigClient<K>
where
K: AsRef<Key>,
{
pub fn answer<Octs>(
&mut self,
message: &mut Message<Octs>,
now: Time48,
) -> Result<(), Error>
where
Octs: Octets + AsMut<[u8]> + ?Sized,
{
match self {
TsigClient::Transaction(client) => client.answer(message, now),
TsigClient::Sequence(client) => client.answer(message, now),
}
.map_err(Error::Authentication)
}
fn done(self) -> Result<(), Error> {
match self {
TsigClient::Transaction(_) => {
Ok(())
}
TsigClient::Sequence(client) => {
client.done().map_err(Error::Authentication)
}
}
}
}
#[derive(Clone)]
pub struct Connection<Upstream, K> {
upstream: Arc<Upstream>,
key: K,
}
impl<Upstream, K> Connection<Upstream, K> {
pub fn new(key: K, upstream: Upstream) -> Self {
Self {
upstream: Arc::new(upstream),
key,
}
}
}
impl<CR, Upstream, K> SendRequest<CR> for Connection<Upstream, K>
where
CR: ComposeRequest + 'static,
Upstream: SendRequest<RequestMessage<CR, K>> + Send + Sync + 'static,
K: Clone + AsRef<Key> + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request::<CR, Upstream, K>::new(
request_msg,
self.key.clone(),
self.upstream.clone(),
))
}
}
impl<CR, Upstream, K> SendRequestMulti<CR> for Connection<Upstream, K>
where
CR: ComposeRequestMulti + 'static,
Upstream: SendRequestMulti<RequestMessage<CR, K>> + Send + Sync + 'static,
K: Clone + AsRef<Key> + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponseMulti + Send + Sync> {
Box::new(Request::<CR, Upstream, K>::new_multi(
request_msg,
self.key.clone(),
self.upstream.clone(),
))
}
}
type Forwarder<Upstream, CR, K> = fn(
&Upstream,
RequestMessage<CR, K>,
Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
) -> RequestState<K>;
fn forwarder<CR, K, Upstream>(
upstream: &Upstream,
msg: RequestMessage<CR, K>,
tsig_client: Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
) -> RequestState<K>
where
CR: ComposeRequest,
Upstream: SendRequest<RequestMessage<CR, K>> + Send + Sync,
{
RequestState::GetResponse(upstream.send_request(msg), tsig_client)
}
fn forwarder_multi<CR, K, Upstream>(
upstream: &Upstream,
msg: RequestMessage<CR, K>,
tsig_client: Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
) -> RequestState<K>
where
CR: ComposeRequestMulti,
Upstream: SendRequestMulti<RequestMessage<CR, K>> + Send + Sync,
{
RequestState::GetResponseMulti(upstream.send_request(msg), tsig_client)
}
struct Request<CR, Upstream, K> {
state: RequestState<K>,
request_msg: Option<CR>,
key: K,
upstream: Arc<Upstream>,
}
impl<CR, Upstream, K> Request<CR, Upstream, K>
where
CR: ComposeRequest,
Upstream: SendRequest<RequestMessage<CR, K>> + Send + Sync,
K: Clone + AsRef<Key>,
Self: GetResponse,
{
fn new(request_msg: CR, key: K, upstream: Arc<Upstream>) -> Self {
Self {
state: RequestState::Init,
request_msg: Some(request_msg),
key,
upstream,
}
}
}
impl<CR, Upstream, K> Request<CR, Upstream, K>
where
CR: Sync + Send,
K: Clone + AsRef<Key>,
{
fn new_multi(request_msg: CR, key: K, upstream: Arc<Upstream>) -> Self {
Self {
state: RequestState::Init,
request_msg: Some(request_msg),
key,
upstream,
}
}
async fn get_response_impl(
&mut self,
upstream_sender: Forwarder<Upstream, CR, K>,
) -> Result<Option<Message<Bytes>>, Error> {
let (response, tsig_client) = loop {
match &mut self.state {
RequestState::Init => {
let tsig_client = Arc::new(std::sync::Mutex::new(None));
let msg = RequestMessage::new(
self.request_msg.take().unwrap(),
self.key.clone(),
tsig_client.clone(),
);
trace!("Sending request upstream...");
self.state =
upstream_sender(&self.upstream, msg, tsig_client);
continue;
}
RequestState::GetResponse(request, tsig_client) => {
let response = request.get_response().await?;
break (Some(response), tsig_client);
}
RequestState::GetResponseMulti(request, tsig_client) => {
let response = request.get_response().await?;
break (response, tsig_client);
}
RequestState::Complete => {
return Err(Error::StreamReceiveError);
}
}
};
let res = Self::validate_response(response, tsig_client)?;
if res.is_none() {
self.state = RequestState::Complete;
}
Ok(res)
}
fn validate_response(
response: Option<Message<Bytes>>,
tsig_client: &mut Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
) -> Result<Option<Message<Bytes>>, Error> {
let res = match response {
None => {
let client = tsig_client.lock().unwrap().take().unwrap();
client.done()?;
None
}
Some(msg) => {
let mut modifiable_msg =
Message::from_octets(msg.as_slice().to_vec())?;
if let Some(client) = tsig_client.lock().unwrap().deref_mut()
{
trace!("Validating TSIG for sequence reply");
client.answer(&mut modifiable_msg, Time48::now())?;
}
let out_vec = modifiable_msg.into_octets();
let out_bytes = Bytes::from(out_vec);
let out_msg = Message::<Bytes>::from_octets(out_bytes)?;
Some(out_msg)
}
};
Ok(res)
}
}
impl<CR, Upstream, K> Debug for Request<CR, Upstream, K> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
f.debug_struct("Request").finish()
}
}
impl<CR, Upstream, K> GetResponse for Request<CR, Upstream, K>
where
CR: ComposeRequest,
Upstream: SendRequest<RequestMessage<CR, K>> + Send + Sync,
K: Clone + AsRef<Key> + Send + Sync,
{
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(async move {
self.get_response_impl(forwarder).await.map(|v| v.unwrap())
})
}
}
impl<CR, Upstream, K> GetResponseMulti for Request<CR, Upstream, K>
where
CR: ComposeRequestMulti,
Upstream: SendRequestMulti<RequestMessage<CR, K>> + Send + Sync,
K: Clone + AsRef<Key> + Send + Sync,
{
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl(forwarder_multi))
}
}
enum RequestState<K> {
Init,
GetResponse(
Box<dyn GetResponse + Send + Sync>,
Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
),
GetResponseMulti(
Box<dyn GetResponseMulti + Send + Sync>,
Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
),
Complete,
}
#[derive(Clone, Debug)]
pub struct RequestMessage<CR, K>
where
CR: Send + Sync,
{
request: CR,
key: K,
signer: Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
}
impl<CR, K> RequestMessage<CR, K>
where
CR: Send + Sync,
{
fn new(
request: CR,
key: K,
signer: Arc<std::sync::Mutex<Option<TsigClient<K>>>>,
) -> Self
where
CR: Sync + Send,
K: Clone + AsRef<Key>,
{
Self {
request,
key,
signer,
}
}
}
impl<CR, K> ComposeRequest for RequestMessage<CR, K>
where
CR: ComposeRequest,
K: Clone + Debug + Send + Sync + AsRef<Key>,
{
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let mut target = self.request.append_message(target)?;
let client = {
trace!(
"Signing single request transaction with key '{}'",
self.key.as_ref().name()
);
TsigClient::Transaction(
ClientTransaction::request(
self.key.clone(),
&mut target,
Time48::now(),
)
.unwrap(),
)
};
*self.signer.lock().unwrap() = Some(client);
Ok(target)
}
fn to_vec(&self) -> Result<Vec<u8>, Error> {
let msg = self.to_message()?;
Ok(msg.as_octets().clone())
}
fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
let mut target = StaticCompressor::new(Vec::new());
self.append_message(&mut target)?;
let msg = Message::from_octets(target.into_target()).expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
fn header(&self) -> &crate::base::Header {
self.request.header()
}
fn header_mut(&mut self) -> &mut crate::base::Header {
self.request.header_mut()
}
fn set_udp_payload_size(&mut self, value: u16) {
self.request.set_udp_payload_size(value)
}
fn set_dnssec_ok(&mut self, value: bool) {
self.request.set_dnssec_ok(value)
}
fn add_opt(
&mut self,
opt: &impl crate::base::opt::ComposeOptData,
) -> Result<(), crate::base::opt::LongOptData> {
self.request.add_opt(opt)
}
fn is_answer(&self, answer: &Message<[u8]>) -> bool {
self.request.is_answer(answer)
}
fn dnssec_ok(&self) -> bool {
self.request.dnssec_ok()
}
}
impl<CR, K> ComposeRequestMulti for RequestMessage<CR, K>
where
CR: ComposeRequestMulti,
K: Clone + Debug + Send + Sync + AsRef<Key>,
{
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let mut target = self.request.append_message(target)?;
trace!(
"Signing streaming request sequence with key '{}'",
self.key.as_ref().name()
);
let client = TsigClient::Sequence(
ClientSequence::request(
self.key.clone(),
&mut target,
Time48::now(),
)
.unwrap(),
);
*self.signer.lock().unwrap() = Some(client);
Ok(target)
}
fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
let mut target = StaticCompressor::new(Vec::new());
self.append_message(&mut target)?;
let msg = Message::from_octets(target.into_target()).expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
fn header(&self) -> &crate::base::Header {
self.request.header()
}
fn header_mut(&mut self) -> &mut crate::base::Header {
self.request.header_mut()
}
fn set_udp_payload_size(&mut self, value: u16) {
self.request.set_udp_payload_size(value)
}
fn set_dnssec_ok(&mut self, value: bool) {
self.request.set_dnssec_ok(value)
}
fn add_opt(
&mut self,
opt: &impl crate::base::opt::ComposeOptData,
) -> Result<(), crate::base::opt::LongOptData> {
self.request.add_opt(opt)
}
fn is_answer(&self, answer: &Message<[u8]>) -> bool {
self.request.is_answer(answer)
}
fn dnssec_ok(&self) -> bool {
self.request.dnssec_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::iana::Rcode;
use crate::base::message_builder::QuestionBuilder;
use crate::base::{MessageBuilder, Name, Rtype};
use crate::tsig::{
Algorithm, KeyName, KeyStore, ServerSequence, ServerTransaction,
ValidationError,
};
use core::future::ready;
use core::str::FromStr;
#[tokio::test]
async fn single_signed_valid_response() {
do_single_response(false).await;
}
#[tokio::test]
async fn single_signed_invalid_response() {
do_single_response(true).await;
}
async fn do_single_response(invalidate_signature: bool) {
let msg = mk_request_msg(Rtype::A);
let req =
crate::net::client::request::RequestMessage::new(msg).unwrap();
let key = mk_tsig_key();
let upstream =
Arc::new(MockUpstream::new(key.clone(), invalidate_signature));
let mut req = Request::new(req, key, upstream);
let res = req.get_response().await;
assert_eq!(res.is_err(), invalidate_signature);
if let Ok(res) = res {
assert_eq!(res.header_counts().arcount(), 0, "TSIG RR should have been removed from the additional section during response processing");
}
}
#[tokio::test]
async fn multiple_signed_valid_responses() {
do_multiple_responses(false, false).await
}
#[tokio::test]
async fn multiple_signed_responses_with_one_invalid() {
do_multiple_responses(true, false).await
}
#[tokio::test]
async fn multiple_signed_valid_responses_and_a_final_unsigned_response() {
do_multiple_responses(false, true).await
}
async fn do_multiple_responses(
invalidate_signature: bool,
dont_sign_last_response: bool,
) {
let msg = mk_request_msg(Rtype::AXFR);
let req = crate::net::client::request::RequestMessageMulti::new(msg)
.unwrap();
let key = mk_tsig_key();
let upstream = Arc::new(MockUpstreamMulti::new(
key.clone(),
invalidate_signature,
dont_sign_last_response,
));
let mut req = Request::new_multi(req, key, upstream);
let res = req
.get_response()
.await
.unwrap()
.expect("First response is missing");
assert_eq!(res.header_counts().arcount(), 0, "TSIG RR should have been removed from the additional section during response processing");
let res = req.get_response().await;
if invalidate_signature {
assert!(
matches!(
res,
Err(Error::Authentication(ValidationError::BadSig))
),
"Expected error BadSig but the result was: {res:?}"
);
} else {
assert!(res.is_ok(), "Unexpected error message: {res:?}");
}
if let Ok(res) = res {
let res = res.expect("Second response is missing");
assert_eq!(res.header_counts().arcount(), 0, "TSIG RR should have been removed from the additional section during response processing");
let res = req
.get_response()
.await
.unwrap()
.expect("Third response is missing");
if dont_sign_last_response {
assert_eq!(res.header_counts().arcount(), 0, "TSIG RR should never have been added to the additional section during response generation");
} else {
assert_eq!(res.header_counts().arcount(), 0, "TSIG RR should have been removed from the additional section during response processing");
}
if dont_sign_last_response {
assert!(
matches!(req.get_response().await, Err(Error::Authentication(ValidationError::TooManyUnsigned))),
"Receiving another response should have failed because the last response should have lacked a signature"
);
} else {
assert!(
req.get_response().await.unwrap().is_none(),
"There should not be a fourth response"
);
}
}
}
fn mk_request_msg(rtype: Rtype) -> QuestionBuilder<Vec<u8>> {
let mut msg = MessageBuilder::new_vec();
msg.header_mut().set_rd(true);
msg.header_mut().set_ad(true);
let mut msg = msg.question();
msg.push((Name::vec_from_str("example.com").unwrap(), rtype))
.unwrap();
msg
}
fn mk_tsig_key() -> Arc<Key> {
let key_name = KeyName::from_str("demo-key").unwrap();
let secret = crate::utils::base64::decode::<Vec<u8>>(
"zlCZbVJPIhobIs1gJNQfrsS3xCxxsR9pMUrGwG8OgG8=",
)
.unwrap();
Arc::new(
Key::new(Algorithm::Sha256, &secret, key_name, None, None)
.unwrap(),
)
}
#[derive(Debug)]
struct MockGetResponse<CR, KS> {
request_msg: CR,
key_store: KS,
invalidate_signature: bool,
}
impl<CR, KS> MockGetResponse<CR, KS> {
fn new(
request_msg: CR,
key_store: KS,
invalidate_signature: bool,
) -> Self {
Self {
request_msg,
key_store,
invalidate_signature,
}
}
}
impl<CR: ComposeRequest + Debug, KS: Debug + KeyStore> GetResponse
for MockGetResponse<CR, KS>
{
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
let mut req = self.request_msg.to_message().unwrap();
let tsig = ServerTransaction::request(
&self.key_store,
&mut req,
Time48::now(),
)
.unwrap()
.unwrap();
let builder = MessageBuilder::new_bytes();
let builder = builder.start_answer(&req, Rcode::NOERROR).unwrap();
let mut builder = builder.additional();
tsig.answer(&mut builder, Time48::now()).unwrap();
if self.invalidate_signature {
builder.header_mut().set_rcode(Rcode::SERVFAIL);
}
let res = builder.into_message();
assert_eq!(res.header_counts().arcount(), 1, "Constructed response lacks a TSIG RR in the additional section");
Box::pin(ready(Ok(res)))
}
}
#[derive(Debug)]
struct MockGetResponseMulti<CR, KS> {
request_msg: CR,
key_store: KS,
sent_request: Option<Message<Vec<u8>>>,
num_responses_generated: usize,
signer: Option<ServerSequence<KS>>,
invalidate_signature: bool,
dont_sign_last_response: bool,
}
impl<CR, KS> MockGetResponseMulti<CR, KS> {
fn new(
request_msg: CR,
key_store: KS,
invalidate_signature: bool,
dont_sign_last_response: bool,
) -> Self {
Self {
request_msg,
key_store,
sent_request: None,
num_responses_generated: 0,
signer: None,
invalidate_signature,
dont_sign_last_response,
}
}
}
impl<CR, KS> GetResponseMulti for MockGetResponseMulti<CR, KS>
where
CR: ComposeRequestMulti + Debug,
KS: Debug + KeyStore<Key = KS> + AsRef<Key>,
{
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
if self.num_responses_generated == 3 {
return Box::pin(ready(Ok(None)));
}
self.num_responses_generated += 1;
let mut tsig = match self.signer.take() {
Some(tsig) => tsig,
None => {
let mut req = self.request_msg.to_message().unwrap();
let tsig = ServerSequence::request(
&self.key_store,
&mut req,
Time48::now(),
)
.unwrap()
.unwrap();
self.sent_request = Some(req);
tsig
}
};
let req = self.sent_request.as_ref().unwrap();
let builder = MessageBuilder::new_bytes();
let builder = builder.start_answer(req, Rcode::NOERROR).unwrap();
let mut builder = builder.additional();
let (sign, invalidate) = match self.num_responses_generated {
1 => (true, false),
2 => (true, self.invalidate_signature),
3 => (!self.dont_sign_last_response, false),
_ => unreachable!(),
};
eprintln!(
"Response {}: sign={}, invalidate={}",
self.num_responses_generated, sign, invalidate
);
if sign {
tsig.answer(&mut builder, Time48::now()).unwrap();
}
self.signer = Some(tsig);
if invalidate {
builder.header_mut().set_rcode(Rcode::SERVFAIL);
}
let res = builder.into_message();
if sign {
assert_eq!(res.header_counts().arcount(), 1, "Constructed response lacks a TSIG RR in the additional section");
let rec = res.additional().unwrap().next().unwrap().unwrap();
assert_eq!(rec.rtype(), Rtype::TSIG);
}
Box::pin(ready(Ok(Some(res))))
}
}
struct MockUpstream {
key: Arc<Key>,
invalidate_signature: bool,
}
impl MockUpstream {
fn new(key: Arc<Key>, invalidate_signature: bool) -> Self {
Self {
key,
invalidate_signature,
}
}
}
impl<CR: ComposeRequest + Debug + Send + Sync + 'static> SendRequest<CR>
for MockUpstream
{
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(MockGetResponse::new(
request_msg,
self.key.clone(),
self.invalidate_signature,
))
}
}
struct MockUpstreamMulti {
key: Arc<Key>,
invalidate_signature: bool,
dont_sign_last_response: bool,
}
impl MockUpstreamMulti {
fn new(
key: Arc<Key>,
invalidate_signature: bool,
dont_sign_last_response: bool,
) -> Self {
Self {
key,
invalidate_signature,
dont_sign_last_response,
}
}
}
impl<CR> SendRequestMulti<CR> for MockUpstreamMulti
where
CR: ComposeRequestMulti + Debug + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponseMulti + Send + Sync> {
Box::new(MockGetResponseMulti::new(
request_msg,
self.key.clone(),
self.invalidate_signature,
self.dont_sign_last_response,
))
}
}
}