use std::net::SocketAddr;
use crossfire::MAsyncRx;
use hickory_client::{
client::{Client, ClientHandle},
proto::{
rr::{DNSClass, Name, RecordType},
runtime::TokioRuntimeProvider,
udp::UdpClientStream,
xfer::DnsResponse,
},
};
use tokio::{sync::oneshot, time::sleep};
use tracing::debug;
use crate::{BlastDNSConfig, error::BlastDNSError};
#[derive(Debug)]
pub(crate) struct QuerySpec {
pub(crate) host: String,
pub(crate) record_type: RecordType,
}
pub(crate) struct WorkItem {
pub(crate) query: QuerySpec,
pub(crate) responder: oneshot::Sender<Result<DnsResponse, BlastDNSError>>,
}
impl WorkItem {
pub(crate) fn new(
query: QuerySpec,
responder: oneshot::Sender<Result<DnsResponse, BlastDNSError>>,
) -> Self {
Self { query, responder }
}
pub(crate) fn respond(self, result: Result<DnsResponse, BlastDNSError>) {
let _ = self.responder.send(result);
}
}
pub(crate) struct ResolverWorker {
resolver: SocketAddr,
config: BlastDNSConfig,
work_rx: MAsyncRx<WorkItem>,
client: Option<Client>,
}
impl ResolverWorker {
pub fn spawn(
resolver: SocketAddr,
work_rx: MAsyncRx<WorkItem>,
config: BlastDNSConfig,
worker_idx: usize,
) {
tokio::spawn(async move {
let resolver_addr = resolver;
let worker = Self {
resolver: resolver_addr,
config,
work_rx,
client: None,
};
match worker.run().await {
Ok(()) => debug!("resolver worker {resolver_addr} (#{worker_idx}) shutting down"),
Err(err) => {
eprintln!("resolver worker {resolver_addr} (#{worker_idx}) exited: {err:?}")
}
}
});
}
async fn run(mut self) -> Result<(), BlastDNSError> {
let mut consecutive_errors = 0usize;
loop {
if self.config.purgatory_threshold > 0
&& consecutive_errors >= self.config.purgatory_threshold
{
let sentence = self.config.purgatory_sentence;
if !sentence.is_zero() {
debug!(
resolver = %self.resolver,
sentence = ?sentence,
consecutive_errors,
"entering purgatory"
);
sleep(sentence).await;
}
consecutive_errors = consecutive_errors.saturating_sub(1);
}
let work_item = match self.work_rx.recv().await {
Ok(item) => item,
Err(_) => break,
};
if self.client.is_none() {
self.client = Some(self.init_client().await?);
}
let WorkItem { query, responder } = work_item;
match self.handle_query(query).await {
Ok(response) => {
consecutive_errors = consecutive_errors.saturating_sub(1);
let _ = responder.send(Ok(response));
}
Err(err) => {
consecutive_errors = consecutive_errors.saturating_add(1);
let _ = responder.send(Err(err));
}
}
}
Ok(())
}
async fn init_client(&self) -> Result<Client, BlastDNSError> {
let provider = TokioRuntimeProvider::new();
let stream = UdpClientStream::builder(self.resolver, provider)
.with_timeout(Some(self.config.request_timeout))
.build();
let (client, bg) =
Client::connect(stream)
.await
.map_err(|source| BlastDNSError::ResolverSetupFailed {
resolver: self.resolver,
source,
})?;
let resolver = self.resolver;
tokio::spawn(async move {
if let Err(err) = bg.await {
eprintln!("resolver {resolver} background task exited: {err}");
}
});
Ok(client)
}
async fn handle_query(&mut self, query: QuerySpec) -> Result<DnsResponse, BlastDNSError> {
let QuerySpec { host, record_type } = query;
debug!(
resolver = %self.resolver,
host,
%record_type,
"querying DNS resolver"
);
let name = Name::from_ascii(&host)
.map_err(|source| BlastDNSError::InvalidHostname { name: host, source })?;
self.client
.as_mut()
.unwrap()
.query(name, DNSClass::IN, record_type)
.await
.map_err(|source| BlastDNSError::ResolverRequestFailed {
resolver: self.resolver,
source,
})
}
}