use super::framing::{frame, take_frame};
use crate::Client;
use async_channel::{Receiver, Sender, bounded, unbounded};
use async_lock::Mutex;
use futures_lite::{AsyncReadExt, AsyncWriteExt, future};
use std::{
borrow::Cow,
collections::HashMap,
io::{self, ErrorKind},
sync::{
Arc,
atomic::{AtomicU16, Ordering},
},
};
use trillium_server_common::{Connector, Destination, Transport, url::Url};
#[derive(Debug, Clone)]
pub(super) struct Dot {
resolver: Url,
next_id: Arc<AtomicU16>,
conn: Arc<Mutex<Option<Sender<Outbound>>>>,
}
#[derive(Debug)]
struct Outbound {
id: u16,
framed: Vec<u8>,
resp: Sender<Vec<u8>>,
}
impl Dot {
pub(super) fn new(mut resolver: Url) -> Self {
if resolver.port().is_none() {
let _ = resolver.set_port(Some(853));
}
Self {
resolver,
next_id: Arc::new(AtomicU16::new(0)),
conn: Arc::new(Mutex::new(None)),
}
}
pub(super) fn host(&self) -> Option<&str> {
self.resolver.host_str()
}
pub(super) fn resolver(&self) -> &Url {
&self.resolver
}
pub(super) async fn exchange(&self, client: &Client, query: Vec<u8>) -> io::Result<Vec<u8>> {
log::trace!(
"DoT exchange to {}: {}-byte query",
self.resolver,
query.len()
);
let tx = self.connection(client).await?;
match self.send(&tx, query.clone()).await {
Ok(response) => {
log::trace!(
"DoT exchange to {}: {}-byte response",
self.resolver,
response.len()
);
Ok(response)
}
Err(_) => {
log::debug!("DoT connection to {} unusable; reconnecting", self.resolver);
self.invalidate(&tx).await;
let tx = self.connection(client).await?;
self.send(&tx, query).await
}
}
}
async fn send(&self, tx: &Sender<Outbound>, mut query: Vec<u8>) -> io::Result<Vec<u8>> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
if query.len() < 2 {
return Err(io::Error::new(
ErrorKind::InvalidInput,
"DNS query too short",
));
}
query[0..2].copy_from_slice(&id.to_be_bytes());
log::trace!("DoT query id {id} queued for {}", self.resolver);
let (resp, resp_rx) = bounded(1);
let framed = frame(&query)?;
tx.send(Outbound { id, framed, resp })
.await
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "DoT driver stopped"))?;
resp_rx
.recv()
.await
.map_err(|_| io::Error::new(ErrorKind::BrokenPipe, "DoT connection closed"))
}
async fn connection(&self, client: &Client) -> io::Result<Sender<Outbound>> {
let mut guard = self.conn.lock().await;
if let Some(tx) = guard.as_ref() {
return Ok(tx.clone());
}
let tx = self.connect(client).await?;
*guard = Some(tx.clone());
Ok(tx)
}
async fn invalidate(&self, stale: &Sender<Outbound>) {
let mut guard = self.conn.lock().await;
if guard.as_ref().is_some_and(|tx| tx.same_channel(stale)) {
*guard = None;
}
}
async fn connect(&self, client: &Client) -> io::Result<Sender<Outbound>> {
log::debug!("DoT connecting to {} (alpn dot)", self.resolver);
let destination =
Destination::from_url(&self.resolver)?.with_alpn([Cow::Borrowed(&b"dot"[..])]);
let transport = client.connector().connect_to(destination).await?;
log::debug!(
"DoT connected to {}; negotiated alpn {:?}",
self.resolver,
transport
.negotiated_alpn()
.map(|a| String::from_utf8_lossy(&a).into_owned())
);
let (tx, rx) = unbounded();
client.connector().runtime().spawn(drive(transport, rx));
Ok(tx)
}
}
async fn drive(mut transport: Box<dyn Transport>, requests: Receiver<Outbound>) {
log::trace!("DoT driver started");
let mut inflight: HashMap<u16, Sender<Vec<u8>>> = HashMap::new();
let mut read_buf = Vec::new();
let mut chunk = [0u8; 2048];
let reason = loop {
let event = future::or(
async { Event::Outbound(requests.recv().await.ok()) },
async { Event::Read(transport.read(&mut chunk).await) },
)
.await;
match event {
Event::Outbound(None) => break "all handles dropped",
Event::Outbound(Some(outbound)) => {
log::trace!("DoT driver writing query id {}", outbound.id);
inflight.insert(outbound.id, outbound.resp);
if let Err(e) = transport.write_all(&outbound.framed).await {
log::debug!("DoT driver write failed: {e}");
break "write error";
}
}
Event::Read(Ok(0)) => break "connection closed by resolver",
Event::Read(Err(e)) => {
log::debug!("DoT driver read failed: {e}");
break "read error";
}
Event::Read(Ok(n)) => {
log::trace!("DoT driver read {n} bytes");
read_buf.extend_from_slice(&chunk[..n]);
while let Some(message) = take_frame(&mut read_buf) {
let [hi, lo, ..] = message[..] else { continue };
let id = u16::from_be_bytes([hi, lo]);
if let Some(resp) = inflight.remove(&id) {
log::trace!(
"DoT driver routing {}-byte response to id {id}",
message.len()
);
let _ = resp.try_send(message);
} else {
log::trace!("DoT driver got response for unknown id {id}; dropping");
}
}
}
}
};
log::trace!(
"DoT driver ending ({reason}); {} queries still in flight",
inflight.len()
);
}
enum Event {
Outbound(Option<Outbound>),
Read(io::Result<usize>),
}
#[cfg(test)]
mod tests {
use super::*;
use trillium_testing::{TestTransport, harness, test};
fn message(id: u16, tag: u8) -> Vec<u8> {
let mut message = id.to_be_bytes().to_vec();
message.extend_from_slice(&[tag; 4]);
message
}
#[test(harness)]
async fn driver_demultiplexes_out_of_order() {
let (client_side, mut server_side) = TestTransport::new();
let (tx, rx) = unbounded::<Outbound>();
let resolver = async move {
for _ in 0..2 {
let mut len = [0u8; 2];
server_side.read_exact(&mut len).await.unwrap();
let mut query = vec![0u8; usize::from(u16::from_be_bytes(len))];
server_side.read_exact(&mut query).await.unwrap();
}
server_side.write_all(&frame(&message(2, 0x22)).unwrap());
server_side.write_all(&frame(&message(1, 0x11)).unwrap());
};
let exchange = async move {
let (resp1, rx1) = bounded(1);
let (resp2, rx2) = bounded(1);
tx.send(Outbound {
id: 1,
framed: frame(&message(1, 0)).unwrap(),
resp: resp1,
})
.await
.unwrap();
tx.send(Outbound {
id: 2,
framed: frame(&message(2, 0)).unwrap(),
resp: resp2,
})
.await
.unwrap();
assert_eq!(rx1.recv().await.unwrap(), message(1, 0x11));
assert_eq!(rx2.recv().await.unwrap(), message(2, 0x22));
};
future::zip(
future::zip(resolver, exchange),
drive(Box::new(client_side), rx),
)
.await;
}
}