use crate::core::MtopError;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question};
use crate::dns::name::Name;
use crate::net::tcp_connect;
use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig};
use crate::timeout::Timeout;
use std::fmt::{self, Formatter};
use std::io::{self, Cursor, Error};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadBuf};
use tokio::net::UdpSocket;
const DEFAULT_NAMESERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53);
const DEFAULT_MESSAGE_BUFFER: usize = 512;
#[derive(Debug, Clone)]
pub struct DnsClientConfig {
pub nameservers: Vec<SocketAddr>,
pub timeout: Duration,
pub attempts: u8,
pub rotate: bool,
pub pool_max_idle: u64,
}
impl Default for DnsClientConfig {
fn default() -> Self {
Self {
nameservers: vec![DEFAULT_NAMESERVER],
timeout: Duration::from_secs(5),
attempts: 2,
rotate: false,
pool_max_idle: 1,
}
}
}
#[derive(Debug)]
pub struct DnsClient {
config: DnsClientConfig,
server_idx: AtomicUsize,
udp_pool: ClientPool<SocketAddr, UdpClient, UdpFactory>,
tcp_pool: ClientPool<SocketAddr, TcpClient, TcpFactory>,
}
impl DnsClient {
pub fn new(config: DnsClientConfig) -> Self {
let udp_config = ClientPoolConfig {
name: "dns-udp".to_owned(),
max_idle: config.pool_max_idle,
};
let tcp_config = ClientPoolConfig {
name: "dns-tcp".to_owned(),
max_idle: config.pool_max_idle,
};
Self {
config,
server_idx: AtomicUsize::new(0),
udp_pool: ClientPool::new(udp_config, UdpFactory),
tcp_pool: ClientPool::new(tcp_config, TcpFactory),
}
}
pub async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let question = Question::new(full, rtype).set_qclass(rclass);
let message = Message::new(id, flags).add_question(question);
let mut attempt = 0;
loop {
match self.exchange(&message, usize::from(attempt)).await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt + 1 >= self.config.attempts {
return Err(e);
}
tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.attempts, err = %e);
attempt += 1;
}
}
}
}
async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let server = self.nameserver(attempt);
let res = async {
let mut client = self.udp_pool.get(&server).await?;
let res = client.exchange(msg).await;
if res.is_ok() {
self.udp_pool.put(client).await;
}
res
}
.timeout(self.config.timeout, format!("client.exchange udp://{}", server))
.await?;
if res.flags().is_truncated() {
tracing::debug!(message = "UDP response truncated, retrying with TCP", flags = ?res.flags(), server = %server);
async {
let mut client = self.tcp_pool.get(&server).await?;
let res = client.exchange(msg).await;
if res.is_ok() {
self.tcp_pool.put(client).await;
}
res
}
.timeout(self.config.timeout, format!("client.exchange tcp://{}", server))
.await
} else {
Ok(res)
}
}
fn nameserver(&self, attempt: usize) -> SocketAddr {
let idx = if self.config.rotate {
self.server_idx.fetch_add(1, Ordering::Relaxed)
} else {
attempt
};
self.config.nameservers[idx % self.config.nameservers.len()]
}
}
struct TcpClient {
read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
size: usize,
}
impl TcpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: BufReader::new(Box::new(read)),
write: BufWriter::new(Box::new(write)),
size,
}
}
async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);
msg.write_network_bytes(&mut buf)?;
self.write.write_u16(buf.len() as u16).await?;
self.write.write_all(&buf).await?;
self.write.flush().await?;
let sz = self.read.read_u16().await?;
buf.clear();
buf.resize(usize::from(sz), 0);
self.read.read_exact(&mut buf).await?;
let mut cur = Cursor::new(buf);
let res = Message::read_network_bytes(&mut cur)?;
if res.id() != msg.id() {
Err(MtopError::runtime(format!(
"unexpected DNS MessageId; expected {}, got {}",
msg.id(),
res.id()
)))
} else {
Ok(res)
}
}
}
impl fmt::Debug for TcpClient {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "TcpClient {{ read: ..., write: ..., size: {} }}", self.size)
}
}
struct UdpClient {
read: Box<dyn AsyncRead + Send + Sync + Unpin>,
write: Box<dyn AsyncWrite + Send + Sync + Unpin>,
size: usize,
}
impl UdpClient {
fn new<R, W>(read: R, write: W, size: usize) -> Self
where
R: AsyncRead + Unpin + Sync + Send + 'static,
W: AsyncWrite + Unpin + Sync + Send + 'static,
{
Self {
read: Box::new(read),
write: Box::new(write),
size,
}
}
async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(self.size);
msg.write_network_bytes(&mut buf)?;
let n = self.write.write(&buf).await?;
if n != buf.len() {
return Err(MtopError::runtime(format!(
"short write to UDP socket. expected {}, got {}",
buf.len(),
n
)));
}
self.write.flush().await?;
buf.clear();
buf.resize(self.size, 0);
loop {
let n = self.read.read(&mut buf).await?;
let cur = Cursor::new(&buf[0..n]);
let res = Message::read_network_bytes(cur)?;
if res.id() == msg.id() {
return Ok(res);
}
}
}
}
impl fmt::Debug for UdpClient {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "UdpClient {{ read: ..., write: ..., size: {} }}", self.size)
}
}
pub(crate) struct SocketAdapter(UdpSocket);
impl SocketAdapter {
pub fn new(sock: UdpSocket) -> Self {
Self(sock)
}
}
impl AsyncRead for SocketAdapter {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
self.0.poll_recv(cx, buf)
}
}
impl AsyncWrite for SocketAdapter {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self.0.poll_send(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Clone, Default)]
struct UdpFactory;
impl ClientFactory<SocketAddr, UdpClient> for UdpFactory {
async fn make(&self, address: &SocketAddr) -> Result<UdpClient, MtopError> {
let local = if address.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
let sock = UdpSocket::bind(local).await?;
sock.connect(address).await?;
let adapter = SocketAdapter::new(sock);
let (read, write) = tokio::io::split(adapter);
Ok(UdpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}
#[derive(Debug, Clone, Default)]
struct TcpFactory;
impl ClientFactory<SocketAddr, TcpClient> for TcpFactory {
async fn make(&self, address: &SocketAddr) -> Result<TcpClient, MtopError> {
let (read, write) = tcp_connect(address).await?;
Ok(TcpClient::new(read, write, DEFAULT_MESSAGE_BUFFER))
}
}
#[cfg(test)]
mod test {
use super::{TcpClient, UdpClient};
use crate::core::ErrorKind;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question, Record};
use crate::dns::name::Name;
use crate::dns::rdata::{RecordData, RecordDataA};
use std::collections::VecDeque;
use std::io::Cursor;
use std::net::Ipv4Addr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
#[rustfmt::skip]
fn new_message_bytes(id: u16, size_prefix: bool) -> Vec<u8> {
let body = &[
128, 0, 0, 1, 0, 1, 0, 0, 0, 0, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0, 1, 0, 1, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0, 1, 0, 1, 0, 0, 0, 60, 0, 4, 127, 0, 0, 100, ];
let mut out = Vec::new();
if size_prefix {
let size = 2 + body.len();
let size_bytes = (size as u16).to_be_bytes();
out.extend_from_slice(&size_bytes);
}
let id_bytes = id.to_be_bytes();
out.extend_from_slice(&id_bytes);
out.extend_from_slice(body);
out
}
#[tokio::test]
async fn test_tcp_client_eof_reading_length() {
let write = Vec::new();
let read = Cursor::new(Vec::new());
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = TcpClient::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::IO, err.kind());
}
#[tokio::test]
async fn test_tcp_client_eof_reading_message() {
let write = Vec::new();
let read = Cursor::new(vec![
0, 200, ]);
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = TcpClient::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::IO, err.kind());
}
#[tokio::test]
async fn test_tcp_client_id_mismatch() {
let write: Vec<u8> = Vec::new();
let read = Cursor::new(new_message_bytes(111 , true));
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = TcpClient::new(read, write, 512);
let res = client.exchange(&message).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::Runtime, err.kind());
}
#[tokio::test]
async fn test_tcp_client_success() {
let write = Vec::new();
let read = Cursor::new(new_message_bytes(123, true));
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = TcpClient::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();
assert_eq!(message.id(), res.id());
assert_eq!(message.questions()[0], res.questions()[0]);
assert_eq!(
Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
60,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
),
res.answers()[0]
);
}
struct MockReadSocket {
values: VecDeque<Vec<u8>>,
reads: Arc<AtomicU64>,
}
impl MockReadSocket {
fn new(values: Vec<Vec<u8>>, reads: Arc<AtomicU64>) -> Self {
Self {
values: VecDeque::from(values),
reads,
}
}
}
impl AsyncRead for MockReadSocket {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.reads.fetch_add(1, Ordering::AcqRel);
let b = self.values.pop_front().unwrap();
buf.put_slice(&b);
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn test_udp_client_one_id_mismatch() {
let write = Vec::new();
let read_count = Arc::new(AtomicU64::new(0));
let read = MockReadSocket::new(
vec![
new_message_bytes(111 , false),
new_message_bytes(123, false),
],
read_count.clone(),
);
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = UdpClient::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();
assert_eq!(message.id(), res.id());
assert_eq!(message.questions()[0], res.questions()[0]);
assert_eq!(
Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
60,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
),
res.answers()[0]
);
assert_eq!(2, read_count.load(Ordering::Acquire));
}
#[tokio::test]
async fn test_udp_client_success() {
let write = Vec::new();
let read_count = Arc::new(AtomicU64::new(0));
let read = MockReadSocket::new(vec![new_message_bytes(123, false)], read_count.clone());
let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
let message =
Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question);
let mut client = UdpClient::new(read, write, 512);
let res = client.exchange(&message).await.unwrap();
assert_eq!(message.id(), res.id());
assert_eq!(message.questions()[0], res.questions()[0]);
assert_eq!(
Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
60,
RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
),
res.answers()[0]
);
assert_eq!(1, read_count.load(Ordering::Acquire));
}
}