use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{Arc, LazyLock},
time::Duration,
};
use bytes::{BufMut, Bytes, BytesMut};
use idns::Answer;
use quinn::{ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
use rustls::pki_types::ServerName;
use tokio::sync::RwLock;
use crate::{Error, QType, Result, parser};
pub const DOQ_PORT: u16 = 853;
const TIMEOUT: Duration = Duration::from_secs(7);
const BIND_V4: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
const BIND_V6: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0);
static CONF: LazyLock<ClientConfig> = LazyLock::new(|| {
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut tls =
rustls::ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("TLS version")
.with_root_certificates(roots)
.with_no_client_auth();
tls.alpn_protocols = vec![b"doq".to_vec()];
let mut config = ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(tls).expect("QUIC config"),
));
let mut transport = TransportConfig::default();
transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
config.transport_config(Arc::new(transport));
config
});
pub struct Doq {
pub server: crate::HostIp,
conn: RwLock<Option<Connection>>,
}
impl std::fmt::Debug for Doq {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Doq").field("server", &self.server).finish()
}
}
impl Doq {
pub fn new(server: crate::HostIp) -> Self {
Self {
server,
conn: RwLock::new(None),
}
}
async fn conn(&self) -> Result<Connection> {
let alive = |c: &Connection| c.close_reason().is_none();
if let Some(c) = self.conn.read().await.as_ref().filter(|c| alive(c)) {
return Ok(c.clone());
}
let mut guard = self.conn.write().await;
if let Some(c) = guard.as_ref().filter(|c| alive(c)) {
return Ok(c.clone());
}
let c = self.dial().await?;
*guard = Some(c.clone());
Ok(c)
}
async fn dial(&self) -> Result<Connection> {
let bind = if self.server.ip.is_ipv6() { BIND_V6 } else { BIND_V4 };
let mut ep = Endpoint::client(bind)?;
ep.set_default_client_config(CONF.clone());
let name: ServerName<'_> = self
.server
.host
.as_str()
.try_into()
.map_err(|_| Error::InvalidAddress(self.server.host.to_string()))?;
let addr = SocketAddr::new(self.server.ip, DOQ_PORT);
tokio::time::timeout(TIMEOUT, ep.connect(addr, name.to_str().as_ref())?)
.await
.map_err(|_| Error::Timeout)?
.map_err(Error::Connection)
}
pub async fn query(&self, domain: &str, qtype: QType) -> Result<Option<Vec<Answer>>> {
let c = self.conn().await?;
match Self::send(&c, domain, qtype).await {
Ok(r) => Ok(r),
Err(e) => {
if matches!(e, Error::Connection(_) | Error::Io(_)) {
*self.conn.write().await = None;
}
Err(e)
}
}
}
async fn send(c: &Connection, domain: &str, qtype: QType) -> Result<Option<Vec<Answer>>> {
let (mut tx, mut rx) = c.open_bi().await.map_err(Error::Connection)?;
let msg = parser::build(domain, qtype as u16);
let mut buf = BytesMut::with_capacity(2 + msg.len());
buf.put_u16(msg.len() as u16);
buf.put_slice(&msg);
tx.write_all(&buf).await?;
tx.finish()?;
let resp = tokio::time::timeout(TIMEOUT, rx.read_to_end(4096))
.await
.map_err(|_| Error::Timeout)??;
if resp.len() < 14 {
return Ok(None);
}
let len = u16::from_be_bytes([resp[0], resp[1]]) as usize;
if len != resp.len() - 2 {
return Err(Error::LengthMismatch);
}
let answers = parser::parse(Bytes::from(resp).slice(2..))?;
Ok(if answers.is_empty() { None } else { Some(answers) })
}
}