ugi 0.2.1

Runtime-agnostic Rust request client with HTTP/1.1, HTTP/2, HTTP/3, H2C, WebSocket, SSE, and gRPC support
Documentation
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::sync::Mutex;
use std::time::{Duration, Instant};

use crate::error::{Error, ErrorKind, Result};

/// Configuration for the built-in DNS resolver and cache.
#[derive(Clone, Copy, Debug)]
pub struct DnsConfig {
    /// How long resolved addresses are cached before being re-queried.
    pub ttl: Duration,
    /// When `true`, IPv6 and IPv4 addresses are interleaved (Happy Eyeballs
    /// style) to improve connection latency on dual-stack hosts.
    pub dual_stack: bool,
}

impl Default for DnsConfig {
    fn default() -> Self {
        Self {
            ttl: Duration::from_secs(60),
            dual_stack: false,
        }
    }
}

#[derive(Default)]
pub struct DnsCache {
    entries: Mutex<HashMap<String, DnsEntry>>,
}

struct DnsEntry {
    addrs: Vec<IpAddr>,
    expires_at: Instant,
}

impl DnsCache {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn resolve_host(&self, host: &str, config: DnsConfig) -> Result<Vec<IpAddr>> {
        if let Ok(ip) = host.parse::<IpAddr>() {
            return Ok(vec![ip]);
        }

        let now = Instant::now();
        let cached = {
            let entries = self.entries.lock().unwrap_or_else(|err| err.into_inner());
            entries
                .get(host)
                .filter(|entry| entry.expires_at > now)
                .map(|entry| entry.addrs.clone())
        };
        if let Some(addrs) = cached {
            return Ok(addrs);
        }

        let addrs = (host, 0)
            .to_socket_addrs()
            .map_err(|err| {
                Error::with_source(
                    ErrorKind::Transport,
                    format!("failed to resolve host {host}"),
                    err,
                )
            })?
            .map(|addr| addr.ip())
            .collect::<Vec<_>>();
        let addrs = if config.dual_stack {
            interleave_dual_stack(addrs)
        } else {
            addrs
        };

        if addrs.is_empty() {
            return Err(Error::new(
                ErrorKind::Transport,
                format!("host resolved to no addresses: {host}"),
            ));
        }

        self.entries
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(
                host.to_owned(),
                DnsEntry {
                    addrs: addrs.clone(),
                    expires_at: now + config.ttl,
                },
            );

        Ok(addrs)
    }

    pub fn resolve_socket_addrs(
        &self,
        host: &str,
        port: u16,
        config: DnsConfig,
    ) -> Result<Vec<SocketAddr>> {
        let addrs = self.resolve_host(host, config)?;
        Ok(addrs
            .into_iter()
            .map(|ip| SocketAddr::new(ip, port))
            .collect())
    }

    pub fn prefetch<I, S>(&self, hosts: I, config: DnsConfig) -> Result<()>
    where
        I: IntoIterator<Item = S>,
        S: AsRef<str>,
    {
        for host in hosts {
            self.resolve_host(host.as_ref(), config)?;
        }
        Ok(())
    }
}

fn interleave_dual_stack(addrs: Vec<IpAddr>) -> Vec<IpAddr> {
    let mut ipv6 = addrs.iter().copied().filter(IpAddr::is_ipv6);
    let mut ipv4 = addrs.iter().copied().filter(IpAddr::is_ipv4);
    let mut ordered = Vec::with_capacity(addrs.len());
    loop {
        let mut progressed = false;
        if let Some(addr) = ipv6.next() {
            ordered.push(addr);
            progressed = true;
        }
        if let Some(addr) = ipv4.next() {
            ordered.push(addr);
            progressed = true;
        }
        if !progressed {
            break;
        }
    }
    ordered
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn resolves_and_caches_localhost() {
        let cache = DnsCache::new();
        let first = cache
            .resolve_socket_addrs("localhost", 8080, DnsConfig::default())
            .unwrap();
        let second = cache
            .resolve_socket_addrs("localhost", 8080, DnsConfig::default())
            .unwrap();
        assert!(!first.is_empty());
        assert_eq!(first, second);
    }

    #[test]
    fn dual_stack_interleaves_ipv6_and_ipv4_addresses() {
        let ordered = interleave_dual_stack(vec![
            "::1".parse::<IpAddr>().unwrap(),
            "::2".parse::<IpAddr>().unwrap(),
            "127.0.0.1".parse::<IpAddr>().unwrap(),
            "127.0.0.2".parse::<IpAddr>().unwrap(),
        ]);
        assert_eq!(
            ordered,
            vec![
                "::1".parse::<IpAddr>().unwrap(),
                "127.0.0.1".parse::<IpAddr>().unwrap(),
                "::2".parse::<IpAddr>().unwrap(),
                "127.0.0.2".parse::<IpAddr>().unwrap(),
            ]
        );
    }
}