use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::SystemTime;
use parking_lot::Mutex;
use tokio::sync::Notify;
use url::Host;
use crate::byte_str::ByteStr;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ChannelController;
use crate::client::name_resolution::Endpoint;
use crate::client::name_resolution::NopResolver;
use crate::client::name_resolution::Resolver;
use crate::client::name_resolution::ResolverBuilder;
use crate::client::name_resolution::ResolverOptions;
use crate::client::name_resolution::ResolverUpdate;
use crate::client::name_resolution::TCP_IP_NETWORK_TYPE;
use crate::client::name_resolution::Target;
use crate::client::name_resolution::backoff::BackoffConfig;
use crate::client::name_resolution::backoff::DEFAULT_EXPONENTIAL_CONFIG;
use crate::client::name_resolution::backoff::ExponentialBackoff;
use crate::client::name_resolution::global_registry;
use crate::rt::BoxedTaskHandle;
use crate::rt::{self};
#[cfg(test)]
mod test;
const DEFAULT_PORT: u16 = 443;
const DEFAULT_DNS_PORT: u16 = 53;
static RESOLVING_TIMEOUT_MS: AtomicU64 = AtomicU64::new(30_000);
static MIN_RESOLUTION_INTERVAL_MS: AtomicU64 = AtomicU64::new(30_000);
fn get_resolving_timeout() -> Duration {
Duration::from_millis(RESOLVING_TIMEOUT_MS.load(Ordering::Relaxed))
}
pub(crate) fn set_resolving_timeout(duration: Duration) {
RESOLVING_TIMEOUT_MS.store(duration.as_millis() as u64, Ordering::Relaxed);
}
fn get_min_resolution_interval() -> Duration {
Duration::from_millis(MIN_RESOLUTION_INTERVAL_MS.load(Ordering::Relaxed))
}
pub(crate) fn set_min_resolution_interval(duration: Duration) {
MIN_RESOLUTION_INTERVAL_MS.store(duration.as_millis() as u64, Ordering::Relaxed);
}
pub(crate) fn reg() {
global_registry().add_builder(Box::new(Builder {}));
}
struct Builder {}
struct DnsOptions {
min_resolution_interval: Duration,
resolving_timeout: Duration,
backoff_config: BackoffConfig,
host: String,
port: u16,
}
impl DnsResolver {
fn new(
dns_client: Box<dyn rt::DnsResolver>,
options: ResolverOptions,
dns_opts: DnsOptions,
) -> Self {
let state = Arc::new(Mutex::new(InternalState {
addrs: Ok(Vec::new()),
channel_response: None,
}));
let state_copy = state.clone();
let resolve_now_notify = Arc::new(Notify::new());
let channel_updated_notify = Arc::new(Notify::new());
let channel_updated_rx = channel_updated_notify.clone();
let resolve_now_rx = resolve_now_notify.clone();
let runtime = options.runtime.clone();
let work_scheduler = options.work_scheduler.clone();
let handle = options.runtime.spawn(Box::pin(async move {
let mut backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone())
.expect("default exponential config must be valid");
let state = state_copy;
loop {
let mut lookup_fut = dns_client.lookup_host_name(&dns_opts.host);
let mut timeout_fut = runtime.sleep(dns_opts.resolving_timeout);
let addrs = tokio::select! {
result = &mut lookup_fut => {
match result {
Ok(ips) => {
let addrs = ips
.into_iter()
.map(|ip| SocketAddr::new(ip, dns_opts.port))
.collect();
Ok(addrs)
}
Err(err) => Err(err),
}
}
_ = &mut timeout_fut => {
Err("Timed out waiting for DNS resolution".to_string())
}
};
{
state.lock().addrs = addrs;
}
work_scheduler.schedule_work();
channel_updated_rx.notified().await;
let channel_response = { state.lock().channel_response.take() };
let next_resoltion_time = if channel_response.is_some() {
SystemTime::now()
.checked_add(backoff.backoff_duration())
.unwrap()
} else {
backoff.reset();
let res_time = SystemTime::now()
.checked_add(dns_opts.min_resolution_interval)
.unwrap();
_ = resolve_now_rx.notified().await;
res_time
};
let Ok(duration) = next_resoltion_time.duration_since(SystemTime::now()) else {
continue; };
runtime.sleep(duration).await;
}
}));
Self {
state,
task_handle: handle,
resolve_now_notifier: resolve_now_notify,
channel_update_notifier: channel_updated_notify,
}
}
}
impl ResolverBuilder for Builder {
fn build(&self, target: &Target, options: ResolverOptions) -> Box<dyn Resolver> {
let parsed = match parse_endpoint_and_authority(target) {
Ok(res) => res,
Err(err) => return NopResolver::new_with_err(err.to_string(), options),
};
let endpoint = parsed.endpoint;
let host = match endpoint.host {
Host::Domain(d) => d,
Host::Ipv4(ipv4) => {
return nop_resolver_for_ip(IpAddr::V4(ipv4), endpoint.port, options);
}
Host::Ipv6(ipv6) => {
return nop_resolver_for_ip(IpAddr::V6(ipv6), endpoint.port, options);
}
};
let authority = parsed.authority;
let dns_client = match options.runtime.get_dns_resolver(rt::ResolverOptions {
server_addr: authority,
}) {
Ok(dns) => dns,
Err(err) => return NopResolver::new_with_err(err.to_string(), options),
};
let dns_opts = DnsOptions {
min_resolution_interval: get_min_resolution_interval(),
resolving_timeout: get_resolving_timeout(),
backoff_config: DEFAULT_EXPONENTIAL_CONFIG,
host,
port: endpoint.port,
};
Box::new(DnsResolver::new(dns_client, options, dns_opts))
}
fn scheme(&self) -> &'static str {
"dns"
}
fn is_valid_uri(&self, target: &Target) -> bool {
if let Err(err) = parse_endpoint_and_authority(target) {
eprintln!("{err}");
false
} else {
true
}
}
}
struct DnsResolver {
state: Arc<Mutex<InternalState>>,
task_handle: BoxedTaskHandle,
resolve_now_notifier: Arc<Notify>,
channel_update_notifier: Arc<Notify>,
}
struct InternalState {
addrs: Result<Vec<SocketAddr>, String>,
channel_response: Option<String>,
}
impl Resolver for DnsResolver {
fn resolve_now(&mut self) {
self.resolve_now_notifier.notify_one();
}
fn work(&mut self, channel_controller: &mut dyn ChannelController) {
let mut state = self.state.lock();
let endpoint_result = match &state.addrs {
Ok(addrs) => {
let endpoints: Vec<_> = addrs
.iter()
.map(|a| Endpoint {
addresses: vec![Address {
network_type: TCP_IP_NETWORK_TYPE,
address: ByteStr::from(a.to_string()),
..Default::default()
}],
..Default::default()
})
.collect();
Ok(endpoints)
}
Err(err) => Err(err.to_string()),
};
let update = ResolverUpdate {
endpoints: endpoint_result,
..Default::default()
};
let status = channel_controller.update(update);
state.channel_response = status.err();
self.channel_update_notifier.notify_one();
}
}
impl Drop for DnsResolver {
fn drop(&mut self) {
self.task_handle.abort();
}
}
#[derive(Eq, PartialEq, Debug)]
struct HostPort {
host: Host<String>,
port: u16,
}
#[derive(Eq, PartialEq, Debug)]
struct ParseResult {
endpoint: HostPort,
authority: Option<SocketAddr>,
}
fn parse_endpoint_and_authority(target: &Target) -> Result<ParseResult, String> {
let endpoint = target.path();
let endpoint = endpoint.strip_prefix("/").unwrap_or(endpoint);
let parse_result = parse_host_port(endpoint, DEFAULT_PORT)
.map_err(|err| format!("Failed to parse target {target}: {err}"))?;
let endpoint = parse_result.ok_or("Received empty endpoint host.".to_string())?;
let authority = target.authority_host_port();
if authority.is_empty() {
return Ok(ParseResult {
endpoint,
authority: None,
});
}
let parse_result = parse_host_port(&authority, DEFAULT_DNS_PORT)
.map_err(|err| format!("Failed to parse DNS authority {target}: {err}"))?;
let Some(authority) = parse_result else {
return Ok(ParseResult {
endpoint,
authority: None,
});
};
let authority = match authority.host {
Host::Ipv4(ipv4) => SocketAddr::new(IpAddr::V4(ipv4), authority.port),
Host::Ipv6(ipv6) => SocketAddr::new(IpAddr::V6(ipv6), authority.port),
_ => {
return Err(format!("Received non-IP DNS authority {}", authority.host));
}
};
Ok(ParseResult {
endpoint,
authority: Some(authority),
})
}
fn parse_host_port(host_and_port: &str, default_port: u16) -> Result<Option<HostPort>, String> {
let url = format!("https://{host_and_port}");
let url = url.parse::<url::Url>().map_err(|err| err.to_string())?;
let port = url.port().unwrap_or(default_port);
let host = match url.host() {
Some(host) => host,
None => return Ok(None),
};
let host = match host {
Host::Domain(s) => Host::Domain(s.to_owned()),
Host::Ipv4(ip) => Host::Ipv4(ip),
Host::Ipv6(ip) => Host::Ipv6(ip),
};
Ok(Some(HostPort { host, port }))
}
fn nop_resolver_for_ip(ip: IpAddr, port: u16, options: ResolverOptions) -> Box<dyn Resolver> {
let addr = Address {
network_type: TCP_IP_NETWORK_TYPE,
address: ByteStr::from(SocketAddr::new(ip, port).to_string()),
..Default::default()
};
NopResolver::new_with_addr(addr, options)
}