numa 0.17.0

Portable DNS resolver in Rust — .numa local domains, ad blocking, developer overrides, DNS-over-HTTPS
Documentation
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::Instant;

use crate::stats::UpstreamTransport;

const INITIAL_SRTT_MS: u64 = 200;
const FAILURE_PENALTY_MS: u64 = 5000;
const DECAY_AFTER_SECS: u64 = 300;
const MAX_ENTRIES: usize = 4096;
const EVICT_BATCH: usize = 64;

/// Failover circuit-breaker threshold: a primary upstream at or above this
/// SRTT is skipped when a fallback is available. Calibrated between
/// `INITIAL_SRTT_MS` and `FAILURE_PENALTY_MS` so a single failure trips it
/// and decay un-trips it within ~5 minutes. With per-`(ip, transport)`
/// keying, success on a sibling transport (e.g. TCP fallback for the same
/// IP) does NOT re-arm probing of the broken transport — so on UDP-hostile
/// networks the UDP timeout is taken at most once per decay window.
pub const PRIMARY_SKIP_SRTT_MS: u64 = 4000;

struct SrttEntry {
    srtt_ms: u64,
    updated_at: Instant,
}

/// Per-(ip, transport) EWMA so a UDP-hostile path (BCP 38) doesn't poison
/// TCP's score on the same IP, and DoT TLS failures stay isolated from
/// plain TCP. DoH/ODoH route through a URL+pool, never key here.
pub struct SrttCache {
    entries: HashMap<(IpAddr, UpstreamTransport), SrttEntry>,
    enabled: bool,
}

impl Default for SrttCache {
    fn default() -> Self {
        Self::new(true)
    }
}

impl SrttCache {
    pub fn new(enabled: bool) -> Self {
        Self {
            entries: HashMap::new(),
            enabled,
        }
    }

    pub fn is_enabled(&self) -> bool {
        self.enabled
    }

    /// Get current SRTT for an (IP, transport), applying decay if stale.
    /// Returns INITIAL for unknown.
    pub fn get(&self, ip: IpAddr, transport: UpstreamTransport) -> u64 {
        match self.entries.get(&(ip, transport)) {
            Some(entry) => Self::decayed_srtt(entry),
            None => INITIAL_SRTT_MS,
        }
    }

    /// Whether we have observed RTT data for this (IP, transport).
    pub fn is_known(&self, ip: IpAddr, transport: UpstreamTransport) -> bool {
        self.entries.contains_key(&(ip, transport))
    }

    /// Apply time-based decay: each DECAY_AFTER_SECS period halves distance to INITIAL.
    fn decayed_srtt(entry: &SrttEntry) -> u64 {
        Self::decay_for_age(entry.srtt_ms, entry.updated_at.elapsed().as_secs())
    }

    fn decay_for_age(srtt_ms: u64, age_secs: u64) -> u64 {
        if age_secs > DECAY_AFTER_SECS {
            let periods = (age_secs / DECAY_AFTER_SECS).min(8);
            let mut srtt = srtt_ms;
            for _ in 0..periods {
                srtt = (srtt + INITIAL_SRTT_MS) / 2;
            }
            srtt
        } else {
            srtt_ms
        }
    }

    /// Record a successful query RTT. No-op when disabled.
    pub fn record_rtt(&mut self, ip: IpAddr, transport: UpstreamTransport, rtt_ms: u64) {
        if !self.enabled {
            return;
        }
        self.maybe_evict();
        let entry = self.entries.entry((ip, transport)).or_insert(SrttEntry {
            srtt_ms: rtt_ms,
            updated_at: Instant::now(),
        });
        // Apply decay before EWMA so recovered servers aren't stuck at stale penalties
        let base = Self::decayed_srtt(entry);
        // BIND EWMA: new = (old * 7 + sample) / 8
        entry.srtt_ms = (base * 7 + rtt_ms) / 8;
        entry.updated_at = Instant::now();
    }

    /// Record a failure (timeout or error). No-op when disabled.
    pub fn record_failure(&mut self, ip: IpAddr, transport: UpstreamTransport) {
        if !self.enabled {
            return;
        }
        self.maybe_evict();
        let entry = self.entries.entry((ip, transport)).or_insert(SrttEntry {
            srtt_ms: FAILURE_PENALTY_MS,
            updated_at: Instant::now(),
        });
        entry.srtt_ms = FAILURE_PENALTY_MS;
        entry.updated_at = Instant::now();
    }

    /// Sort by UDP SRTT ascending (lowest/fastest first). No-op when disabled.
    pub fn sort_by_udp_rtt(&self, addrs: &mut [SocketAddr]) {
        if !self.enabled {
            return;
        }
        addrs.sort_by_key(|a| self.get(a.ip(), UpstreamTransport::Udp));
    }

    pub fn heap_bytes(&self) -> usize {
        let per_slot = std::mem::size_of::<u64>()
            + std::mem::size_of::<(IpAddr, UpstreamTransport)>()
            + std::mem::size_of::<SrttEntry>()
            + 1;
        self.entries.capacity() * per_slot
    }

    pub fn len(&self) -> usize {
        self.entries.len()
    }

    pub fn is_empty(&self) -> bool {
        self.entries.is_empty()
    }

    fn maybe_evict(&mut self) {
        if self.entries.len() < MAX_ENTRIES {
            return;
        }
        // Batch eviction: remove the oldest EVICT_BATCH entries at once
        let mut by_age: Vec<_> = self.entries.keys().copied().collect();
        by_age.sort_by_key(|k| self.entries[k].updated_at);
        for k in by_age.into_iter().take(EVICT_BATCH) {
            self.entries.remove(&k);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::Ipv4Addr;

    const UDP: UpstreamTransport = UpstreamTransport::Udp;
    const TCP: UpstreamTransport = UpstreamTransport::Tcp;

    fn ip(last: u8) -> IpAddr {
        IpAddr::V4(Ipv4Addr::new(192, 0, 2, last))
    }

    fn sock(last: u8) -> SocketAddr {
        SocketAddr::new(ip(last), 53)
    }

    #[test]
    fn unknown_returns_initial() {
        let cache = SrttCache::new(true);
        assert_eq!(cache.get(ip(1), UDP), INITIAL_SRTT_MS);
    }

    #[test]
    fn ewma_converges() {
        let mut cache = SrttCache::new(true);
        for _ in 0..20 {
            cache.record_rtt(ip(1), UDP, 100);
        }
        let srtt = cache.get(ip(1), UDP);
        assert!(srtt >= 98 && srtt <= 102, "srtt={}", srtt);
    }

    #[test]
    fn failure_sets_penalty() {
        let mut cache = SrttCache::new(true);
        cache.record_rtt(ip(1), UDP, 50);
        cache.record_failure(ip(1), UDP);
        assert_eq!(cache.get(ip(1), UDP), FAILURE_PENALTY_MS);
    }

    #[test]
    fn udp_and_tcp_are_independent() {
        // UDP penalty must not bleed into the TCP EWMA on the same IP.
        let mut cache = SrttCache::new(true);
        cache.record_failure(ip(1), UDP);
        for _ in 0..20 {
            cache.record_rtt(ip(1), TCP, 50);
        }
        assert_eq!(cache.get(ip(1), UDP), FAILURE_PENALTY_MS);
        let tcp = cache.get(ip(1), TCP);
        assert!(tcp >= 48 && tcp <= 52, "tcp srtt drifted: {}", tcp);
    }

    #[test]
    fn sort_by_udp_rtt_orders_correctly() {
        let mut cache = SrttCache::new(true);
        for _ in 0..20 {
            cache.record_rtt(ip(1), UDP, 500);
            cache.record_rtt(ip(2), UDP, 100);
            cache.record_rtt(ip(3), UDP, 10);
        }
        let mut addrs = vec![sock(1), sock(2), sock(3)];
        cache.sort_by_udp_rtt(&mut addrs);
        assert_eq!(addrs, vec![sock(3), sock(2), sock(1)]);
    }

    #[test]
    fn sort_by_udp_rtt_ignores_tcp_entries() {
        // TCP scores must not influence UDP-context sorting.
        let mut cache = SrttCache::new(true);
        for _ in 0..20 {
            cache.record_rtt(ip(1), TCP, 10); // would sort first if TCP leaked in
            cache.record_rtt(ip(2), UDP, 100);
        }
        let mut addrs = vec![sock(1), sock(2)];
        cache.sort_by_udp_rtt(&mut addrs);
        // ip(1) has no UDP record → INITIAL (200) > ip(2) UDP (100)
        assert_eq!(addrs, vec![sock(2), sock(1)]);
    }

    #[test]
    fn unknown_servers_sort_equal() {
        let cache = SrttCache::new(true);
        let mut addrs = vec![sock(1), sock(2), sock(3)];
        let original = addrs.clone();
        cache.sort_by_udp_rtt(&mut addrs);
        assert_eq!(addrs, original);
    }

    #[test]
    fn disabled_is_noop() {
        let mut cache = SrttCache::new(false);
        cache.record_rtt(ip(1), UDP, 50);
        cache.record_failure(ip(2), UDP);
        assert_eq!(cache.len(), 0);

        let mut addrs = vec![sock(2), sock(1)];
        let original = addrs.clone();
        cache.sort_by_udp_rtt(&mut addrs);
        assert_eq!(addrs, original);
    }

    #[test]
    fn no_decay_within_threshold() {
        // At exactly DECAY_AFTER_SECS, no decay applied
        let result = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS);
        assert_eq!(result, FAILURE_PENALTY_MS);
    }

    #[test]
    fn one_decay_period() {
        let result = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS + 1);
        let expected = (FAILURE_PENALTY_MS + INITIAL_SRTT_MS) / 2;
        assert_eq!(result, expected);
    }

    #[test]
    fn multiple_decay_periods() {
        let result = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 4 + 1);
        let mut expected = FAILURE_PENALTY_MS;
        for _ in 0..4 {
            expected = (expected + INITIAL_SRTT_MS) / 2;
        }
        assert_eq!(result, expected);
    }

    #[test]
    fn decay_caps_at_8_periods() {
        // 9 periods and 100 periods should produce the same result (capped at 8)
        let a = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 9 + 1);
        let b = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 100);
        assert_eq!(a, b);
    }

    #[test]
    fn decay_converges_toward_initial() {
        let decayed = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 100);
        let diff = decayed.abs_diff(INITIAL_SRTT_MS);
        assert!(
            diff < 25,
            "expected near INITIAL_SRTT_MS, got {} (diff={})",
            decayed,
            diff
        );
    }

    #[test]
    fn record_rtt_applies_decay_before_ewma() {
        // Verify decay is applied before EWMA in record_rtt by checking
        // that a saturated penalty + long age + new sample produces a low SRTT
        let decayed = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 8);
        // EWMA: (decayed * 7 + 50) / 8
        let after_ewma = (decayed * 7 + 50) / 8;
        assert!(
            after_ewma < 500,
            "expected decay before EWMA, got srtt={}",
            after_ewma
        );
    }

    #[test]
    fn decay_reranks_stale_failures() {
        // After enough decay, a failed server (5000ms) converges toward
        // INITIAL (200ms), which is below a stable server at 300ms
        let decayed = SrttCache::decay_for_age(FAILURE_PENALTY_MS, DECAY_AFTER_SECS * 100);
        assert!(
            decayed < 300,
            "expected decayed penalty ({}) < 300ms",
            decayed
        );
    }

    #[test]
    fn heap_bytes_grows_with_entries() {
        let mut cache = SrttCache::new(true);
        let empty = cache.heap_bytes();
        for i in 1..=10u8 {
            cache.record_rtt(ip(i), UDP, 100);
        }
        assert!(cache.heap_bytes() > empty);
    }

    #[test]
    fn eviction_removes_oldest() {
        let mut cache = SrttCache::new(true);
        for i in 0..MAX_ENTRIES {
            let octets = [
                10,
                ((i >> 16) & 0xFF) as u8,
                ((i >> 8) & 0xFF) as u8,
                (i & 0xFF) as u8,
            ];
            cache.record_rtt(
                IpAddr::V4(Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3])),
                UDP,
                100,
            );
        }
        assert_eq!(cache.len(), MAX_ENTRIES);
        cache.record_rtt(ip(1), UDP, 100);
        // Batch eviction removes EVICT_BATCH entries
        assert!(cache.len() <= MAX_ENTRIES - EVICT_BATCH + 1);
    }
}