use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
const TTL_MIN: Duration = Duration::from_secs(10);
const TTL_MAX: Duration = Duration::from_secs(300);
const GOOGLE_DOH_URL: &str = "https://dns.google.com/resolve";
const MOZILLA_DOH_URL: &str = "https://mozilla.cloudflare-dns.com/dns-query";
const PAD_MIN: usize = 13;
const PAD_MAX: usize = 128;
#[derive(Clone, Debug)]
struct CacheEntry {
ips: Vec<String>,
expires_at: Instant,
}
#[derive(Clone)]
pub struct DnsResolver {
client: reqwest::Client,
cache: Arc<Mutex<HashMap<(String, bool), CacheEntry>>>,
}
impl std::fmt::Debug for DnsResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DnsResolver").finish_non_exhaustive()
}
}
impl DnsResolver {
pub fn new() -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.no_proxy()
.user_agent(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) \
Chrome/124.0.0.0 Safari/537.36",
)
.build()
.expect("DnsResolver: failed to build reqwest client");
Self {
client,
cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn resolve(&self, domain: &str) -> Vec<String> {
let (v4, v6) = tokio::join!(
self.resolve_type(domain, false),
self.resolve_type(domain, true),
);
let mut ips = v4;
ips.extend(v6);
ips
}
pub async fn resolve_v4(&self, domain: &str) -> Vec<String> {
self.resolve_type(domain, false).await
}
pub async fn resolve_v6(&self, domain: &str) -> Vec<String> {
self.resolve_type(domain, true).await
}
async fn resolve_type(&self, domain: &str, ipv6: bool) -> Vec<String> {
{
let cache = self.cache.lock().await;
if let Some(entry) = cache.get(&(domain.to_owned(), ipv6))
&& entry.expires_at > Instant::now()
{
tracing::debug!(
"[dns] cache hit: {} {} → {:?}",
domain,
if ipv6 { "AAAA" } else { "A" },
entry.ips
);
return entry.ips.clone();
}
}
let rrtype = if ipv6 { 28u8 } else { 1u8 }; let pad = generate_padding();
let google_fut = self.query_google(domain, rrtype, &pad);
let mozilla_fut = self.query_mozilla(domain, rrtype, &pad);
let result = tokio::select! {
r = google_fut => r,
r = mozilla_fut => r,
};
let entries = match result {
Ok(v) if !v.is_empty() => v,
_ => {
let pad2 = generate_padding();
let r = tokio::join!(
self.query_google(domain, rrtype, &pad2),
self.query_mozilla(domain, rrtype, &pad2),
);
match (r.0, r.1) {
(Ok(v), _) if !v.is_empty() => v,
(_, Ok(v)) if !v.is_empty() => v,
_ => {
tracing::warn!(
"[dns] DoH resolution failed for {} {}",
domain,
if ipv6 { "AAAA" } else { "A" }
);
return vec![];
}
}
}
};
let min_ttl = entries
.iter()
.map(|e| e.ttl)
.min()
.unwrap_or(TTL_MIN.as_secs());
let ttl = Duration::from_secs(min_ttl).max(TTL_MIN).min(TTL_MAX);
let ips: Vec<String> = entries.into_iter().map(|e| e.data).collect();
tracing::debug!(
"[dns] resolved {} {} → {:?} (TTL={:?})",
domain,
if ipv6 { "AAAA" } else { "A" },
ips,
ttl
);
{
let mut cache = self.cache.lock().await;
cache.insert(
(domain.to_owned(), ipv6),
CacheEntry {
ips: ips.clone(),
expires_at: Instant::now() + ttl,
},
);
}
ips
}
async fn query_google(
&self,
domain: &str,
rrtype: u8,
padding: &str,
) -> Result<Vec<DnsEntry>, DohError> {
let url = format!("{GOOGLE_DOH_URL}?name={domain}&type={rrtype}&random_padding={padding}");
let resp = self
.client
.get(&url)
.send()
.await
.map_err(|e| DohError::Http(e.to_string()))?;
let bytes = resp
.bytes()
.await
.map_err(|e| DohError::Http(e.to_string()))?;
parse_doh_json(&bytes, Some(rrtype as u32))
}
async fn query_mozilla(
&self,
domain: &str,
rrtype: u8,
padding: &str,
) -> Result<Vec<DnsEntry>, DohError> {
let url = format!("{MOZILLA_DOH_URL}?name={domain}&type={rrtype}&random_padding={padding}");
let resp = self
.client
.get(&url)
.header("accept", "application/dns-json")
.send()
.await
.map_err(|e| DohError::Http(e.to_string()))?;
let bytes = resp
.bytes()
.await
.map_err(|e| DohError::Http(e.to_string()))?;
parse_doh_json(&bytes, Some(rrtype as u32))
}
}
impl Default for DnsResolver {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct DnsEntry {
data: String,
ttl: u64,
}
#[derive(Debug)]
enum DohError {
Http(String),
Parse(String),
}
impl std::fmt::Display for DohError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Http(e) => write!(f, "HTTP: {e}"),
Self::Parse(e) => write!(f, "parse: {e}"),
}
}
}
fn parse_doh_json(bytes: &[u8], type_filter: Option<u32>) -> Result<Vec<DnsEntry>, DohError> {
if bytes.is_empty() {
return Err(DohError::Parse("empty response".into()));
}
let value: serde_json::Value =
serde_json::from_slice(bytes).map_err(|e| DohError::Parse(format!("JSON: {e}")))?;
let answer = match value.get("Answer") {
Some(a) => a,
None => return Ok(vec![]), };
let arr = answer
.as_array()
.ok_or_else(|| DohError::Parse("Answer not an array".into()))?;
let mut entries = Vec::new();
for item in arr {
if let Some(expected) = type_filter
&& let Some(t) = item.get("type").and_then(|v| v.as_u64())
&& t != expected as u64
{
continue;
}
let data = match item.get("data").and_then(|v| v.as_str()) {
Some(s) => s.to_owned(),
None => continue,
};
let ttl = item
.get("TTL")
.and_then(|v| v.as_u64())
.unwrap_or(TTL_MIN.as_secs());
entries.push(DnsEntry { data, ttl });
}
Ok(entries)
}
fn generate_padding() -> String {
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut len_byte = [0u8; 1];
getrandom::getrandom(&mut len_byte).unwrap();
let len = PAD_MIN + (len_byte[0] as usize % (PAD_MAX - PAD_MIN + 1));
let mut result = String::with_capacity(len);
let mut buf = vec![0u8; len];
getrandom::getrandom(&mut buf).unwrap();
for b in buf {
result.push(CHARSET[b as usize % CHARSET.len()] as char);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_google_doh_response() {
let json = r#"{
"Status": 0,
"Answer": [
{ "name": "venus.web.telegram.org", "type": 1, "TTL": 120, "data": "149.154.167.51" },
{ "name": "venus.web.telegram.org", "type": 1, "TTL": 120, "data": "149.154.167.91" }
]
}"#;
let entries = parse_doh_json(json.as_bytes(), Some(1)).unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].data, "149.154.167.51");
assert_eq!(entries[0].ttl, 120);
}
#[test]
fn parse_filters_by_type() {
let json = r#"{ "Answer": [
{ "type": 1, "TTL": 60, "data": "1.2.3.4" },
{ "type": 28, "TTL": 60, "data": "::1" }
]}"#;
let v4 = parse_doh_json(json.as_bytes(), Some(1)).unwrap();
assert_eq!(v4.len(), 1);
assert_eq!(v4[0].data, "1.2.3.4");
let v6 = parse_doh_json(json.as_bytes(), Some(28)).unwrap();
assert_eq!(v6.len(), 1);
assert_eq!(v6[0].data, "::1");
}
#[test]
fn parse_empty_answer_ok() {
let json = r#"{ "Status": 3 }"#; let entries = parse_doh_json(json.as_bytes(), None).unwrap();
assert!(entries.is_empty());
}
#[test]
fn padding_length_in_range() {
for _ in 0..20 {
let p = generate_padding();
assert!(
p.len() >= PAD_MIN && p.len() <= PAD_MAX,
"bad len {}",
p.len()
);
}
}
}