Skip to main content

cellos_supervisor/dns_proxy/
upstream.rs

1//! SEC-21 Phase 3h.2 / T2.B Slot **A6** — DoT/DoH/DoQ upstream support.
2//!
3//! Phase 1's now-deleted `super::forward_upstream` only spoke `do53-udp`. This
4//! module generalises that single hot path to the full set of upstream transports the
5//! contract layer (`DnsAuthority::resolvers[].protocol`) admits:
6//!
7//! | Variant       | Wire        | Status (DNS-DOH-1..3 / DNS-DOQ-1..2)              |
8//! |---------------|-------------|----------------------------------------------------|
9//! | `Do53Udp`     | UDP/53      | Default. Byte-for-byte the legacy forward path.    |
10//! | `Do53Tcp`     | TCP/53      | Length-prefixed two-byte framing over plain TCP.   |
11//! | `Dot`         | TLS/853     | Length-prefixed framing over rustls 0.23 TLS.      |
12//! | `Doh`         | HTTPS/443   | reqwest POST `application/dns-message` (RFC 8484). |
13//! | `Doq`         | QUIC/853    | quinn 0.11 bidi stream (RFC 9250).                 |
14//!
15//! ## DoH / DoQ configuration shape
16//!
17//! The original DNS-DOH-1 ticket proposed widening `UpstreamTransport` into a
18//! data-carrying enum (`Doh { url }`, `Doq { server, port }`). That would
19//! ripple through ~10 call sites in `supervisor.rs` and the test surface that
20//! pattern-match the enum as `Copy` unit variants. Instead, DoH/DoQ
21//! configuration rides on [`UpstreamExtras`] alongside the existing DoT
22//! overrides — the enum stays `Copy`, the production composition root keeps
23//! the same shape, and the env-var protocol (`CELLOS_DNS_UPSTREAM_DOH_URL`,
24//! `CELLOS_DNS_UPSTREAM_DOQ_SERVER`, `CELLOS_DNS_UPSTREAM_DOQ_PORT`) wires
25//! through unchanged. The downside is that operators selecting DoH/DoQ with
26//! no extras get a sensible default (Cloudflare `1.1.1.1`) rather than an
27//! eager refusal; the upside is that all transports share one config seam.
28//!
29//! ## Why DoT/DoH/DoQ are split into typed variants now
30//!
31//! The supervisor's contract layer already accepts a `protocol` field on each
32//! declared resolver — there's no schema migration in this slot, only a
33//! widening of the dataplane forward path. Operators who declare a non-`do53`
34//! protocol today get a SERVFAIL with an explicit `TransportNotEnabled`
35//! discriminator (D9 in the dispatch ledger) instead of silently downgrading
36//! to plain UDP. The toggle parses, the audit event tells the truth.
37//!
38//! ## DoT (the actually-implemented non-do53 path)
39//!
40//! Hickory's DoT support exists but is feature-gated and pulls in the full
41//! `dns-over-rustls` stack — for one transport that's heavyweight. We
42//! implement DoT directly on rustls 0.23 + tokio-rustls 0.26:
43//!
44//! 1. TCP connect to the resolver's `host:port` (port defaulted to 853 if the
45//!    operator did not specify one — RFC 7858).
46//! 2. TLS handshake using a `RootCertStore` populated from
47//!    `webpki-roots = "1"` (Mozilla's bundled CA set; same trust set hickory
48//!    uses internally).
49//! 3. RFC 7858 framing: each query is prefixed with its 2-byte big-endian
50//!    length, and the response is parsed in the same shape.
51//!
52//! The `forward()` entry point is synchronous (the proxy hot path runs on a
53//! blocking thread; see [`super::run_one_shot`]) so we drive the async DoT /
54//! DoH / DoQ paths via `tokio::runtime::Handle::block_on` wrapped in
55//! `tokio::task::block_in_place`. If no runtime is in scope we return
56//! [`UpstreamError::NoTokioRuntime`] and the caller fail-safes to SERVFAIL.
57//!
58//! ## DoH (RFC 8484, DNS-DOH-1..3)
59//!
60//! `forward_doh()` POSTs the raw DNS wire-format query to the operator-
61//! supplied URL (default `https://1.1.1.1/dns-query`) with
62//! `Content-Type: application/dns-message` + `Accept: application/dns-message`,
63//! then reads the response body as the raw DNS wire response. reqwest's
64//! client is constructed per call (no pool reuse) — keeps the path stateless
65//! to match DoT and the rest of the proxy.
66//!
67//! ## DoQ (RFC 9250, DNS-DOQ-1..2)
68//!
69//! `forward_doq()` opens a QUIC connection to the operator-supplied
70//! `server:port` (default `1.1.1.1:853`), opens a bidirectional stream, writes
71//! the RFC 9250 2-byte-length-prefixed query, finishes the send side, then
72//! reads the length-prefixed response. ALPN is `doq` per RFC 9250 §4.1.1.
73//! rustls 0.23 + webpki-roots is the trust set (same as DoT).
74
75use std::io;
76use std::net::{SocketAddr, UdpSocket};
77use std::sync::Arc;
78use std::time::Duration;
79
80use rustls::pki_types::ServerName;
81use rustls::{ClientConfig, RootCertStore};
82use tokio::io::{AsyncReadExt, AsyncWriteExt};
83use tokio_rustls::TlsConnector;
84
85/// Maximum DNS payload the proxy will pass through. Mirrors the
86/// `super::MAX_UDP_PAYLOAD` ceiling so callers can safely hand us a 1500-byte
87/// scratch buffer.
88const MAX_PAYLOAD: usize = 1500;
89
90/// Default DoT port (RFC 7858).
91const DEFAULT_DOT_PORT: u16 = 853;
92
93/// Default DoH endpoint (Cloudflare). Picked when an operator selects DoH
94/// without supplying `CELLOS_DNS_UPSTREAM_DOH_URL`. RFC 8484 §4.1 documents
95/// the canonical `/dns-query` path; Cloudflare answers on both the apex IP
96/// and `cloudflare-dns.com`.
97const DEFAULT_DOH_URL: &str = "https://1.1.1.1/dns-query";
98
99/// Default DoQ host + port (Cloudflare on 1.1.1.1:853, RFC 9250 §4).
100const DEFAULT_DOQ_SERVER: &str = "1.1.1.1";
101const DEFAULT_DOQ_PORT: u16 = 853;
102
103/// Upstream transport selector. Production callers populate this from
104/// `DnsAuthority::resolvers[].protocol`; operators in dev / CI override via
105/// the `CELLOS_DNS_UPSTREAM_TRANSPORT` env var (see [`UpstreamTransport::from_env`]).
106///
107/// `Default = Do53Udp` so the existing UDP-only behaviour is the no-op
108/// upgrade path — adding the `transport` field to [`super::DnsProxyConfig`]
109/// does not change behaviour for cells that don't explicitly opt into a
110/// non-default value.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum UpstreamTransport {
113    /// Plain UDP/53 (legacy default, byte-identical to pre-A6 forward path).
114    #[default]
115    Do53Udp,
116    /// Plain TCP/53 with the conventional 2-byte length prefix.
117    Do53Tcp,
118    /// DNS-over-TLS (RFC 7858). Length-prefixed framing inside a TLS 1.2/1.3 tunnel.
119    Dot,
120    /// DNS-over-HTTPS (RFC 8484). reqwest POST with
121    /// `application/dns-message` framing. URL configured via
122    /// [`UpstreamExtras::doh_url`] (default Cloudflare).
123    Doh,
124    /// DNS-over-QUIC (RFC 9250). quinn 0.11 bidi stream with the same
125    /// 2-byte length prefix as DoT. Server + port configured via
126    /// [`UpstreamExtras::doq_server`] / [`UpstreamExtras::doq_port`]
127    /// (default `1.1.1.1:853`).
128    Doq,
129}
130
131impl UpstreamTransport {
132    /// Parse a textual transport selector. Case-insensitive; accepts the
133    /// canonical contract names (`do53-udp`, `do53-tcp`, `dot`, `doh`, `doq`)
134    /// AND the common aliases operators reach for (`udp`, `tcp`, `tls`,
135    /// `https`, `quic`).
136    ///
137    /// Returns `None` for unrecognised input. The caller is expected to
138    /// fail-closed (refuse to construct the cell) rather than silently
139    /// fall back to UDP, so we don't pretend to know what the operator meant.
140    pub fn parse(s: &str) -> Option<Self> {
141        match s.trim().to_ascii_lowercase().as_str() {
142            "do53-udp" | "do53udp" | "udp" | "" => Some(Self::Do53Udp),
143            "do53-tcp" | "do53tcp" | "tcp" => Some(Self::Do53Tcp),
144            "dot" | "tls" | "do-tls" => Some(Self::Dot),
145            "doh" | "https" | "do-https" => Some(Self::Doh),
146            "doq" | "quic" | "do-quic" => Some(Self::Doq),
147            _ => None,
148        }
149    }
150
151    /// Read the upstream transport selector from the process environment and
152    /// parse it. Two env var names are accepted (in priority order):
153    ///
154    /// 1. `CELLOS_DNS_UPSTREAM_PROTOCOL` — the contract-aligned name
155    ///    (matches `DnsAuthority::resolvers[].protocol`). Preferred.
156    /// 2. `CELLOS_DNS_UPSTREAM_TRANSPORT` — the original legacy name,
157    ///    preserved so existing operator scripts and integration tests
158    ///    keep working byte-for-byte.
159    ///
160    /// Unset returns `Some(Do53Udp)` (the default). Set-but-unparseable
161    /// returns `None` so the caller can refuse to start.
162    pub fn from_env() -> Option<Self> {
163        for var in [
164            "CELLOS_DNS_UPSTREAM_PROTOCOL",
165            "CELLOS_DNS_UPSTREAM_TRANSPORT",
166        ] {
167            match std::env::var(var) {
168                Ok(s) if !s.trim().is_empty() => return Self::parse(&s),
169                _ => continue,
170            }
171        }
172        Some(Self::Do53Udp)
173    }
174
175    /// Stable string label suitable for stamping into events / logs. Used in
176    /// [`UpstreamError::TransportNotEnabled`]'s `Display`.
177    pub fn label(self) -> &'static str {
178        match self {
179            Self::Do53Udp => "do53-udp",
180            Self::Do53Tcp => "do53-tcp",
181            Self::Dot => "dot",
182            Self::Doh => "doh",
183            Self::Doq => "doq",
184        }
185    }
186}
187
188/// Optional per-transport extras the caller may supply alongside
189/// [`UpstreamTransport`]. Populated by [`super::DnsProxyConfig::upstream_extras`]
190/// in production; tests construct via `Default::default()` and then set only
191/// the fields they care about.
192///
193/// All fields are optional so the `do53-udp` hot path doesn't require
194/// any extras to function — unused fields incur zero cost.
195#[derive(Debug, Clone, Default)]
196pub struct UpstreamExtras {
197    /// SNI hostname presented during the DoT TLS handshake. When `None`,
198    /// rustls is given the resolver's IP literal — most public DoT resolvers
199    /// (1.1.1.1 / 8.8.8.8) ship a cert that covers both the hostname and
200    /// the IP, but operators with a private resolver will need to set this
201    /// to the cert's CN/SAN.
202    pub dot_sni: Option<String>,
203    /// Operator-supplied DoT server host. When set (typically populated from
204    /// the `CELLOS_DNS_UPSTREAM_DOT_SERVER` env var), the supervisor pre-resolves
205    /// this to a `SocketAddr` and substitutes the proxy's `upstream_addr`,
206    /// so the DoT roundtrip targets the operator's choice rather than the
207    /// spec's do53 resolver. `None` falls back to the spec resolver's IP.
208    ///
209    /// Plain string here (not pre-resolved) so the config struct stays
210    /// transport-agnostic and the resolution step lives in the composition
211    /// root where DNS bootstrap is allowed.
212    pub dot_server: Option<String>,
213    /// Operator-supplied DoT port (RFC 7858 default = 853). Paired with
214    /// [`Self::dot_server`]; when `None` the supervisor defaults to 853.
215    pub dot_port: Option<u16>,
216    /// DNS-DOH-2 — operator-supplied DoH endpoint URL. When `None` the
217    /// DoH forward path defaults to [`DEFAULT_DOH_URL`]. Sourced from
218    /// `CELLOS_DNS_UPSTREAM_DOH_URL` in production. Must be a full
219    /// `https://…` URL including the `/dns-query` path; reqwest validates
220    /// the scheme on first call and surfaces a typed
221    /// [`UpstreamError::Io(InvalidInput)`] if it's malformed.
222    pub doh_url: Option<String>,
223    /// DNS-DOQ-2 — operator-supplied DoQ server (IP literal or hostname).
224    /// `None` → default `1.1.1.1`. Sourced from
225    /// `CELLOS_DNS_UPSTREAM_DOQ_SERVER`. Hostnames are passed through to
226    /// `tokio::net::lookup_host` (which uses the OS resolver, NOT the
227    /// supervisor's bootstrap path) — operators should prefer IP literals
228    /// for the same reason DoT does (see [`parse_dot_target`]).
229    pub doq_server: Option<String>,
230    /// DNS-DOQ-2 — operator-supplied DoQ port. `None` → default `853`
231    /// (RFC 9250). Sourced from `CELLOS_DNS_UPSTREAM_DOQ_PORT`.
232    pub doq_port: Option<u16>,
233}
234
235impl UpstreamExtras {
236    /// Read the DoT-specific operator overrides from the process environment.
237    ///
238    /// Recognised env vars:
239    ///   - `CELLOS_DNS_UPSTREAM_DOT_SERVER` — host (IP literal or hostname)
240    ///     for the DoT upstream. Default unset (caller falls back to the
241    ///     spec resolver's IP, or to `1.1.1.1` if the caller has no spec
242    ///     to fall back on).
243    ///   - `CELLOS_DNS_UPSTREAM_DOT_PORT` — TCP port. Default unset
244    ///     (caller falls back to `853` per RFC 7858).
245    ///   - `CELLOS_DNS_UPSTREAM_DOT_SNI` — explicit SNI hostname. Default
246    ///     unset (rustls receives the resolver's IP literal as ServerName).
247    ///
248    /// Unparseable port values are silently ignored (the field stays `None`)
249    /// rather than failing — operators get the default behaviour instead of
250    /// a refused cell. Strict parsing is the supervisor's job at the composition
251    /// site if it wants to gate startup on a typo.
252    pub fn from_env() -> Self {
253        let dot_server = std::env::var("CELLOS_DNS_UPSTREAM_DOT_SERVER")
254            .ok()
255            .filter(|s| !s.trim().is_empty());
256        let dot_port = std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT")
257            .ok()
258            .and_then(|s| s.trim().parse::<u16>().ok());
259        let dot_sni = std::env::var("CELLOS_DNS_UPSTREAM_DOT_SNI")
260            .ok()
261            .filter(|s| !s.trim().is_empty());
262        // DNS-DOH-2 / DNS-DOQ-2 — DoH/DoQ env wiring. Mirrors the DoT pattern:
263        // unparseable port is silently treated as unset rather than failing,
264        // so a typo lands the operator on the default (typed config error
265        // surfaces later via reqwest URL validation in the DoH case).
266        let doh_url = std::env::var("CELLOS_DNS_UPSTREAM_DOH_URL")
267            .ok()
268            .filter(|s| !s.trim().is_empty());
269        let doq_server = std::env::var("CELLOS_DNS_UPSTREAM_DOQ_SERVER")
270            .ok()
271            .filter(|s| !s.trim().is_empty());
272        let doq_port = std::env::var("CELLOS_DNS_UPSTREAM_DOQ_PORT")
273            .ok()
274            .and_then(|s| s.trim().parse::<u16>().ok());
275        Self {
276            dot_sni,
277            dot_server,
278            dot_port,
279            doh_url,
280            doq_server,
281            doq_port,
282        }
283    }
284}
285
286/// Typed upstream-forward error.
287///
288/// Mapped at the [`super::run_one_shot`] call site to a SERVFAIL response
289/// for the workload plus a `dns_query` CloudEvent stamped
290/// `reasonCode: upstream_failure`. The variant discriminator is stable so
291/// follow-up slots can refine the event's reason text without re-shaping
292/// the matrix.
293#[derive(Debug)]
294pub enum UpstreamError {
295    /// Upstream did not answer within the configured budget.
296    Timeout,
297    /// Lower-level I/O failure (refused, unreachable, transport reset).
298    Io(io::Error),
299    /// Operator selected a transport this build does not implement —
300    /// typed SERVFAIL discriminator (D9 in the dispatch ledger).
301    TransportNotEnabled(UpstreamTransport),
302    /// rustls handshake failed (cert chain rejected, SNI mismatch, protocol
303    /// error). Distinct from [`Self::Io`] so triage can tell "the resolver
304    /// answered but with a bad cert" from "the resolver didn't answer".
305    TlsHandshake(String),
306    /// Async transport invoked outside a tokio runtime — programming error
307    /// in production (the supervisor always runs inside one) but explicit
308    /// rather than panicking.
309    NoTokioRuntime,
310}
311
312impl std::fmt::Display for UpstreamError {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        match self {
315            Self::Timeout => write!(f, "upstream timeout"),
316            Self::Io(e) => write!(f, "upstream io: {e}"),
317            Self::TransportNotEnabled(t) => {
318                write!(
319                    f,
320                    "upstream transport '{}' not enabled in this build",
321                    t.label()
322                )
323            }
324            Self::TlsHandshake(msg) => write!(f, "tls handshake: {msg}"),
325            Self::NoTokioRuntime => write!(f, "no tokio runtime in scope for async upstream"),
326        }
327    }
328}
329
330impl std::error::Error for UpstreamError {
331    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
332        match self {
333            Self::Io(e) => Some(e),
334            _ => None,
335        }
336    }
337}
338
339impl From<io::Error> for UpstreamError {
340    fn from(e: io::Error) -> Self {
341        if matches!(
342            e.kind(),
343            io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
344        ) {
345            Self::Timeout
346        } else {
347            Self::Io(e)
348        }
349    }
350}
351
352/// Synchronous upstream forward — the single entry point the proxy hot path
353/// calls. Dispatches on [`UpstreamTransport`].
354///
355/// `udp_socket` is the pre-bound UDP socket used for the `Do53Udp` path; for
356/// every other transport it is ignored (we open a fresh transport-specific
357/// connection per query — no connection pooling in this slot). `out_buf` is
358/// filled with the upstream's response and the byte count returned; on error
359/// the buffer is left untouched.
360pub fn forward(
361    transport: UpstreamTransport,
362    udp_socket: &UdpSocket,
363    upstream: SocketAddr,
364    query: &[u8],
365    out_buf: &mut [u8],
366    timeout: Duration,
367    extras: &UpstreamExtras,
368) -> Result<usize, UpstreamError> {
369    if query.len() > MAX_PAYLOAD {
370        return Err(UpstreamError::Io(io::Error::new(
371            io::ErrorKind::InvalidInput,
372            "query exceeds MAX_PAYLOAD",
373        )));
374    }
375    match transport {
376        UpstreamTransport::Do53Udp => forward_udp(udp_socket, upstream, query, out_buf, timeout),
377        UpstreamTransport::Do53Tcp => forward_tcp(upstream, query, out_buf, timeout),
378        UpstreamTransport::Dot => forward_dot(upstream, query, out_buf, timeout, extras),
379        UpstreamTransport::Doh => forward_doh(query, out_buf, timeout, extras),
380        UpstreamTransport::Doq => forward_doq(query, out_buf, timeout, extras),
381    }
382}
383
384fn forward_udp(
385    upstream: &UdpSocket,
386    addr: SocketAddr,
387    query: &[u8],
388    buf: &mut [u8],
389    timeout: Duration,
390) -> Result<usize, UpstreamError> {
391    upstream.send_to(query, addr)?;
392    upstream.set_read_timeout(Some(timeout))?;
393    let deadline = std::time::Instant::now() + timeout;
394    loop {
395        match upstream.recv_from(buf) {
396            Ok((n, _peer)) => return Ok(n),
397            Err(e)
398                if matches!(
399                    e.kind(),
400                    io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
401                ) =>
402            {
403                if std::time::Instant::now() >= deadline {
404                    return Err(UpstreamError::Timeout);
405                }
406            }
407            Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
408            Err(e) => return Err(UpstreamError::Io(e)),
409        }
410    }
411}
412
413fn forward_tcp(
414    upstream: SocketAddr,
415    query: &[u8],
416    buf: &mut [u8],
417    timeout: Duration,
418) -> Result<usize, UpstreamError> {
419    use std::io::{Read, Write};
420    use std::net::TcpStream;
421
422    let mut stream = TcpStream::connect_timeout(&upstream, timeout)?;
423    stream.set_read_timeout(Some(timeout))?;
424    stream.set_write_timeout(Some(timeout))?;
425
426    let len = query.len() as u16;
427    stream.write_all(&len.to_be_bytes())?;
428    stream.write_all(query)?;
429
430    let mut len_buf = [0u8; 2];
431    stream.read_exact(&mut len_buf)?;
432    let resp_len = u16::from_be_bytes(len_buf) as usize;
433    if resp_len > buf.len() {
434        return Err(UpstreamError::Io(io::Error::new(
435            io::ErrorKind::InvalidData,
436            "tcp response exceeds out buffer",
437        )));
438    }
439    stream.read_exact(&mut buf[..resp_len])?;
440    Ok(resp_len)
441}
442
443fn forward_dot(
444    upstream: SocketAddr,
445    query: &[u8],
446    buf: &mut [u8],
447    timeout: Duration,
448    extras: &UpstreamExtras,
449) -> Result<usize, UpstreamError> {
450    let handle =
451        tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
452
453    // Resolve target SocketAddr. Precedence:
454    //   1. `extras.dot_server` (operator override, typically from
455    //      `CELLOS_DNS_UPSTREAM_DOT_SERVER`) — interpreted as either a
456    //      bare IP literal or a `host:port` pair. If the parse fails, we
457    //      surface a typed `Io(InvalidInput)` so the caller fail-safes to
458    //      SERVFAIL rather than silently routing to the wrong resolver.
459    //   2. The `upstream` SocketAddr the caller supplied (spec resolver).
460    //   3. Port defaults to 853 (RFC 7858) when zero.
461    let mut target = match extras.dot_server.as_deref() {
462        Some(host) => parse_dot_target(host, extras.dot_port).map_err(|e| {
463            UpstreamError::Io(io::Error::new(
464                io::ErrorKind::InvalidInput,
465                format!("CELLOS_DNS_UPSTREAM_DOT_SERVER='{host}' did not parse: {e}"),
466            ))
467        })?,
468        None => {
469            let mut t = upstream;
470            if let Some(p) = extras.dot_port {
471                t.set_port(p);
472            }
473            t
474        }
475    };
476    if target.port() == 0 {
477        target.set_port(DEFAULT_DOT_PORT);
478    }
479
480    let sni = extras.dot_sni.clone();
481    let query = query.to_vec();
482    let buf_len = buf.len();
483    // `forward` is a sync entry point so it can sit on the proxy's
484    // blocking I/O thread. The DoT path needs an async runtime for the
485    // TLS handshake. `block_in_place` is tokio's official escape hatch:
486    // it tells the multi-thread scheduler "this worker is about to
487    // block; redistribute the other tasks", which keeps `block_on`
488    // from panicking with `Cannot start a runtime from within a runtime`
489    // when callers happen to dispatch us from inside an async context
490    // (notably integration tests like supervisor_dns_proxy_dot.rs).
491    let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
492        handle.block_on(async move {
493            match tokio::time::timeout(timeout, dot_roundtrip(target, &query, &sni, buf_len)).await
494            {
495                Ok(inner) => inner,
496                Err(_) => Err(UpstreamError::Timeout),
497            }
498        })
499    });
500    let resp = result?;
501    if resp.len() > buf.len() {
502        return Err(UpstreamError::Io(io::Error::new(
503            io::ErrorKind::InvalidData,
504            "dot response exceeds out buffer",
505        )));
506    }
507    buf[..resp.len()].copy_from_slice(&resp);
508    Ok(resp.len())
509}
510
511/// Parse the operator-supplied DoT server string into a `SocketAddr`.
512///
513/// Accepted shapes:
514///   - bare IPv4 literal (`1.1.1.1`)
515///   - bare IPv6 literal (`2606:4700:4700::1111`)
516///   - `IPv4:port` (`1.1.1.1:853`)
517///   - `[IPv6]:port` (`[2606:4700:4700::1111]:853`)
518///
519/// Hostnames are explicitly rejected: the in-netns dataplane has no DNS
520/// bootstrap path, and accepting a hostname here would either require a
521/// circular DNS lookup or a silent fallback to the system resolver outside
522/// the cell's authority. The composition root must pre-resolve any hostname
523/// to an IP literal before populating `UpstreamExtras::dot_server`.
524fn parse_dot_target(host: &str, port_override: Option<u16>) -> Result<SocketAddr, String> {
525    let trimmed = host.trim();
526    if trimmed.is_empty() {
527        return Err("empty string".to_string());
528    }
529
530    // If the operator wrote `host:port`, prefer that port over the override.
531    // SocketAddr::from_str handles both `IPv4:port` and `[IPv6]:port`.
532    if let Ok(sa) = trimmed.parse::<SocketAddr>() {
533        return Ok(sa);
534    }
535
536    // Otherwise treat the string as a bare IP literal and bolt on the port.
537    let port = port_override.unwrap_or(DEFAULT_DOT_PORT);
538    let ip: std::net::IpAddr = trimmed.parse().map_err(|_| {
539        format!(
540            "'{trimmed}' is not an IP literal (hostnames must be pre-resolved by the supervisor)"
541        )
542    })?;
543    Ok(SocketAddr::new(ip, port))
544}
545
546/// Async DoT roundtrip — TCP connect, TLS handshake, length-prefixed query,
547/// length-prefixed response.
548async fn dot_roundtrip(
549    target: SocketAddr,
550    query: &[u8],
551    sni: &Option<String>,
552    out_cap: usize,
553) -> Result<Vec<u8>, UpstreamError> {
554    let config = build_dot_client_config();
555    let connector = TlsConnector::from(Arc::new(config));
556
557    let tcp = tokio::net::TcpStream::connect(target).await?;
558
559    // Build the SNI ServerName. Operator-supplied SNI wins; otherwise we
560    // fall back to the resolver's IP literal — rustls 0.23's `pki_types`
561    // accepts `IpAddress` server names so the handshake will succeed
562    // against resolvers whose certs cover the IP (1.1.1.1, 9.9.9.9, etc.).
563    let server_name: ServerName<'static> = match sni {
564        Some(host) if !host.is_empty() => ServerName::try_from(host.clone())
565            .map_err(|e| UpstreamError::TlsHandshake(format!("invalid sni '{host}': {e}")))?,
566        _ => ServerName::IpAddress(target.ip().into()),
567    };
568
569    let mut tls = connector
570        .connect(server_name, tcp)
571        .await
572        .map_err(|e| UpstreamError::TlsHandshake(format!("{e}")))?;
573
574    // RFC 7858 framing — 2-byte big-endian length prefix on both sides.
575    let len = query.len() as u16;
576    tls.write_all(&len.to_be_bytes()).await?;
577    tls.write_all(query).await?;
578    tls.flush().await?;
579
580    let mut len_buf = [0u8; 2];
581    tls.read_exact(&mut len_buf).await?;
582    let resp_len = u16::from_be_bytes(len_buf) as usize;
583    if resp_len > out_cap {
584        return Err(UpstreamError::Io(io::Error::new(
585            io::ErrorKind::InvalidData,
586            "dot response exceeds out buffer",
587        )));
588    }
589    let mut resp = vec![0u8; resp_len];
590    tls.read_exact(&mut resp).await?;
591    Ok(resp)
592}
593
594/// DNS-DOH-1 — DoH (RFC 8484) forward path.
595///
596/// Synchronous entry point following the same `block_in_place` +
597/// `Handle::block_on` pattern as [`forward_dot`]: the proxy hot path is a
598/// blocking thread inside a tokio runtime, so we redistribute scheduler
599/// work and drive an async reqwest POST. Returns the response wire bytes
600/// or a typed [`UpstreamError`].
601///
602/// Endpoint resolution precedence:
603///   1. `extras.doh_url` (operator override, typically from
604///      `CELLOS_DNS_UPSTREAM_DOH_URL`).
605///   2. [`DEFAULT_DOH_URL`] = Cloudflare's `https://1.1.1.1/dns-query`.
606fn forward_doh(
607    query: &[u8],
608    buf: &mut [u8],
609    timeout: Duration,
610    extras: &UpstreamExtras,
611) -> Result<usize, UpstreamError> {
612    let handle =
613        tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
614
615    let url = extras
616        .doh_url
617        .clone()
618        .unwrap_or_else(|| DEFAULT_DOH_URL.to_string());
619    let query = query.to_vec();
620    let buf_len = buf.len();
621
622    let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
623        handle.block_on(async move {
624            match tokio::time::timeout(timeout, doh_roundtrip(&url, &query, timeout, buf_len)).await
625            {
626                Ok(inner) => inner,
627                Err(_) => Err(UpstreamError::Timeout),
628            }
629        })
630    });
631    let resp = result?;
632    if resp.len() > buf.len() {
633        return Err(UpstreamError::Io(io::Error::new(
634            io::ErrorKind::InvalidData,
635            "doh response exceeds out buffer",
636        )));
637    }
638    buf[..resp.len()].copy_from_slice(&resp);
639    Ok(resp.len())
640}
641
642/// Async DoH roundtrip — single POST with the RFC 8484 content-type contract.
643async fn doh_roundtrip(
644    url: &str,
645    query: &[u8],
646    timeout: Duration,
647    out_cap: usize,
648) -> Result<Vec<u8>, UpstreamError> {
649    // reqwest client constructed per call. No pool reuse in this slot — the
650    // upstream forward path is already on the slow path (every cache miss
651    // pays the TLS handshake), and connection reuse would require a static
652    // client + careful cross-cell isolation. We keep it stateless to match
653    // DoT.
654    let client = reqwest::Client::builder()
655        .timeout(timeout)
656        .build()
657        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doh client build: {e}"))))?;
658
659    let resp = client
660        .post(url)
661        .header(reqwest::header::CONTENT_TYPE, "application/dns-message")
662        .header(reqwest::header::ACCEPT, "application/dns-message")
663        .body(query.to_vec())
664        .send()
665        .await
666        .map_err(|e| {
667            // reqwest's timeout-error variant maps to our `Timeout`; everything
668            // else (DNS failure on the URL host, TLS handshake, connection
669            // refused) lands on `Io` so triage can tell "the resolver answered"
670            // from "the resolver didn't answer" via the discriminator.
671            if e.is_timeout() {
672                UpstreamError::Timeout
673            } else {
674                UpstreamError::Io(io::Error::other(format!("doh request: {e}")))
675            }
676        })?;
677
678    if !resp.status().is_success() {
679        return Err(UpstreamError::Io(io::Error::other(format!(
680            "doh upstream returned HTTP {}",
681            resp.status()
682        ))));
683    }
684
685    let bytes = resp
686        .bytes()
687        .await
688        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doh body: {e}"))))?;
689    if bytes.len() > out_cap {
690        return Err(UpstreamError::Io(io::Error::new(
691            io::ErrorKind::InvalidData,
692            "doh response exceeds out buffer",
693        )));
694    }
695    Ok(bytes.to_vec())
696}
697
698/// DNS-DOQ-1 — DoQ (RFC 9250) forward path.
699///
700/// Synchronous entry point. QUIC connection + bidi stream + 2-byte
701/// length-prefixed query/response. Defaults to `1.1.1.1:853` if no operator
702/// override is present.
703fn forward_doq(
704    query: &[u8],
705    buf: &mut [u8],
706    timeout: Duration,
707    extras: &UpstreamExtras,
708) -> Result<usize, UpstreamError> {
709    let handle =
710        tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
711
712    let server = extras
713        .doq_server
714        .clone()
715        .unwrap_or_else(|| DEFAULT_DOQ_SERVER.to_string());
716    let port = extras.doq_port.unwrap_or(DEFAULT_DOQ_PORT);
717    let query = query.to_vec();
718    let buf_len = buf.len();
719
720    let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
721        handle.block_on(async move {
722            match tokio::time::timeout(timeout, doq_roundtrip(&server, port, &query, buf_len)).await
723            {
724                Ok(inner) => inner,
725                Err(_) => Err(UpstreamError::Timeout),
726            }
727        })
728    });
729    let resp = result?;
730    if resp.len() > buf.len() {
731        return Err(UpstreamError::Io(io::Error::new(
732            io::ErrorKind::InvalidData,
733            "doq response exceeds out buffer",
734        )));
735    }
736    buf[..resp.len()].copy_from_slice(&resp);
737    Ok(resp.len())
738}
739
740/// Async DoQ roundtrip — quinn 0.11 client endpoint, bidi stream, RFC 9250
741/// length-prefixed framing.
742async fn doq_roundtrip(
743    server: &str,
744    port: u16,
745    query: &[u8],
746    out_cap: usize,
747) -> Result<Vec<u8>, UpstreamError> {
748    use std::net::IpAddr;
749
750    // Resolve `server` to an `IpAddr` + keep the original string for SNI. If
751    // the operator supplied an IP literal, we use that for both the QUIC
752    // target and as a fallback SNI; if it's a hostname, we rely on
753    // `tokio::net::lookup_host` for the address and pass the hostname through
754    // as the rustls SNI (matching standard browser-style DoQ resolvers).
755    let (target_addr, sni): (SocketAddr, ServerName<'static>) =
756        if let Ok(ip) = server.parse::<IpAddr>() {
757            let sa = SocketAddr::new(ip, port);
758            let sni = ServerName::IpAddress(match ip {
759                IpAddr::V4(v4) => rustls::pki_types::IpAddr::V4(v4.into()),
760                IpAddr::V6(v6) => rustls::pki_types::IpAddr::V6(v6.into()),
761            });
762            (sa, sni)
763        } else {
764            // Hostname path — first address wins. This uses the OS resolver
765            // (NOT the supervisor's bootstrap path), so operators in a
766            // sealed netns should prefer IP literals.
767            let mut iter = tokio::net::lookup_host((server, port)).await.map_err(|e| {
768                UpstreamError::Io(io::Error::new(
769                    e.kind(),
770                    format!("doq lookup '{server}': {e}"),
771                ))
772            })?;
773            let sa = iter.next().ok_or_else(|| {
774                UpstreamError::Io(io::Error::new(
775                    io::ErrorKind::AddrNotAvailable,
776                    format!("doq lookup '{server}' returned no addresses"),
777                ))
778            })?;
779            let sni = ServerName::try_from(server.to_string())
780                .map_err(|e| UpstreamError::TlsHandshake(format!("invalid sni '{server}': {e}")))?;
781            (sa, sni)
782        };
783
784    // Bind a local UDP socket for the QUIC client. `0.0.0.0:0` lets the OS
785    // pick an ephemeral port; the QUIC endpoint owns this socket for the
786    // life of the connection.
787    let bind_addr: SocketAddr = match target_addr {
788        SocketAddr::V4(_) => "0.0.0.0:0".parse().unwrap(),
789        SocketAddr::V6(_) => "[::]:0".parse().unwrap(),
790    };
791    let mut endpoint = quinn::Endpoint::client(bind_addr)
792        .map_err(|e| UpstreamError::Io(io::Error::new(e.kind(), format!("doq endpoint: {e}"))))?;
793
794    // rustls config with ALPN = "doq" (RFC 9250 §4.1.1).
795    let mut roots = RootCertStore::empty();
796    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
797    let provider = Arc::new(rustls::crypto::ring::default_provider());
798    let mut crypto = ClientConfig::builder_with_provider(provider)
799        .with_safe_default_protocol_versions()
800        .map_err(|e| UpstreamError::TlsHandshake(format!("doq rustls protocols: {e}")))?
801        .with_root_certificates(roots)
802        .with_no_client_auth();
803    crypto.alpn_protocols = vec![b"doq".to_vec()];
804
805    let quic_crypto = quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
806        .map_err(|e| UpstreamError::TlsHandshake(format!("doq quic crypto: {e}")))?;
807    let client_config = quinn::ClientConfig::new(Arc::new(quic_crypto));
808    endpoint.set_default_client_config(client_config);
809
810    // SNI server-name string — rustls wants `&str`, but quinn's `connect`
811    // takes `&str` separately from the rustls config. We extract it back
812    // from `sni` for the API call.
813    let sni_str: String = match &sni {
814        ServerName::DnsName(d) => d.as_ref().to_string(),
815        ServerName::IpAddress(_) => server.to_string(),
816        _ => server.to_string(),
817    };
818
819    let connecting = endpoint
820        .connect(target_addr, &sni_str)
821        .map_err(|e| UpstreamError::TlsHandshake(format!("doq connect: {e}")))?;
822    let connection = connecting
823        .await
824        .map_err(|e| UpstreamError::TlsHandshake(format!("doq handshake: {e}")))?;
825
826    let (mut send, mut recv) = connection
827        .open_bi()
828        .await
829        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq open_bi: {e}"))))?;
830
831    // RFC 9250 §4.2.1 — 2-byte big-endian length prefix on both query and
832    // response. The query carries the same wire format as DNS-over-TCP.
833    let len = query.len() as u16;
834    send.write_all(&len.to_be_bytes())
835        .await
836        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq send len: {e}"))))?;
837    send.write_all(query)
838        .await
839        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq send body: {e}"))))?;
840    send.finish()
841        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq finish: {e}"))))?;
842
843    let mut len_buf = [0u8; 2];
844    recv.read_exact(&mut len_buf)
845        .await
846        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq recv len: {e}"))))?;
847    let resp_len = u16::from_be_bytes(len_buf) as usize;
848    if resp_len > out_cap {
849        // Close the connection gracefully before bailing — keeps the QUIC
850        // peer's state machine clean.
851        connection.close(0u32.into(), b"oversized");
852        endpoint.wait_idle().await;
853        return Err(UpstreamError::Io(io::Error::new(
854            io::ErrorKind::InvalidData,
855            "doq response exceeds out buffer",
856        )));
857    }
858    let mut resp = vec![0u8; resp_len];
859    recv.read_exact(&mut resp)
860        .await
861        .map_err(|e| UpstreamError::Io(io::Error::other(format!("doq recv body: {e}"))))?;
862
863    connection.close(0u32.into(), b"done");
864    endpoint.wait_idle().await;
865    Ok(resp)
866}
867
868/// Build a default rustls 0.23 client config with Mozilla CA roots from
869/// `webpki-roots = "1"`.
870///
871/// Explicitly threads the `ring` crypto provider so we don't depend on a
872/// process-wide default (rustls 0.23's `ClientConfig::builder()` panics if
873/// neither provider has been installed via `install_default()` and the
874/// process picks neither feature unambiguously). `cellos-supervisor` pins
875/// the `ring` feature in its Cargo.toml so this is the intended provider.
876fn build_dot_client_config() -> ClientConfig {
877    let mut roots = RootCertStore::empty();
878    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
879    let provider = Arc::new(rustls::crypto::ring::default_provider());
880    ClientConfig::builder_with_provider(provider)
881        .with_safe_default_protocol_versions()
882        .expect("ring provider supports default rustls protocol versions")
883        .with_root_certificates(roots)
884        .with_no_client_auth()
885}
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890    use std::sync::Mutex;
891
892    /// Tests in this module that mutate `CELLOS_DNS_UPSTREAM_*` env vars must
893    /// hold this mutex for their entire duration — cargo runs tests in
894    /// parallel by default and the env namespace is process-global.
895    static ENV_LOCK: Mutex<()> = Mutex::new(());
896
897    #[test]
898    fn parse_canonical_names() {
899        assert_eq!(
900            UpstreamTransport::parse("do53-udp"),
901            Some(UpstreamTransport::Do53Udp)
902        );
903        assert_eq!(
904            UpstreamTransport::parse("do53-tcp"),
905            Some(UpstreamTransport::Do53Tcp)
906        );
907        assert_eq!(
908            UpstreamTransport::parse("dot"),
909            Some(UpstreamTransport::Dot)
910        );
911        assert_eq!(
912            UpstreamTransport::parse("doh"),
913            Some(UpstreamTransport::Doh)
914        );
915        assert_eq!(
916            UpstreamTransport::parse("doq"),
917            Some(UpstreamTransport::Doq)
918        );
919    }
920
921    #[test]
922    fn parse_aliases_case_insensitive() {
923        assert_eq!(
924            UpstreamTransport::parse("UDP"),
925            Some(UpstreamTransport::Do53Udp)
926        );
927        assert_eq!(
928            UpstreamTransport::parse("TCP"),
929            Some(UpstreamTransport::Do53Tcp)
930        );
931        assert_eq!(
932            UpstreamTransport::parse("Tls"),
933            Some(UpstreamTransport::Dot)
934        );
935        assert_eq!(
936            UpstreamTransport::parse("HTTPS"),
937            Some(UpstreamTransport::Doh)
938        );
939        assert_eq!(
940            UpstreamTransport::parse("quic"),
941            Some(UpstreamTransport::Doq)
942        );
943    }
944
945    #[test]
946    fn parse_rejects_unknown() {
947        assert_eq!(UpstreamTransport::parse("dnscrypt"), None);
948        assert_eq!(UpstreamTransport::parse("xxx"), None);
949    }
950
951    #[test]
952    fn default_is_udp() {
953        assert_eq!(UpstreamTransport::default(), UpstreamTransport::Do53Udp);
954    }
955
956    #[test]
957    fn label_round_trips() {
958        for t in [
959            UpstreamTransport::Do53Udp,
960            UpstreamTransport::Do53Tcp,
961            UpstreamTransport::Dot,
962            UpstreamTransport::Doh,
963            UpstreamTransport::Doq,
964        ] {
965            assert_eq!(UpstreamTransport::parse(t.label()), Some(t));
966        }
967    }
968
969    #[test]
970    fn extras_from_env_reads_doh_url() {
971        // DNS-DOH-2 — `CELLOS_DNS_UPSTREAM_DOH_URL` lands on extras.doh_url.
972        let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
973        let saved = std::env::var("CELLOS_DNS_UPSTREAM_DOH_URL").ok();
974        unsafe {
975            std::env::set_var(
976                "CELLOS_DNS_UPSTREAM_DOH_URL",
977                "https://cloudflare-dns.com/dns-query",
978            );
979        }
980        let extras = UpstreamExtras::from_env();
981        assert_eq!(
982            extras.doh_url.as_deref(),
983            Some("https://cloudflare-dns.com/dns-query")
984        );
985        unsafe {
986            match saved {
987                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOH_URL", v),
988                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOH_URL"),
989            }
990        }
991    }
992
993    #[test]
994    fn extras_from_env_reads_doq_server_and_port() {
995        // DNS-DOQ-2 — server + port env wiring lands on extras.
996        let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
997        let saved = (
998            std::env::var("CELLOS_DNS_UPSTREAM_DOQ_SERVER").ok(),
999            std::env::var("CELLOS_DNS_UPSTREAM_DOQ_PORT").ok(),
1000        );
1001        unsafe {
1002            std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER", "9.9.9.9");
1003            std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_PORT", "8853");
1004        }
1005        let extras = UpstreamExtras::from_env();
1006        assert_eq!(extras.doq_server.as_deref(), Some("9.9.9.9"));
1007        assert_eq!(extras.doq_port, Some(8853));
1008        unsafe {
1009            match saved.0 {
1010                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER", v),
1011                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER"),
1012            }
1013            match saved.1 {
1014                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_PORT", v),
1015                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOQ_PORT"),
1016            }
1017        }
1018    }
1019
1020    #[test]
1021    fn parse_dot_target_accepts_bare_ipv4() {
1022        let sa = parse_dot_target("1.1.1.1", None).expect("bare ipv4 parses");
1023        assert_eq!(sa, "1.1.1.1:853".parse::<SocketAddr>().unwrap());
1024    }
1025
1026    #[test]
1027    fn parse_dot_target_accepts_bare_ipv4_with_port_override() {
1028        let sa = parse_dot_target("9.9.9.9", Some(8853)).expect("ipv4 + override parses");
1029        assert_eq!(sa, "9.9.9.9:8853".parse::<SocketAddr>().unwrap());
1030    }
1031
1032    #[test]
1033    fn parse_dot_target_accepts_ipv4_with_inline_port() {
1034        // Inline port wins over override (operator wrote it explicitly).
1035        let sa = parse_dot_target("1.1.1.1:9999", Some(853)).expect("inline port parses");
1036        assert_eq!(sa, "1.1.1.1:9999".parse::<SocketAddr>().unwrap());
1037    }
1038
1039    #[test]
1040    fn parse_dot_target_accepts_bracketed_ipv6() {
1041        let sa =
1042            parse_dot_target("[2606:4700:4700::1111]:853", None).expect("bracketed ipv6 parses");
1043        assert_eq!(
1044            sa,
1045            "[2606:4700:4700::1111]:853".parse::<SocketAddr>().unwrap()
1046        );
1047    }
1048
1049    #[test]
1050    fn parse_dot_target_rejects_hostname() {
1051        let err = parse_dot_target("dns.example.com", None)
1052            .expect_err("hostname must be rejected (no DNS bootstrap in netns)");
1053        assert!(err.contains("hostnames must be pre-resolved"));
1054    }
1055
1056    #[test]
1057    fn parse_dot_target_rejects_empty() {
1058        assert!(parse_dot_target("", None).is_err());
1059        assert!(parse_dot_target("   ", None).is_err());
1060    }
1061
1062    #[test]
1063    fn extras_from_env_reads_dot_server_port_sni() {
1064        let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
1065        // Save and restore env state so concurrent tests can't see our
1066        // overrides. (cargo test runs tests in parallel by default.)
1067        let saved = (
1068            std::env::var("CELLOS_DNS_UPSTREAM_DOT_SERVER").ok(),
1069            std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT").ok(),
1070            std::env::var("CELLOS_DNS_UPSTREAM_DOT_SNI").ok(),
1071        );
1072
1073        // Use a process-local mutex via a static OnceLock would be cleaner,
1074        // but the supervisor's existing tests rely on the same env-var
1075        // pattern (see upstream `CELLOS_DNS_UPSTREAM_TRANSPORT` tests in
1076        // supervisor.rs); we follow the convention and accept the
1077        // serial-test caveat for this unit test.
1078        // SAFETY: these env vars are only read by this test and by
1079        // `UpstreamExtras::from_env`; no other thread mutates them
1080        // concurrently within the cellos-supervisor test binary.
1081        unsafe {
1082            std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SERVER", "8.8.8.8");
1083            std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", "8853");
1084            std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SNI", "dns.google");
1085        }
1086
1087        let extras = UpstreamExtras::from_env();
1088        assert_eq!(extras.dot_server.as_deref(), Some("8.8.8.8"));
1089        assert_eq!(extras.dot_port, Some(8853));
1090        assert_eq!(extras.dot_sni.as_deref(), Some("dns.google"));
1091
1092        // Restore prior env state.
1093        unsafe {
1094            match saved.0 {
1095                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SERVER", v),
1096                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SERVER"),
1097            }
1098            match saved.1 {
1099                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", v),
1100                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_PORT"),
1101            }
1102            match saved.2 {
1103                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SNI", v),
1104                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SNI"),
1105            }
1106        }
1107    }
1108
1109    #[test]
1110    fn extras_from_env_ignores_unparseable_port() {
1111        let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
1112        let saved = std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT").ok();
1113        // Clear sibling vars so a parallel test's leftover state can't leak
1114        // into our observation. The ENV_LOCK serialises us, but tests prior
1115        // to the lock's introduction may have left state from a panic.
1116        unsafe {
1117            std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SERVER");
1118            std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SNI");
1119        }
1120        unsafe {
1121            std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", "not-a-number");
1122        }
1123        let extras = UpstreamExtras::from_env();
1124        assert_eq!(extras.dot_port, None);
1125        unsafe {
1126            match saved {
1127                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", v),
1128                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_PORT"),
1129            }
1130        }
1131    }
1132
1133    #[test]
1134    fn from_env_prefers_protocol_over_transport() {
1135        let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
1136        let saved = (
1137            std::env::var("CELLOS_DNS_UPSTREAM_PROTOCOL").ok(),
1138            std::env::var("CELLOS_DNS_UPSTREAM_TRANSPORT").ok(),
1139        );
1140        unsafe {
1141            std::env::set_var("CELLOS_DNS_UPSTREAM_PROTOCOL", "dot");
1142            std::env::set_var("CELLOS_DNS_UPSTREAM_TRANSPORT", "do53-udp");
1143        }
1144        // PROTOCOL takes priority — operator's contract-aligned name wins.
1145        assert_eq!(UpstreamTransport::from_env(), Some(UpstreamTransport::Dot));
1146        unsafe {
1147            std::env::remove_var("CELLOS_DNS_UPSTREAM_PROTOCOL");
1148        }
1149        // PROTOCOL unset, TRANSPORT still honoured (back-compat).
1150        assert_eq!(
1151            UpstreamTransport::from_env(),
1152            Some(UpstreamTransport::Do53Udp)
1153        );
1154        unsafe {
1155            match saved.0 {
1156                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_PROTOCOL", v),
1157                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_PROTOCOL"),
1158            }
1159            match saved.1 {
1160                Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_TRANSPORT", v),
1161                None => std::env::remove_var("CELLOS_DNS_UPSTREAM_TRANSPORT"),
1162            }
1163        }
1164    }
1165
1166    #[test]
1167    fn udp_path_round_trips_against_synthetic_upstream() {
1168        // Spawn a synthetic UDP echo-style upstream that replies with a
1169        // 13-byte canned packet for any incoming query. Confirms the
1170        // Do53Udp dispatch goes through `forward_udp` and the result is
1171        // surfaced byte-for-byte.
1172        let echo = UdpSocket::bind("127.0.0.1:0").unwrap();
1173        echo.set_read_timeout(Some(Duration::from_millis(500)))
1174            .unwrap();
1175        let echo_addr = echo.local_addr().unwrap();
1176        std::thread::spawn(move || {
1177            let mut b = [0u8; 1500];
1178            if let Ok((_n, peer)) = echo.recv_from(&mut b) {
1179                let _ = echo.send_to(b"\x00\x00ABCDEFGHIJK", peer);
1180            }
1181        });
1182        let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1183        let mut out = [0u8; 1500];
1184        let n = forward(
1185            UpstreamTransport::Do53Udp,
1186            &upstream_sock,
1187            echo_addr,
1188            b"\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00",
1189            &mut out,
1190            Duration::from_millis(500),
1191            &UpstreamExtras::default(),
1192        )
1193        .expect("udp round-trip");
1194        assert_eq!(n, 13);
1195        assert_eq!(&out[..13], b"\x00\x00ABCDEFGHIJK");
1196    }
1197}