use core::future::{ready, Ready};
use core::marker::PhantomData;
use std::string::{String, ToString};
use futures_util::stream::Once;
use octseq::{Octets, OctetsBuilder};
use tracing::warn;
use crate::base::iana::OptRcode;
use crate::base::message_builder::{
AdditionalBuilder, OptBuilder, PushError,
};
use crate::base::wire::Composer;
use crate::base::Message;
use crate::base::{MessageBuilder, ParsedName, Rtype, StreamTarget};
use crate::rdata::AllRecordData;
use crate::utils::base16;
use super::message::Request;
use super::service::{Service, ServiceResult};
pub fn mk_builder_for_target<Target>() -> MessageBuilder<StreamTarget<Target>>
where
Target: Composer + Default,
{
let target = StreamTarget::new(Target::default())
.map_err(|_| ())
.expect("Internal error: Unable to create new target.");
MessageBuilder::from_target(target).expect(
"Internal error: Unable to convert target to message builder.",
)
}
pub fn service_fn<RequestOctets, Target, T, RequestMeta, Metadata>(
request_handler: T,
metadata: Metadata,
) -> ServiceFn<Target, T, Metadata>
where
RequestOctets: AsRef<[u8]> + Send + Sync + Unpin,
RequestMeta: Clone + Default,
Metadata: Clone,
Target: Composer + Default,
T: Fn(
Request<RequestOctets, RequestMeta>,
Metadata,
) -> ServiceResult<Target>
+ Clone,
{
ServiceFn {
request_handler,
metadata,
_phantom: PhantomData,
}
}
#[derive(Clone, Debug)]
pub struct ServiceFn<Target, T, Metadata> {
request_handler: T,
metadata: Metadata,
_phantom: PhantomData<Target>,
}
impl<RequestOctets, Target, RequestMeta, T, Metadata>
Service<RequestOctets, RequestMeta> for ServiceFn<Target, T, Metadata>
where
RequestOctets: AsRef<[u8]> + Send + Sync + Unpin,
RequestMeta: Default + Clone,
Metadata: Clone,
Target: Composer + Default + Send + Sync,
T: Fn(
Request<RequestOctets, RequestMeta>,
Metadata,
) -> ServiceResult<Target>
+ Clone,
Self: Clone + Send + Sync + 'static,
{
type Target = Target;
type Stream = Once<Ready<ServiceResult<Self::Target>>>;
type Future = Ready<Self::Stream>;
fn call(
&self,
request: Request<RequestOctets, RequestMeta>,
) -> Self::Future {
ready(futures_util::stream::once(ready((self.request_handler)(
request,
self.metadata.clone(),
))))
}
}
pub(crate) fn to_pcap_text<T: AsRef<[u8]>>(
bytes: T,
num_bytes: usize,
) -> String {
let mut formatted = "000000".to_string();
let hex_encoded = base16::encode_string(&bytes.as_ref()[..num_bytes]);
let mut chars = hex_encoded.chars();
loop {
match (chars.next(), chars.next()) {
(None, None) => break,
(Some(a), Some(b)) => {
formatted.push(' ');
formatted.push(a);
formatted.push(b);
}
_ => unreachable!(),
}
}
formatted
}
pub fn mk_error_response<RequestOctets, Target>(
msg: &Message<RequestOctets>,
rcode: OptRcode,
) -> AdditionalBuilder<StreamTarget<Target>>
where
RequestOctets: Octets,
Target: Composer + Default,
{
let mut additional = mk_builder_for_target()
.start_error(msg, rcode.rcode())
.additional();
if let Err(err) = add_edns_options(&mut additional, |opt| {
opt.set_rcode(rcode);
Ok(())
}) {
warn!("Failed to set (extended) error '{rcode}' in response: {err}");
}
additional
}
pub fn add_edns_options<F, Target>(
response: &mut AdditionalBuilder<StreamTarget<Target>>,
op: F,
) -> Result<(), PushError>
where
F: FnOnce(
&mut OptBuilder<'_, StreamTarget<Target>>,
) -> Result<
(),
<StreamTarget<Target> as OctetsBuilder>::AppendError,
>,
Target: Composer,
{
if response.counts().arcount() > 0
&& response.as_message().opt().is_some()
{
let copied_response = response.as_slice().to_vec();
let Ok(copied_response) = Message::from_octets(&copied_response)
else {
warn!("Internal error: Unable to create message from octets while adding EDNS option");
return Ok(());
};
if let Some(current_opt) = copied_response.opt() {
response.rewind();
if let Ok(current_additional) = copied_response.additional() {
for rr in current_additional.flatten() {
if rr.rtype() != Rtype::OPT {
if let Ok(Some(rr)) = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()
{
response.push(rr)?;
}
}
}
}
let res = response.opt(|builder| {
builder.clone_from(¤t_opt)?;
op(builder)
});
return res;
}
}
response.opt(|builder| op(builder))
}
pub fn remove_edns_opt_record<Target>(
response: &mut AdditionalBuilder<StreamTarget<Target>>,
) -> Result<(), PushError>
where
Target: Composer,
{
if response.counts().arcount() > 0
&& response.as_message().opt().is_some()
{
let copied_response = response.as_slice().to_vec();
let Ok(copied_response) = Message::from_octets(&copied_response)
else {
warn!("Internal error: Unable to create message from octets while adding EDNS option");
return Ok(());
};
if copied_response.opt().is_some() {
response.rewind();
if let Ok(current_additional) = copied_response.additional() {
for rr in current_additional.flatten() {
if rr.rtype() != Rtype::OPT {
if let Ok(Some(rr)) = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()
{
response.push(rr)?;
}
}
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use tokio::time::Instant;
use crate::base::{Message, MessageBuilder, Name, Rtype, StreamTarget};
use crate::net::server::message::{Request, UdpTransportContext};
use crate::base::iana::{OptRcode, Rcode};
use crate::base::message_builder::AdditionalBuilder;
use crate::base::opt::UnknownOptData;
use crate::base::wire::Composer;
use crate::net::server::util::{
add_edns_options, mk_builder_for_target, remove_edns_opt_record,
};
use std::vec::Vec;
#[test]
fn test_add_edns_option() {
let query = MessageBuilder::new_vec();
let mut query = query.question();
query.push((Name::<Bytes>::root(), Rtype::A)).unwrap();
let msg = query.into_message();
let client_ip = "127.0.0.1:12345".parse().unwrap();
let sent_at = Instant::now();
let ctx = UdpTransportContext::default();
let request = Request::new(client_ip, sent_at, msg, ctx.into(), ());
let reply = mk_builder_for_target::<Vec<u8>>()
.start_answer(request.message(), Rcode::NOERROR)
.unwrap();
assert_eq!(reply.counts().arcount(), 0);
assert_eq!(reply.header().rcode(), Rcode::NOERROR);
let mut reply = reply.additional();
reply
.opt(|builder| {
builder.set_rcode(OptRcode::BADCOOKIE);
builder.set_udp_payload_size(123);
Ok(())
})
.unwrap();
assert_eq!(reply.counts().arcount(), 1);
let expected_rcode = Rcode::checked_from_int(0b0111).unwrap();
assert_eq!(reply.header().rcode(), expected_rcode);
let response = assert_opt(
reply.clone(),
expected_rcode,
Some(OptRcode::BADCOOKIE),
);
let opt = response.opt().unwrap();
let options = opt.opt();
assert_eq!(options.len(), 0);
add_edns_options(&mut reply, |builder| builder.padding(123)).unwrap();
let response = assert_opt(
reply.clone(),
expected_rcode,
Some(OptRcode::BADCOOKIE),
);
let opt = response.opt().unwrap();
let options = opt.opt();
assert_eq!(options.iter::<UnknownOptData<_>>().count(), 1);
add_edns_options(&mut reply, |builder| builder.padding(123)).unwrap();
let response = assert_opt(
reply.clone(),
expected_rcode,
Some(OptRcode::BADCOOKIE),
);
let opt = response.opt().unwrap();
let options = opt.opt();
assert_eq!(options.iter::<UnknownOptData<_>>().count(), 2);
}
#[test]
fn test_remove_edns_opt_record() {
let query = MessageBuilder::new_vec();
let mut query = query.question();
query.push((Name::<Bytes>::root(), Rtype::A)).unwrap();
let msg = query.into_message();
let client_ip = "127.0.0.1:12345".parse().unwrap();
let sent_at = Instant::now();
let ctx = UdpTransportContext::default();
let request = Request::new(client_ip, sent_at, msg, ctx.into(), ());
let reply = mk_builder_for_target::<Vec<u8>>()
.start_answer(request.message(), Rcode::NOERROR)
.unwrap();
assert_eq!(reply.counts().arcount(), 0);
let mut reply = reply.additional();
reply.opt(|builder| builder.padding(32)).unwrap();
assert_eq!(reply.counts().arcount(), 1);
assert_opt(reply.clone(), Rcode::NOERROR, Some(OptRcode::NOERROR));
remove_edns_opt_record(&mut reply).unwrap();
assert_opt(reply.clone(), Rcode::NOERROR, None);
}
fn assert_opt<Target: Composer>(
reply: AdditionalBuilder<StreamTarget<Target>>,
expected_rcode: Rcode,
expected_opt_rcode: Option<OptRcode>,
) -> Message<Vec<u8>> {
let response = reply.finish();
let response_bytes = response.as_dgram_slice().to_vec();
let response = Message::from_octets(response_bytes).unwrap();
assert_eq!(response.header().rcode(), expected_rcode);
match expected_opt_rcode {
Some(opt_rcode) => {
assert_eq!(response.header_counts().arcount(), 1);
assert!(response.opt().is_some());
assert_eq!(response.opt_rcode(), opt_rcode);
}
None => {
assert_eq!(response.header_counts().arcount(), 0);
assert!(response.opt().is_none());
}
}
response
}
}