use mio::net::UdpSocket;
use std::{io, net::{SocketAddr, Ipv4Addr}, fmt, time::{self, Duration, Instant}};
use crate::util::{make_socket_addr, register_all, wouldblock, reregister_all};
const ME: SocketAddr = make_socket_addr(Ipv4Addr::new(0, 0, 0, 0), 0);
const DNS: SocketAddr = make_socket_addr(Ipv4Addr::new(8, 8, 8, 8), 53);
pub(crate) struct DnsClient {
pub(crate) token: mio::Token,
socket: Option<UdpSocket>,
write_outdated: bool,
requests: Vec<InternalRequest>,
next_id: u16,
}
impl DnsClient {
pub(crate) fn new(token: mio::Token) -> Self {
Self {
socket: None,
write_outdated: false,
token,
requests: Vec::new(),
next_id: 0
}
}
pub(crate) fn resolve(&mut self, io: &mio::Poll, host: &str, timeout: Option<Duration>) -> io::Result<DnsId> {
if self.socket.is_none() {
let mut socket = UdpSocket::bind(ME)?;
socket.connect(DNS)?;
register_all(io, &mut socket, self.token)?;
self.socket = Some(socket);
}
if self.write_outdated {
let socket = self.socket.as_mut().expect("no socket");
reregister_all(io, socket, self.token)?;
self.write_outdated = false;
}
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
self.requests.push(InternalRequest {
id,
state: InternalRequestState::Pending,
packet: new_dns_packet(id, host),
time_created: Instant::now(),
timeout,
});
Ok(DnsId { inner: id })
}
pub(crate) fn pump(&mut self, io: &mio::Poll, events: &mio::Events) -> io::Result<Vec<DnsResponse>> {
let mut responses = Vec::new();
let mut index: isize = 0;
while let Some(request) = self.requests.get_mut(index as usize) {
if request.timeout.unwrap_or(Duration::MAX) <= request.time_created.elapsed() {
let id = request.id;
self.requests.swap_remove(index as usize);
index -= 1;
responses.push(DnsResponse {
id: DnsId { inner: id },
outcome: DnsOutcome::TimedOut
})
}
index += 1;
}
for event in events {
if self.token == event.token() {
if let Some(ref mut socket) = self.socket {
if event.is_writable() {
self.write_outdated = true;
for req in self.requests.iter_mut() {
if req.state == InternalRequestState::Pending {
socket.send(&req.packet)?;
req.state = InternalRequestState::Sent;
self.write_outdated = false;
}
}
}
if event.is_readable() {
loop {
let mut buff = [0; 1024];
match socket.recv(&mut buff) {
Err(err) if wouldblock(&err) => break,
Err(other) => return Err(other),
Ok(..) => (),
};
let resp = DnsResponse::parse_from_packet(&buff);
let maybe_idx = self.requests.iter().position(|req| req.id == resp.id.inner);
if let Some(idx) = maybe_idx {
responses.push(resp);
self.requests.swap_remove(idx);
if self.requests.is_empty() {
io.registry().deregister(socket)?;
self.socket = None;
break
}
}
}
}
}
}
}
Ok(responses)
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct DnsId {
pub(crate) inner: u16,
}
struct InternalRequest {
id: u16,
state: InternalRequestState,
packet: Vec<u8>,
time_created: Instant,
timeout: Option<Duration>,
}
#[derive(PartialEq)]
enum InternalRequestState {
Pending,
Sent,
}
fn new_dns_packet(id: u16, hostname: &str) -> Vec<u8> {
let mut packet = dns_parser::Builder::new_query(id, true);
packet.add_question(hostname, false, dns_parser::QueryType::A, dns_parser::QueryClass::IN);
packet.build().unwrap()
}
#[derive(Debug)]
pub(crate) enum DnsOutcome {
Known { addr: Ipv4Addr, ttl: time::Duration },
Unknown,
ProtocolError,
TimedOut,
}
pub(crate) struct DnsResponse {
pub(crate) id: DnsId,
pub(crate) outcome: DnsOutcome,
}
impl DnsResponse {
fn parse_from_packet(buff: &[u8]) -> Self {
let packet = dns_parser::Packet::parse(buff).unwrap();
let outcome = match packet.header.response_code {
dns_parser::ResponseCode::NoError => {
match parse_answer(&packet) {
Some((addr, ttl)) => DnsOutcome::Known { addr, ttl },
None => DnsOutcome::ProtocolError,
}
},
dns_parser::ResponseCode::NameError => {
DnsOutcome::Unknown
},
_ => {
DnsOutcome::ProtocolError
}
};
Self { id: DnsId { inner: packet.header.id }, outcome }
}
}
fn parse_answer(packet: &dns_parser::Packet) -> Option<(Ipv4Addr, time::Duration)> {
for answer in &packet.answers {
if let dns_parser::RData::A(result) = answer.data {
return Some((result.0, time::Duration::from_secs(answer.ttl as u64)))
}
}
None
}
impl fmt::Debug for DnsResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.outcome {
DnsOutcome::Known { addr, ttl } => write!(f, "{:?}, ttl: {:?}", addr, ttl),
DnsOutcome::Unknown => write!(f, "Unknown"),
DnsOutcome::ProtocolError => write!(f, "Dns Protocol Error"),
DnsOutcome::TimedOut => write!(f, "TimedOut"),
}
}
}