use anyhow::{bail, Result};
use bytes::{BufMut, Bytes, BytesMut};
use ip_network_table::IpNetworkTable;
use log::*;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::time::timeout;
use dnstap_utils::dnstap;
use dnstap_utils::util::dns_message_is_truncated;
use dnstap_utils::util::try_from_u8_slice_for_ipaddr;
use dnstap_utils::util::DnstapHandlerError;
use crate::{Channels, Opts};
const DNS_QUERY_TIMEOUT: Duration = Duration::from_millis(5000);
const DNS_RESPONSE_BUFFER_SIZE: usize = 4096;
pub struct DnstapHandler {
opts: Opts,
channels: Channels,
ignore_query_nets: Option<IpNetworkTable<bool>>,
match_status: Arc<AtomicBool>,
socket: Option<UdpSocket>,
recv_buf: [u8; DNS_RESPONSE_BUFFER_SIZE],
}
#[derive(Error, Debug)]
enum DnstapHandlerInternalError {
#[error("Non-UDP dnstap payload was discarded")]
DiscardNonUdp,
}
#[derive(Debug)]
struct Timespec {
pub seconds: u64,
pub nanoseconds: u32,
}
impl DnstapHandler {
pub async fn new(
opts: &Opts,
channels: &Channels,
match_status: Arc<AtomicBool>,
) -> Result<Self> {
let ignore_query_nets = if !opts.ignore_query_net.is_empty() {
let mut table = IpNetworkTable::new();
for net in &opts.ignore_query_net {
table.insert(*net, true);
}
Some(table)
} else {
None
};
let mut handler = DnstapHandler {
opts: opts.clone(),
channels: channels.clone(),
ignore_query_nets,
match_status,
socket: None,
recv_buf: [0; DNS_RESPONSE_BUFFER_SIZE],
};
handler.maybe_setup_socket().await?;
Ok(handler)
}
async fn maybe_setup_socket(&mut self) -> Result<()> {
if self.socket.is_none() {
let local_address: SocketAddr = if self.opts.dns.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
}
.parse()?;
let socket = UdpSocket::bind(local_address).await?;
if let Some(dscp) = self.opts.dscp {
set_udp_dscp(&socket, dscp)?;
}
socket.connect(&self.opts.dns).await?;
debug!("Connected socket to DNS server: {:?}", &socket);
self.socket = Some(socket);
}
Ok(())
}
async fn restart_socket(&mut self) -> Result<()> {
self.socket = None;
self.maybe_setup_socket().await?;
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
while let Ok(d) = self.channels.receiver.recv().await {
self.maybe_setup_socket().await?;
self.process_dnstap(d).await?
}
Ok(())
}
async fn process_dnstap(&mut self, mut d: dnstap::Dnstap) -> Result<()> {
if dnstap::dnstap::Type::from_i32(d.r#type) != Some(dnstap::dnstap::Type::Message) {
return Ok(());
}
let msg = match &d.message {
Some(msg) => msg,
None => return Ok(()),
};
if dnstap::message::Type::from_i32(msg.r#type) != Some(dnstap::message::Type::AuthResponse)
{
return Ok(());
}
match self.process_dnstap_message(msg).await {
Ok(_) => {
crate::metrics::DNSTAP_PAYLOADS.success.inc();
}
Err(e) => {
crate::metrics::DNSTAP_PAYLOADS.error.inc();
if let Some(e) = e.downcast_ref::<DnstapHandlerError>() {
d.extra = Some(e.serialize().to_vec());
match e {
DnstapHandlerError::Mismatch(_, _, _) => {
self.send_error(d);
crate::metrics::DNS_COMPARISONS.mismatched.inc();
}
DnstapHandlerError::Timeout => {
self.send_timeout(d);
crate::metrics::DNS_QUERIES.timeout.inc();
self.restart_socket().await?;
}
DnstapHandlerError::MissingField => {
self.send_error(d);
}
}
} else if let Some(e) = e.downcast_ref::<DnstapHandlerInternalError>() {
match e {
DnstapHandlerInternalError::DiscardNonUdp => {
crate::metrics::DNSTAP_HANDLER_INTERNAL_ERRORS
.discard_non_udp
.inc();
}
}
}
}
}
Ok(())
}
async fn process_dnstap_message(&mut self, msg: &dnstap::Message) -> Result<()> {
let socket = match &self.socket {
Some(socket) => socket,
None => {
bail!("No connected socket to send DNS queries");
}
};
match &msg.socket_protocol {
Some(socket_protocol) => {
if dnstap::SocketProtocol::from_i32(*socket_protocol)
!= Some(dnstap::SocketProtocol::Udp)
{
bail!(DnstapHandlerInternalError::DiscardNonUdp);
}
}
None => bail!(DnstapHandlerError::MissingField),
};
let query_message = match &msg.query_message {
Some(msg) => msg,
None => return Ok(()),
};
let response_message = match &msg.response_message {
Some(msg) => msg,
None => return Ok(()),
};
let query_address = match &msg.query_address {
Some(addr) => try_from_u8_slice_for_ipaddr(addr)?,
None => bail!(DnstapHandlerError::MissingField),
};
if let Some(table) = &self.ignore_query_nets {
if table.longest_match(query_address).is_some() {
crate::metrics::DNS_COMPARISONS.query_net_ignored.inc();
return Ok(());
};
};
let mut buf = BytesMut::with_capacity(1024);
if self.opts.proxy {
let timespec = if self.opts.proxy_timespec {
Some(Timespec {
seconds: msg.response_time_sec(),
nanoseconds: msg.response_time_nsec(),
})
} else {
None
};
add_proxy_payload(&mut buf, msg, &query_address, timespec)?;
}
buf.put_slice(query_message);
let buf = buf.freeze();
trace!("Sending DNS query: {}", hex::encode(&buf));
socket.send(&buf).await?;
match timeout(DNS_QUERY_TIMEOUT, socket.recv(&mut self.recv_buf)).await {
Ok(res) => match res {
Ok(n_bytes) => {
crate::metrics::DNS_QUERIES.success.inc();
let received_message = &self.recv_buf[..n_bytes];
trace!("Received DNS response: {}", hex::encode(received_message));
if self.match_status.load(Ordering::Relaxed) {
if response_message == received_message {
crate::metrics::DNS_COMPARISONS.matched.inc();
} else if self.opts.ignore_tc
&& (dns_message_is_truncated(response_message)
|| dns_message_is_truncated(received_message))
{
crate::metrics::DNS_COMPARISONS.udp_tc_ignored.inc();
} else {
bail!(DnstapHandlerError::Mismatch(
Bytes::copy_from_slice(received_message),
hex::encode(received_message),
hex::encode(response_message),
));
}
} else {
crate::metrics::DNS_COMPARISONS.suppressed.inc();
}
}
Err(e) => {
crate::metrics::DNS_QUERIES.error.inc();
bail!(e);
}
},
Err(_) => {
bail!(DnstapHandlerError::Timeout);
}
}
Ok(())
}
fn send_error(&self, d: dnstap::Dnstap) {
match self.channels.error_sender.try_send(d) {
Ok(_) => {
crate::metrics::CHANNEL_ERROR_TX.success.inc();
}
Err(_) => {
crate::metrics::CHANNEL_ERROR_TX.error.inc();
}
}
}
fn send_timeout(&self, d: dnstap::Dnstap) {
match self.channels.timeout_sender.try_send(d) {
Ok(_) => {
crate::metrics::CHANNEL_TIMEOUT_TX.success.inc();
}
Err(_) => {
crate::metrics::CHANNEL_TIMEOUT_TX.error.inc();
}
}
}
}
fn add_proxy_payload(
buf: &mut BytesMut,
msg: &dnstap::Message,
query_address: &IpAddr,
timespec: Option<Timespec>,
) -> Result<()> {
const PP2_TYPE_CUSTOM_TIMESPEC: u8 = 0xEA;
const PP2_CUSTOM_TIMESPEC_SIZE: u16 = 3 + 8 + 4;
let needed_tlv_size = if timespec.is_some() {
PP2_CUSTOM_TIMESPEC_SIZE
} else {
0
};
let query_port = match &msg.query_port {
Some(port) => *port as u16,
None => bail!(DnstapHandlerError::MissingField),
};
buf.put(&b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"[..]);
buf.put_u8(0x21);
match query_address {
IpAddr::V4(addr) => {
buf.put_u8(0x12);
buf.put_u16(12 + needed_tlv_size);
buf.put_slice(&addr.octets());
buf.put_u32(0);
}
IpAddr::V6(addr) => {
buf.put_u8(0x22);
buf.put_u16(36 + needed_tlv_size);
buf.put_slice(&addr.octets());
buf.put_u128(0);
}
};
buf.put_u16(query_port);
buf.put_u16(53);
if let Some(timespec) = timespec {
trace!("Sending PROXY v2 custom TLV: {timespec:?}");
buf.put_u8(PP2_TYPE_CUSTOM_TIMESPEC);
buf.put_u8(0);
buf.put_u8(12);
buf.put_u64_le(timespec.seconds);
buf.put_u32_le(timespec.nanoseconds);
}
Ok(())
}
#[cfg(unix)]
fn set_udp_dscp(s: &UdpSocket, dscp: u8) -> Result<()> {
use std::os::unix::io::AsRawFd;
let raw_fd = s.as_raw_fd();
let optval: libc::c_int = (dscp << 2).into();
let ret = match s.local_addr()? {
SocketAddr::V4(_) => unsafe {
libc::setsockopt(
raw_fd,
libc::IPPROTO_IP,
libc::IP_TOS,
&optval as *const _ as *const libc::c_void,
std::mem::size_of_val(&optval) as libc::socklen_t,
)
},
SocketAddr::V6(_) => unsafe {
libc::setsockopt(
raw_fd,
libc::IPPROTO_IPV6,
libc::IPV6_TCLASS,
&optval as *const _ as *const libc::c_void,
std::mem::size_of_val(&optval) as libc::socklen_t,
)
},
};
match ret {
0 => Ok(()),
_ => bail!(
"Failed to set DSCP value {} on socket fd {}: {}",
dscp,
raw_fd,
std::io::Error::last_os_error()
),
}
}
#[cfg(not(unix))]
fn set_udp_dscp(_s: &UdpSocket, _dscp: u8) -> Result<()> {
bail!("Cannot set DSCP values on this platform");
}