use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use tokio::time::{Duration, Instant};
use crate::dns::dnspkt;
use crate::dns::parse;
lazy_static::lazy_static! {
static ref DNS_TIMEOUT: RwLock<Duration> = RwLock::new(Duration::from_millis(800));
}
const MIN_DNS_TIMEOUT: Duration = Duration::from_millis(300);
const MAX_DNS_TIMEOUT: Duration = Duration::from_millis(2000);
type TcpNameserverChannel = tokio::sync::mpsc::Sender<TcpNameserverMessage>;
lazy_static::lazy_static! {
static ref NAMESERVER_INFO: tokio::sync::Mutex<std::collections::HashMap<std::net::SocketAddr,TcpNameserverChannel>> = Default::default();
static ref DNS_SENT_QUERIES: prometheus::IntCounterVec =
prometheus::register_int_counter_vec!("dns_out_query_packets_sent",
"Number of DNS out queries packets sent",
&["dns_server", "protocol"]
)
.unwrap();
static ref OUT_QUERY_LATENCY: prometheus::HistogramVec =
prometheus::register_histogram_vec!("dns_out_query_latency",
"DNS latency for out queries",
&["dns_server", "protocol"])
.unwrap();
static ref OUT_QUERY_RESULT: prometheus::IntCounterVec =
prometheus::register_int_counter_vec!("dns_out_query_result",
"DNS out query results",
&["dns_server", "result"])
.unwrap();
static ref OUT_QUERY_RETRY: prometheus::IntCounterVec =
prometheus::register_int_counter_vec!("dns_out_query_retries",
"DNS out query retry reasons",
&["dns_server", "reason"])
.unwrap();
static ref OUT_QUERY_OUTSTANDING: prometheus::IntGaugeVec =
prometheus::register_int_gauge_vec!("dns_out_query_outstanding",
"Number of out queries currently outstanding",
&["dns_server"])
.unwrap();
static ref OUT_QUERY_TIMEOUT: prometheus::IntGauge =
prometheus::register_int_gauge!("dns_out_query_timeout_ms",
"The current dynamic timeout for out queries").unwrap();
}
#[derive(Debug)]
pub enum Error {
Timeout,
FailedToSend(std::io::Error),
FailedToSendMsg(String),
FailedToRecv(std::io::Error),
FailedToRecvMsg(String),
TcpConnection(String),
Parse(String),
Internal(String),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
use Error::*;
match self {
Timeout => write!(f, "Timeout"),
FailedToSend(err) => write!(f, "Failed to send out query: {}", err),
FailedToSendMsg(msg) => write!(f, "Failed to send out query: {}", msg),
FailedToRecv(err) => write!(f, "Failed to receive out query: {}", err),
FailedToRecvMsg(msg) => write!(f, "Failed to receive out query: {}", msg),
TcpConnection(err) => {
write!(f, "TCP connection error while waiting for result: {}", err)
}
Parse(err) => write!(f, "Failed to parse out reply: {}", err),
Internal(err) => write!(f, "Internal error in out query handling: {}", err),
}
}
}
type Protocol = super::Protocol;
fn increment_result(dns_server: &str, result: &Result<dnspkt::DNSPkt, Error>) {
OUT_QUERY_RESULT
.with_label_values(&[
dns_server,
&match result {
Ok(pkt) => match pkt
.edns
.as_ref()
.and_then(|edns| edns.get_extended_dns_error())
{
Some((code, _msg)) => format!("{} ({})", pkt.rcode, code),
_ => pkt.rcode.to_string(),
},
Err(Error::Timeout) => "TIMEOUT".into(),
Err(Error::FailedToSend(io)) => format!("SEND: {}", io),
Err(Error::FailedToRecv(io)) => format!("RECV: {}", io),
Err(Error::FailedToSendMsg(msg)) => format!("SEND: {}", msg),
Err(Error::FailedToRecvMsg(msg)) => format!("RECV: {}", msg),
Err(Error::Parse(msg)) => format!("PARSE_ERROR: {}", msg),
Err(Error::Internal(msg)) => format!("INTERNAL: {}", msg),
Err(Error::TcpConnection(msg)) => format!("TCP: {}", msg),
},
])
.inc()
}
type Responder<T> = tokio::sync::oneshot::Sender<Result<T, Error>>;
struct TcpNameserverMessage {
out_query: super::dnspkt::DNSPkt,
out_reply: Responder<super::dnspkt::DNSPkt>,
}
struct TcpNameserver {
addr: std::net::SocketAddr,
tcp: Option<tokio::net::TcpStream>,
tcp_last_send_activity: Instant,
tcp_last_recv_activity: Instant,
qid2reply: std::collections::HashMap<u16, Responder<super::dnspkt::DNSPkt>>,
}
impl TcpNameserver {
fn start(addr: std::net::SocketAddr) -> TcpNameserverChannel {
let (tx, rx) = tokio::sync::mpsc::channel(2);
let ret = Box::new(Self {
addr,
tcp: None,
tcp_last_send_activity: Instant::now(),
tcp_last_recv_activity: Instant::now(),
qid2reply: Default::default(),
});
tokio::task::spawn(ret.run(rx));
tx
}
async fn send_query_to(
addr: &std::net::SocketAddr,
out_query: super::dnspkt::DNSPkt,
) -> Result<super::dnspkt::DNSPkt, Error> {
let chan = NAMESERVER_INFO
.lock()
.await
.entry(*addr)
.or_insert_with(|| TcpNameserver::start(*addr))
.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
let _timer = OUT_QUERY_LATENCY
.with_label_values(&[&addr.to_string(), "TCP"])
.start_timer();
chan.send(TcpNameserverMessage {
out_query,
out_reply: tx,
})
.await
.map_err(|err| Error::Internal(format!("Channel send failed: {}", err)))?;
match rx.await {
Ok(ret) => ret,
Err(err) => Err(Error::Internal(format!("Channel recv failed: {}", err))),
}
}
async fn send_tcp_reply(&mut self, qid: u16, reply: Result<super::dnspkt::DNSPkt, Error>) {
if let Some(resp) = self.qid2reply.remove(&qid) {
resp.send(reply).unwrap();
} else {
log::error!("Sending reply to unknown request: {:?}", reply);
}
}
async fn send_tcp_query(&mut self, msg: TcpNameserverMessage) -> Result<(), Error> {
assert!(
self.qid2reply
.insert(msg.out_query.qid, msg.out_reply)
.is_none()
); if let Some(ref mut tcp_sock) = self.tcp {
use tokio::io::AsyncWriteExt as _;
let bytes = msg.out_query.serialise();
let mut buf: Vec<u8> = vec![];
buf.reserve_exact(2 + bytes.len());
buf.extend((bytes.len() as u16).to_be_bytes().iter());
buf.extend(bytes);
DNS_SENT_QUERIES
.with_label_values(&[&self.addr.to_string(), "TCP"])
.inc();
let ret = tcp_sock.write_all(&buf).await.map_err(Error::FailedToSend);
self.tcp_last_send_activity = Instant::now();
ret
} else {
panic!("Write on non-existant tcp socket");
}
}
async fn read_reply(&mut self) -> Result<Vec<u8>, Error> {
if let Some(ref mut tcp_sock) = self.tcp {
use tokio::io::AsyncReadExt as _;
let mut lbuf = [0u8; 2];
tcp_sock
.read_exact(&mut lbuf)
.await
.map_err(Error::FailedToRecv)?;
let l = u16::from_be_bytes(lbuf);
let mut msg_buf = vec![0u8; l as usize];
log::trace!("Reading {} bytes from TCP socket", l);
tcp_sock
.read_exact(&mut msg_buf[..])
.await
.map_err(Error::FailedToRecv)?;
self.tcp_last_recv_activity = Instant::now();
Ok(msg_buf)
} else {
panic!("Read from non existant tcp socket");
}
}
async fn handle_reply(&mut self, buf: &[u8]) {
let pkt = match parse::PktParser::new(buf).get_dns().map_err(Error::Parse) {
Ok(pkt) => pkt,
Err(err) => {
log::error!("{}", err);
return;
}
};
self.send_tcp_reply(pkt.qid, Ok(pkt)).await
}
fn tcp_teardown(&mut self, err: Error) {
self.tcp = None;
log::trace!("Tearing down {} TCP channel: {}", self.addr, err);
for (_qid, chan) in self.qid2reply.drain() {
chan.send(Err(Error::TcpConnection(format!(
"TCP channel closed before reply: {}",
err
))))
.unwrap();
}
}
async fn run(mut self, mut chan: tokio::sync::mpsc::Receiver<TcpNameserverMessage>) {
loop {
let last_send_activity = self.tcp_last_send_activity;
let last_recv_activity = self.tcp_last_recv_activity;
if self.tcp.is_some() {
use futures::FutureExt as _;
futures::select! {
msg = chan.recv().fuse() => if let Some(msg) = msg {
if let Err(e) = self.send_tcp_query(msg).await {
self.tcp_teardown(e);
}
} else {
return;
},
ret = self.read_reply().fuse() => match ret {
Ok(msg) => self.handle_reply(&msg[..]).await,
Err(e) => {
self.tcp_teardown(e);
},
},
() = tokio::time::sleep_until(last_send_activity + std::time::Duration::from_secs(120)).fuse() => {
self.tcp_teardown(Error::TcpConnection("TCP Connection idle".into()));
},
() = tokio::time::sleep_until(last_recv_activity + std::time::Duration::from_secs(120)).fuse() => {
self.tcp_teardown(Error::TcpConnection("Timed out waiting for TCP replies".into()));
},
}
} else if let Some(msg) = chan.recv().await {
log::trace!("Opening new TCP channel to {}", self.addr);
match tokio::net::TcpStream::connect(self.addr).await {
Ok(sock) => self.tcp = Some(sock),
Err(err) => {
msg.out_reply.send(Err(Error::FailedToSend(err))).unwrap();
continue;
}
}
self.tcp_last_recv_activity = Instant::now();
self.tcp_last_send_activity = Instant::now();
if let Err(e) = self.send_tcp_query(msg).await {
self.tcp_teardown(e);
}
} else {
return;
}
}
}
}
fn create_outquery(id: u16, in_query: &dnspkt::DNSPkt) -> dnspkt::DNSPkt {
dnspkt::DNSPkt {
qid: id,
rd: true,
tc: false,
aa: false,
qr: false,
opcode: dnspkt::OPCODE_QUERY,
cd: false,
ad: false,
ra: false,
rcode: dnspkt::NOERROR,
bufsize: 4096,
edns_ver: Some(0),
edns_do: in_query.edns_do,
question: in_query.question.clone(),
answer: vec![],
nameserver: vec![],
additional: vec![],
edns: Some(dnspkt::EdnsData::new()),
}
}
#[derive(Clone)]
pub struct OutQuery;
impl OutQuery {
pub fn new() -> Self {
OutQuery
}
async fn send_single_udp(
&self,
addr: std::net::SocketAddr,
oq: super::dnspkt::DNSPkt,
) -> Result<(Duration, dnspkt::DNSPkt), Error> {
let start = Instant::now();
let outsock = UdpSocket::bind(match addr {
std::net::SocketAddr::V4(_) => "0.0.0.0:0",
std::net::SocketAddr::V6(_) => "[::]:0",
})
.await
.map_err(Error::FailedToSend)?;
outsock.connect(addr).await.map_err(Error::FailedToSend)?;
log::trace!(
"Sending query {} → {} ({})",
outsock.local_addr().unwrap(),
addr,
oq.qid
);
DNS_SENT_QUERIES
.with_label_values(&[&addr.to_string(), "UDP"])
.inc();
outsock
.send(oq.serialise().as_slice())
.await
.map_err(Error::FailedToSend)?;
let mut buf = [0; 65536]; let l = outsock.recv(&mut buf).await.map_err(Error::FailedToRecv)?;
let pkt = parse::PktParser::new(&buf[0..l])
.get_dns()
.map_err(Error::Parse)?;
let duration = Instant::now() - start;
Ok((duration, pkt))
}
async fn send_udp(
&self,
addr: std::net::SocketAddr,
oq: &super::dnspkt::DNSPkt,
) -> Result<dnspkt::DNSPkt, Error> {
let mut attempts = futures::stream::FuturesUnordered::new();
log::trace!("OutQuery: {:?}", oq);
let initial_timeout: Duration = *DNS_TIMEOUT.read().await;
OUT_QUERY_TIMEOUT.set(initial_timeout.as_millis() as i64);
let mut timeout = initial_timeout;
let _timer = OUT_QUERY_LATENCY
.with_label_values(&[&addr.to_string(), "UDP"])
.start_timer();
loop {
use futures::FutureExt as _;
use futures::StreamExt as _;
attempts.push(self.send_single_udp(addr, oq.clone()));
futures::select! {
ret = attempts.next() =>
return match ret {
None => Err(Error::FailedToRecvMsg("No attempts made".into())),
Some(Err(e)) => Err(e),
Some(Ok((dur, pkt))) => {
if attempts.len() > 1 {
let mut timeout = DNS_TIMEOUT.write().await;
if dur < initial_timeout {
if dur <= *timeout {
const ALPHA : u32 = 10;
const BASE : u32 = 1000;
let new_timeout = (dur * ALPHA + *timeout * ( BASE - ALPHA)) / BASE;
*timeout = std::cmp::max(
std::cmp::min(new_timeout, MAX_DNS_TIMEOUT),
MIN_DNS_TIMEOUT);
}
} else {
const HEADROOM : u32 = 10;
const BASE : u32 = 100;
let new_timeout = dur * (BASE + HEADROOM/BASE);
*timeout = std::cmp::max(
std::cmp::min(std::cmp::max(*timeout, new_timeout), MAX_DNS_TIMEOUT),
MIN_DNS_TIMEOUT);
}
}
Ok(pkt)
}
},
() = tokio::time::sleep(timeout).fuse() => {
if attempts.len() > 3 {
return Err(Error::Timeout);
}
OUT_QUERY_RETRY
.with_label_values(&[&addr.to_string(), "TIMEOUT"])
.inc();
use rand::prelude::*;
let jitter = rand::rng().random_range(std::time::Duration::from_secs(0)..timeout);
timeout += (timeout / 2) + jitter;
},
}
}
}
async fn handle_query_internal(
&self,
msg: &super::DnsMessage,
addr: std::net::SocketAddr,
) -> Result<dnspkt::DNSPkt, Error> {
use rand::TryRng as _;
let id = rand::rngs::SysRng.try_next_u32().unwrap() as u16;
let oq = create_outquery(id, &msg.in_query);
let out_reply;
match msg.protocol {
Protocol::Udp => {
let reply = self.send_udp(addr, &oq).await?;
if reply.qid != id {
OUT_QUERY_RETRY
.with_label_values(&[&addr.to_string(), "KAMINSKY"])
.inc();
out_reply = TcpNameserver::send_query_to(&addr, oq).await?;
} else if reply.tc {
OUT_QUERY_RETRY
.with_label_values(&[&addr.to_string(), "TRUNCATED"])
.inc();
out_reply = TcpNameserver::send_query_to(&addr, oq).await?;
} else {
out_reply = reply;
}
}
Protocol::Tcp => {
out_reply = TcpNameserver::send_query_to(&addr, oq).await?;
}
}
if out_reply.qid != id {
log::warn!("Mismatched ID: {} != {}", out_reply.qid, id);
}
Ok(out_reply)
}
pub async fn handle_query(
&self,
msg: &super::DnsMessage,
addr: std::net::SocketAddr,
) -> Result<dnspkt::DNSPkt, super::Error> {
OUT_QUERY_OUTSTANDING
.with_label_values(&[&addr.to_string()])
.inc();
let ret = self.handle_query_internal(msg, addr).await;
OUT_QUERY_OUTSTANDING
.with_label_values(&[&addr.to_string()])
.dec();
increment_result(&addr.to_string(), &ret);
ret.map_err(super::Error::OutReply)
}
}