idoq 0.1.5

DNS over QUIC (DoQ) client / DNS over QUIC (DoQ) 客户端
Documentation
use std::{
  net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
  sync::{Arc, LazyLock},
  time::Duration,
};

use bytes::{BufMut, Bytes, BytesMut};
use idns::Answer;
use quinn::{ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
use rustls::pki_types::ServerName;
use tokio::sync::RwLock;

use crate::{Error, QType, Result, parser};

pub const DOQ_PORT: u16 = 853;

const TIMEOUT: Duration = Duration::from_secs(7);
const BIND_V4: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
const BIND_V6: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0);

static CONF: LazyLock<ClientConfig> = LazyLock::new(|| {
  let mut roots = rustls::RootCertStore::empty();
  roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

  let mut tls =
    rustls::ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
      .with_safe_default_protocol_versions()
      .expect("TLS version")
      .with_root_certificates(roots)
      .with_no_client_auth();

  tls.alpn_protocols = vec![b"doq".to_vec()];

  let mut config = ClientConfig::new(Arc::new(
    quinn::crypto::rustls::QuicClientConfig::try_from(tls).expect("QUIC config"),
  ));

  let mut transport = TransportConfig::default();
  transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
  config.transport_config(Arc::new(transport));

  config
});

/// DoQ 客户端,持有可复用的 QUIC 连接
pub struct Doq {
  pub server: crate::HostIp,
  conn: RwLock<Option<Connection>>,
}

impl std::fmt::Debug for Doq {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("Doq").field("server", &self.server).finish()
  }
}

impl Doq {
  pub fn new(server: crate::HostIp) -> Self {
    Self {
      server,
      conn: RwLock::new(None),
    }
  }

  /// 获取或创建连接
  async fn conn(&self) -> Result<Connection> {
    let alive = |c: &Connection| c.close_reason().is_none();

    if let Some(c) = self.conn.read().await.as_ref().filter(|c| alive(c)) {
      return Ok(c.clone());
    }

    let mut guard = self.conn.write().await;
    if let Some(c) = guard.as_ref().filter(|c| alive(c)) {
      return Ok(c.clone());
    }

    let c = self.dial().await?;
    *guard = Some(c.clone());
    Ok(c)
  }

  async fn dial(&self) -> Result<Connection> {
    let bind = if self.server.ip.is_ipv6() { BIND_V6 } else { BIND_V4 };
    let mut ep = Endpoint::client(bind)?;
    ep.set_default_client_config(CONF.clone());

    let name: ServerName<'_> = self
      .server
      .host
      .as_str()
      .try_into()
      .map_err(|_| Error::InvalidAddress(self.server.host.to_string()))?;

    let addr = SocketAddr::new(self.server.ip, DOQ_PORT);
    tokio::time::timeout(TIMEOUT, ep.connect(addr, name.to_str().as_ref())?)
      .await
      .map_err(|_| Error::Timeout)?
      .map_err(Error::Connection)
  }

  /// 执行 DNS 查询
  pub async fn query(&self, domain: &str, qtype: QType) -> Result<Option<Vec<Answer>>> {
    let c = self.conn().await?;
    match Self::send(&c, domain, qtype).await {
      Ok(r) => Ok(r),
      Err(e) => {
        if matches!(e, Error::Connection(_) | Error::Io(_)) {
          *self.conn.write().await = None;
        }
        Err(e)
      }
    }
  }

  async fn send(c: &Connection, domain: &str, qtype: QType) -> Result<Option<Vec<Answer>>> {
    let (mut tx, mut rx) = c.open_bi().await.map_err(Error::Connection)?;

    let msg = parser::build(domain, qtype as u16);
    let mut buf = BytesMut::with_capacity(2 + msg.len());
    buf.put_u16(msg.len() as u16);
    buf.put_slice(&msg);

    tx.write_all(&buf).await?;
    tx.finish()?;

    let resp = tokio::time::timeout(TIMEOUT, rx.read_to_end(4096))
      .await
      .map_err(|_| Error::Timeout)??;

    if resp.len() < 14 {
      return Ok(None);
    }

    let len = u16::from_be_bytes([resp[0], resp[1]]) as usize;
    if len != resp.len() - 2 {
      return Err(Error::LengthMismatch);
    }

    let answers = parser::parse(Bytes::from(resp).slice(2..))?;
    Ok(if answers.is_empty() { None } else { Some(answers) })
  }
}