use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use serde::Deserialize;
#[derive(Debug, Default)]
struct Ipv4OnlyResolver;
impl Resolve for Ipv4OnlyResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_owned();
Box::pin(async move {
let addrs = tokio::net::lookup_host((host.as_str(), 0)).await?;
let v4: Vec<SocketAddr> = addrs.filter(SocketAddr::is_ipv4).collect();
Ok(Box::new(v4.into_iter()) as Addrs)
})
}
}
#[derive(Debug)]
struct DohResolver {
client: reqwest::Client,
endpoint: String,
resolver: IpAddr,
ipv4_only: bool,
}
#[derive(Deserialize)]
struct DohResponse {
#[serde(rename = "Answer", default)]
answer: Vec<DohAnswer>,
}
#[derive(Deserialize)]
struct DohAnswer {
data: String,
}
impl DohResolver {
async fn query(
client: reqwest::Client,
endpoint: String,
host: String,
rtype: &'static str,
) -> Result<Vec<IpAddr>, Box<dyn std::error::Error + Send + Sync>> {
let resp = client
.get(&endpoint)
.query(&[("name", host.as_str()), ("type", rtype)])
.header("accept", "application/dns-json")
.send()
.await?
.error_for_status()?
.json::<DohResponse>()
.await?;
Ok(resp
.answer
.into_iter()
.filter_map(|a| a.data.parse::<IpAddr>().ok())
.collect())
}
}
impl Resolve for DohResolver {
fn resolve(&self, name: Name) -> Resolving {
let client = self.client.clone();
let endpoint = self.endpoint.clone();
let resolver = self.resolver;
let ipv4_only = self.ipv4_only;
let host = name.as_str().to_owned();
Box::pin(async move {
let debug = std::env::var_os("MJRS_DNS_DEBUG").is_some();
let mut ips =
match DohResolver::query(client.clone(), endpoint.clone(), host.clone(), "A").await
{
Ok(v) => v,
Err(e) => {
if debug {
eprintln!("[dns] DoH via {resolver} → {host} = ERROR ({e})");
}
return Err(e);
}
};
if !ipv4_only {
if let Ok(v6) = DohResolver::query(client, endpoint, host.clone(), "AAAA").await {
ips.extend(v6);
}
}
if debug {
eprintln!("[dns] DoH via {resolver} → {host} = {ips:?}");
}
let addrs: Vec<SocketAddr> = ips.into_iter().map(|ip| SocketAddr::new(ip, 0)).collect();
Ok(Box::new(addrs.into_iter()) as Addrs)
})
}
}
pub fn build_client(
timeout_secs: u64,
force_ipv4: bool,
dns: Option<IpAddr>,
) -> reqwest::Result<reqwest::Client> {
let mut builder = reqwest::Client::builder().timeout(Duration::from_secs(timeout_secs));
match dns {
Some(ip) => {
let boot = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()?;
let endpoint = match ip {
IpAddr::V4(v4) => format!("https://{v4}/dns-query"),
IpAddr::V6(v6) => format!("https://[{v6}]/dns-query"),
};
builder = builder.dns_resolver(Arc::new(DohResolver {
client: boot,
endpoint,
resolver: ip,
ipv4_only: force_ipv4,
}));
}
None if force_ipv4 => {
builder = builder.dns_resolver(Arc::new(Ipv4OnlyResolver));
}
None => {}
}
builder.build()
}
pub fn describe_reqwest_error(err: &reqwest::Error) -> String {
if let Some(status) = err.status() {
let reason = status.canonical_reason().unwrap_or("unknown");
return format!("HTTP {status} {reason}");
}
let kind = if err.is_timeout() {
"connection timed out"
} else if err.is_connect() {
"could not establish connection"
} else if err.is_request() {
"request could not be sent"
} else {
"network error"
};
let mut root: Option<String> = None;
let mut src = std::error::Error::source(err);
while let Some(e) = src {
root = Some(e.to_string());
src = e.source();
}
match root {
Some(cause) => format!("{kind} ({cause})"),
None => kind.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_client_succeeds_in_all_modes() {
assert!(build_client(10, false, None).is_ok());
assert!(build_client(10, true, None).is_ok());
assert!(build_client(10, false, Some("1.1.1.1".parse().unwrap())).is_ok());
assert!(build_client(10, true, Some("1.1.1.1".parse().unwrap())).is_ok());
}
#[tokio::test]
#[ignore = "requires internet: queries Cloudflare DoH at 1.1.1.1"]
async fn doh_resolver_resolves_real_host() {
std::env::set_var("MJRS_DNS_DEBUG", "1");
let resolver = DohResolver {
client: reqwest::Client::builder().build().unwrap(),
endpoint: "https://1.1.1.1/dns-query".to_owned(),
resolver: "1.1.1.1".parse().unwrap(),
ipv4_only: true,
};
let name: Name = "resources.download.minecraft.net".parse().unwrap();
let addrs: Vec<SocketAddr> = resolver.resolve(name).await.unwrap().collect();
assert!(!addrs.is_empty(), "DoH returned no addresses");
assert!(addrs.iter().all(SocketAddr::is_ipv4), "ipv4_only honoured");
std::env::remove_var("MJRS_DNS_DEBUG");
}
}