use std::io;
use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc;
use std::time::Duration;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, RootCertStore};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_rustls::TlsConnector;
const MAX_PAYLOAD: usize = 1500;
const DEFAULT_DOT_PORT: u16 = 853;
const DEFAULT_DOH_URL: &str = "https://1.1.1.1/dns-query";
const DEFAULT_DOQ_SERVER: &str = "1.1.1.1";
const DEFAULT_DOQ_PORT: u16 = 853;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum UpstreamTransport {
#[default]
Do53Udp,
Do53Tcp,
Dot,
Doh,
Doq,
}
impl UpstreamTransport {
pub fn parse(s: &str) -> Option<Self> {
match s.trim().to_ascii_lowercase().as_str() {
"do53-udp" | "do53udp" | "udp" | "" => Some(Self::Do53Udp),
"do53-tcp" | "do53tcp" | "tcp" => Some(Self::Do53Tcp),
"dot" | "tls" | "do-tls" => Some(Self::Dot),
"doh" | "https" | "do-https" => Some(Self::Doh),
"doq" | "quic" | "do-quic" => Some(Self::Doq),
_ => None,
}
}
pub fn from_env() -> Option<Self> {
for var in [
"CELLOS_DNS_UPSTREAM_PROTOCOL",
"CELLOS_DNS_UPSTREAM_TRANSPORT",
] {
match std::env::var(var) {
Ok(s) if !s.trim().is_empty() => return Self::parse(&s),
_ => continue,
}
}
Some(Self::Do53Udp)
}
pub fn label(self) -> &'static str {
match self {
Self::Do53Udp => "do53-udp",
Self::Do53Tcp => "do53-tcp",
Self::Dot => "dot",
Self::Doh => "doh",
Self::Doq => "doq",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct UpstreamExtras {
pub dot_sni: Option<String>,
pub dot_server: Option<String>,
pub dot_port: Option<u16>,
pub doh_url: Option<String>,
pub doq_server: Option<String>,
pub doq_port: Option<u16>,
}
impl UpstreamExtras {
pub fn from_env() -> Self {
let dot_server = std::env::var("CELLOS_DNS_UPSTREAM_DOT_SERVER")
.ok()
.filter(|s| !s.trim().is_empty());
let dot_port = std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT")
.ok()
.and_then(|s| s.trim().parse::<u16>().ok());
let dot_sni = std::env::var("CELLOS_DNS_UPSTREAM_DOT_SNI")
.ok()
.filter(|s| !s.trim().is_empty());
let doh_url = std::env::var("CELLOS_DNS_UPSTREAM_DOH_URL")
.ok()
.filter(|s| !s.trim().is_empty());
let doq_server = std::env::var("CELLOS_DNS_UPSTREAM_DOQ_SERVER")
.ok()
.filter(|s| !s.trim().is_empty());
let doq_port = std::env::var("CELLOS_DNS_UPSTREAM_DOQ_PORT")
.ok()
.and_then(|s| s.trim().parse::<u16>().ok());
Self {
dot_sni,
dot_server,
dot_port,
doh_url,
doq_server,
doq_port,
}
}
}
#[derive(Debug)]
pub enum UpstreamError {
Timeout,
Io(io::Error),
TransportNotEnabled(UpstreamTransport),
TlsHandshake(String),
NoTokioRuntime,
}
impl std::fmt::Display for UpstreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout => write!(f, "upstream timeout"),
Self::Io(e) => write!(f, "upstream io: {e}"),
Self::TransportNotEnabled(t) => {
write!(
f,
"upstream transport '{}' not enabled in this build",
t.label()
)
}
Self::TlsHandshake(msg) => write!(f, "tls handshake: {msg}"),
Self::NoTokioRuntime => write!(f, "no tokio runtime in scope for async upstream"),
}
}
}
impl std::error::Error for UpstreamError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for UpstreamError {
fn from(e: io::Error) -> Self {
if matches!(
e.kind(),
io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
) {
Self::Timeout
} else {
Self::Io(e)
}
}
}
pub fn forward(
transport: UpstreamTransport,
udp_socket: &UdpSocket,
upstream: SocketAddr,
query: &[u8],
out_buf: &mut [u8],
timeout: Duration,
extras: &UpstreamExtras,
) -> Result<usize, UpstreamError> {
if query.len() > MAX_PAYLOAD {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"query exceeds MAX_PAYLOAD",
)));
}
match transport {
UpstreamTransport::Do53Udp => forward_udp(udp_socket, upstream, query, out_buf, timeout),
UpstreamTransport::Do53Tcp => forward_tcp(upstream, query, out_buf, timeout),
UpstreamTransport::Dot => forward_dot(upstream, query, out_buf, timeout, extras),
UpstreamTransport::Doh => forward_doh(query, out_buf, timeout, extras),
UpstreamTransport::Doq => forward_doq(query, out_buf, timeout, extras),
}
}
fn forward_udp(
upstream: &UdpSocket,
addr: SocketAddr,
query: &[u8],
buf: &mut [u8],
timeout: Duration,
) -> Result<usize, UpstreamError> {
upstream.send_to(query, addr)?;
upstream.set_read_timeout(Some(timeout))?;
let deadline = std::time::Instant::now() + timeout;
loop {
match upstream.recv_from(buf) {
Ok((n, _peer)) => return Ok(n),
Err(e)
if matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
) =>
{
if std::time::Instant::now() >= deadline {
return Err(UpstreamError::Timeout);
}
}
Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
Err(e) => return Err(UpstreamError::Io(e)),
}
}
}
fn forward_tcp(
upstream: SocketAddr,
query: &[u8],
buf: &mut [u8],
timeout: Duration,
) -> Result<usize, UpstreamError> {
use std::io::{Read, Write};
use std::net::TcpStream;
let mut stream = TcpStream::connect_timeout(&upstream, timeout)?;
stream.set_read_timeout(Some(timeout))?;
stream.set_write_timeout(Some(timeout))?;
let len = query.len() as u16;
stream.write_all(&len.to_be_bytes())?;
stream.write_all(query)?;
let mut len_buf = [0u8; 2];
stream.read_exact(&mut len_buf)?;
let resp_len = u16::from_be_bytes(len_buf) as usize;
if resp_len > buf.len() {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"tcp response exceeds out buffer",
)));
}
stream.read_exact(&mut buf[..resp_len])?;
Ok(resp_len)
}
fn forward_dot(
upstream: SocketAddr,
query: &[u8],
buf: &mut [u8],
timeout: Duration,
extras: &UpstreamExtras,
) -> Result<usize, UpstreamError> {
let handle =
tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
let mut target = match extras.dot_server.as_deref() {
Some(host) => parse_dot_target(host, extras.dot_port).map_err(|e| {
UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!("CELLOS_DNS_UPSTREAM_DOT_SERVER='{host}' did not parse: {e}"),
))
})?,
None => {
let mut t = upstream;
if let Some(p) = extras.dot_port {
t.set_port(p);
}
t
}
};
if target.port() == 0 {
target.set_port(DEFAULT_DOT_PORT);
}
let sni = extras.dot_sni.clone();
let query = query.to_vec();
let buf_len = buf.len();
let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
handle.block_on(async move {
match tokio::time::timeout(timeout, dot_roundtrip(target, &query, &sni, buf_len)).await
{
Ok(inner) => inner,
Err(_) => Err(UpstreamError::Timeout),
}
})
});
let resp = result?;
if resp.len() > buf.len() {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"dot response exceeds out buffer",
)));
}
buf[..resp.len()].copy_from_slice(&resp);
Ok(resp.len())
}
fn parse_dot_target(host: &str, port_override: Option<u16>) -> Result<SocketAddr, String> {
let trimmed = host.trim();
if trimmed.is_empty() {
return Err("empty string".to_string());
}
if let Ok(sa) = trimmed.parse::<SocketAddr>() {
return Ok(sa);
}
let port = port_override.unwrap_or(DEFAULT_DOT_PORT);
let ip: std::net::IpAddr = trimmed.parse().map_err(|_| {
format!(
"'{trimmed}' is not an IP literal (hostnames must be pre-resolved by the supervisor)"
)
})?;
Ok(SocketAddr::new(ip, port))
}
async fn dot_roundtrip(
target: SocketAddr,
query: &[u8],
sni: &Option<String>,
out_cap: usize,
) -> Result<Vec<u8>, UpstreamError> {
let config = build_dot_client_config();
let connector = TlsConnector::from(Arc::new(config));
let tcp = tokio::net::TcpStream::connect(target).await?;
let server_name: ServerName<'static> = match sni {
Some(host) if !host.is_empty() => ServerName::try_from(host.clone())
.map_err(|e| UpstreamError::TlsHandshake(format!("invalid sni '{host}': {e}")))?,
_ => ServerName::IpAddress(target.ip().into()),
};
let mut tls = connector
.connect(server_name, tcp)
.await
.map_err(|e| UpstreamError::TlsHandshake(format!("{e}")))?;
let len = query.len() as u16;
tls.write_all(&len.to_be_bytes()).await?;
tls.write_all(query).await?;
tls.flush().await?;
let mut len_buf = [0u8; 2];
tls.read_exact(&mut len_buf).await?;
let resp_len = u16::from_be_bytes(len_buf) as usize;
if resp_len > out_cap {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"dot response exceeds out buffer",
)));
}
let mut resp = vec![0u8; resp_len];
tls.read_exact(&mut resp).await?;
Ok(resp)
}
fn forward_doh(
query: &[u8],
buf: &mut [u8],
timeout: Duration,
extras: &UpstreamExtras,
) -> Result<usize, UpstreamError> {
let handle =
tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
let url = extras
.doh_url
.clone()
.unwrap_or_else(|| DEFAULT_DOH_URL.to_string());
let query = query.to_vec();
let buf_len = buf.len();
let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
handle.block_on(async move {
match tokio::time::timeout(timeout, doh_roundtrip(&url, &query, timeout, buf_len)).await
{
Ok(inner) => inner,
Err(_) => Err(UpstreamError::Timeout),
}
})
});
let resp = result?;
if resp.len() > buf.len() {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"doh response exceeds out buffer",
)));
}
buf[..resp.len()].copy_from_slice(&resp);
Ok(resp.len())
}
async fn doh_roundtrip(
url: &str,
query: &[u8],
timeout: Duration,
out_cap: usize,
) -> Result<Vec<u8>, UpstreamError> {
let client = reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doh client build: {e}"))))?;
let resp = client
.post(url)
.header(reqwest::header::CONTENT_TYPE, "application/dns-message")
.header(reqwest::header::ACCEPT, "application/dns-message")
.body(query.to_vec())
.send()
.await
.map_err(|e| {
if e.is_timeout() {
UpstreamError::Timeout
} else {
UpstreamError::Io(io::Error::other(format!("doh request: {e}")))
}
})?;
if !resp.status().is_success() {
return Err(UpstreamError::Io(io::Error::other(format!(
"doh upstream returned HTTP {}",
resp.status()
))));
}
let bytes = resp
.bytes()
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doh body: {e}"))))?;
if bytes.len() > out_cap {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"doh response exceeds out buffer",
)));
}
Ok(bytes.to_vec())
}
fn forward_doq(
query: &[u8],
buf: &mut [u8],
timeout: Duration,
extras: &UpstreamExtras,
) -> Result<usize, UpstreamError> {
let handle =
tokio::runtime::Handle::try_current().map_err(|_| UpstreamError::NoTokioRuntime)?;
let server = extras
.doq_server
.clone()
.unwrap_or_else(|| DEFAULT_DOQ_SERVER.to_string());
let port = extras.doq_port.unwrap_or(DEFAULT_DOQ_PORT);
let query = query.to_vec();
let buf_len = buf.len();
let result: Result<Vec<u8>, UpstreamError> = tokio::task::block_in_place(|| {
handle.block_on(async move {
match tokio::time::timeout(timeout, doq_roundtrip(&server, port, &query, buf_len)).await
{
Ok(inner) => inner,
Err(_) => Err(UpstreamError::Timeout),
}
})
});
let resp = result?;
if resp.len() > buf.len() {
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"doq response exceeds out buffer",
)));
}
buf[..resp.len()].copy_from_slice(&resp);
Ok(resp.len())
}
async fn doq_roundtrip(
server: &str,
port: u16,
query: &[u8],
out_cap: usize,
) -> Result<Vec<u8>, UpstreamError> {
use std::net::IpAddr;
let (target_addr, sni): (SocketAddr, ServerName<'static>) =
if let Ok(ip) = server.parse::<IpAddr>() {
let sa = SocketAddr::new(ip, port);
let sni = ServerName::IpAddress(match ip {
IpAddr::V4(v4) => rustls::pki_types::IpAddr::V4(v4.into()),
IpAddr::V6(v6) => rustls::pki_types::IpAddr::V6(v6.into()),
});
(sa, sni)
} else {
let mut iter = tokio::net::lookup_host((server, port)).await.map_err(|e| {
UpstreamError::Io(io::Error::new(
e.kind(),
format!("doq lookup '{server}': {e}"),
))
})?;
let sa = iter.next().ok_or_else(|| {
UpstreamError::Io(io::Error::new(
io::ErrorKind::AddrNotAvailable,
format!("doq lookup '{server}' returned no addresses"),
))
})?;
let sni = ServerName::try_from(server.to_string())
.map_err(|e| UpstreamError::TlsHandshake(format!("invalid sni '{server}': {e}")))?;
(sa, sni)
};
let bind_addr: SocketAddr = match target_addr {
SocketAddr::V4(_) => "0.0.0.0:0".parse().unwrap(),
SocketAddr::V6(_) => "[::]:0".parse().unwrap(),
};
let mut endpoint = quinn::Endpoint::client(bind_addr)
.map_err(|e| UpstreamError::Io(io::Error::new(e.kind(), format!("doq endpoint: {e}"))))?;
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut crypto = ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| UpstreamError::TlsHandshake(format!("doq rustls protocols: {e}")))?
.with_root_certificates(roots)
.with_no_client_auth();
crypto.alpn_protocols = vec![b"doq".to_vec()];
let quic_crypto = quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
.map_err(|e| UpstreamError::TlsHandshake(format!("doq quic crypto: {e}")))?;
let client_config = quinn::ClientConfig::new(Arc::new(quic_crypto));
endpoint.set_default_client_config(client_config);
let sni_str: String = match &sni {
ServerName::DnsName(d) => d.as_ref().to_string(),
ServerName::IpAddress(_) => server.to_string(),
_ => server.to_string(),
};
let connecting = endpoint
.connect(target_addr, &sni_str)
.map_err(|e| UpstreamError::TlsHandshake(format!("doq connect: {e}")))?;
let connection = connecting
.await
.map_err(|e| UpstreamError::TlsHandshake(format!("doq handshake: {e}")))?;
let (mut send, mut recv) = connection
.open_bi()
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq open_bi: {e}"))))?;
let len = query.len() as u16;
send.write_all(&len.to_be_bytes())
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq send len: {e}"))))?;
send.write_all(query)
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq send body: {e}"))))?;
send.finish()
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq finish: {e}"))))?;
let mut len_buf = [0u8; 2];
recv.read_exact(&mut len_buf)
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq recv len: {e}"))))?;
let resp_len = u16::from_be_bytes(len_buf) as usize;
if resp_len > out_cap {
connection.close(0u32.into(), b"oversized");
endpoint.wait_idle().await;
return Err(UpstreamError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"doq response exceeds out buffer",
)));
}
let mut resp = vec![0u8; resp_len];
recv.read_exact(&mut resp)
.await
.map_err(|e| UpstreamError::Io(io::Error::other(format!("doq recv body: {e}"))))?;
connection.close(0u32.into(), b"done");
endpoint.wait_idle().await;
Ok(resp)
}
fn build_dot_client_config() -> ClientConfig {
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let provider = Arc::new(rustls::crypto::ring::default_provider());
ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.expect("ring provider supports default rustls protocol versions")
.with_root_certificates(roots)
.with_no_client_auth()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn parse_canonical_names() {
assert_eq!(
UpstreamTransport::parse("do53-udp"),
Some(UpstreamTransport::Do53Udp)
);
assert_eq!(
UpstreamTransport::parse("do53-tcp"),
Some(UpstreamTransport::Do53Tcp)
);
assert_eq!(
UpstreamTransport::parse("dot"),
Some(UpstreamTransport::Dot)
);
assert_eq!(
UpstreamTransport::parse("doh"),
Some(UpstreamTransport::Doh)
);
assert_eq!(
UpstreamTransport::parse("doq"),
Some(UpstreamTransport::Doq)
);
}
#[test]
fn parse_aliases_case_insensitive() {
assert_eq!(
UpstreamTransport::parse("UDP"),
Some(UpstreamTransport::Do53Udp)
);
assert_eq!(
UpstreamTransport::parse("TCP"),
Some(UpstreamTransport::Do53Tcp)
);
assert_eq!(
UpstreamTransport::parse("Tls"),
Some(UpstreamTransport::Dot)
);
assert_eq!(
UpstreamTransport::parse("HTTPS"),
Some(UpstreamTransport::Doh)
);
assert_eq!(
UpstreamTransport::parse("quic"),
Some(UpstreamTransport::Doq)
);
}
#[test]
fn parse_rejects_unknown() {
assert_eq!(UpstreamTransport::parse("dnscrypt"), None);
assert_eq!(UpstreamTransport::parse("xxx"), None);
}
#[test]
fn default_is_udp() {
assert_eq!(UpstreamTransport::default(), UpstreamTransport::Do53Udp);
}
#[test]
fn label_round_trips() {
for t in [
UpstreamTransport::Do53Udp,
UpstreamTransport::Do53Tcp,
UpstreamTransport::Dot,
UpstreamTransport::Doh,
UpstreamTransport::Doq,
] {
assert_eq!(UpstreamTransport::parse(t.label()), Some(t));
}
}
#[test]
fn extras_from_env_reads_doh_url() {
let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let saved = std::env::var("CELLOS_DNS_UPSTREAM_DOH_URL").ok();
unsafe {
std::env::set_var(
"CELLOS_DNS_UPSTREAM_DOH_URL",
"https://cloudflare-dns.com/dns-query",
);
}
let extras = UpstreamExtras::from_env();
assert_eq!(
extras.doh_url.as_deref(),
Some("https://cloudflare-dns.com/dns-query")
);
unsafe {
match saved {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOH_URL", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOH_URL"),
}
}
}
#[test]
fn extras_from_env_reads_doq_server_and_port() {
let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let saved = (
std::env::var("CELLOS_DNS_UPSTREAM_DOQ_SERVER").ok(),
std::env::var("CELLOS_DNS_UPSTREAM_DOQ_PORT").ok(),
);
unsafe {
std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER", "9.9.9.9");
std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_PORT", "8853");
}
let extras = UpstreamExtras::from_env();
assert_eq!(extras.doq_server.as_deref(), Some("9.9.9.9"));
assert_eq!(extras.doq_port, Some(8853));
unsafe {
match saved.0 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOQ_SERVER"),
}
match saved.1 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOQ_PORT", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOQ_PORT"),
}
}
}
#[test]
fn parse_dot_target_accepts_bare_ipv4() {
let sa = parse_dot_target("1.1.1.1", None).expect("bare ipv4 parses");
assert_eq!(sa, "1.1.1.1:853".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_dot_target_accepts_bare_ipv4_with_port_override() {
let sa = parse_dot_target("9.9.9.9", Some(8853)).expect("ipv4 + override parses");
assert_eq!(sa, "9.9.9.9:8853".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_dot_target_accepts_ipv4_with_inline_port() {
let sa = parse_dot_target("1.1.1.1:9999", Some(853)).expect("inline port parses");
assert_eq!(sa, "1.1.1.1:9999".parse::<SocketAddr>().unwrap());
}
#[test]
fn parse_dot_target_accepts_bracketed_ipv6() {
let sa =
parse_dot_target("[2606:4700:4700::1111]:853", None).expect("bracketed ipv6 parses");
assert_eq!(
sa,
"[2606:4700:4700::1111]:853".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn parse_dot_target_rejects_hostname() {
let err = parse_dot_target("dns.example.com", None)
.expect_err("hostname must be rejected (no DNS bootstrap in netns)");
assert!(err.contains("hostnames must be pre-resolved"));
}
#[test]
fn parse_dot_target_rejects_empty() {
assert!(parse_dot_target("", None).is_err());
assert!(parse_dot_target(" ", None).is_err());
}
#[test]
fn extras_from_env_reads_dot_server_port_sni() {
let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let saved = (
std::env::var("CELLOS_DNS_UPSTREAM_DOT_SERVER").ok(),
std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT").ok(),
std::env::var("CELLOS_DNS_UPSTREAM_DOT_SNI").ok(),
);
unsafe {
std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SERVER", "8.8.8.8");
std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", "8853");
std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SNI", "dns.google");
}
let extras = UpstreamExtras::from_env();
assert_eq!(extras.dot_server.as_deref(), Some("8.8.8.8"));
assert_eq!(extras.dot_port, Some(8853));
assert_eq!(extras.dot_sni.as_deref(), Some("dns.google"));
unsafe {
match saved.0 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SERVER", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SERVER"),
}
match saved.1 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_PORT"),
}
match saved.2 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_SNI", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SNI"),
}
}
}
#[test]
fn extras_from_env_ignores_unparseable_port() {
let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let saved = std::env::var("CELLOS_DNS_UPSTREAM_DOT_PORT").ok();
unsafe {
std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SERVER");
std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_SNI");
}
unsafe {
std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", "not-a-number");
}
let extras = UpstreamExtras::from_env();
assert_eq!(extras.dot_port, None);
unsafe {
match saved {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_DOT_PORT", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_DOT_PORT"),
}
}
}
#[test]
fn from_env_prefers_protocol_over_transport() {
let _g = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let saved = (
std::env::var("CELLOS_DNS_UPSTREAM_PROTOCOL").ok(),
std::env::var("CELLOS_DNS_UPSTREAM_TRANSPORT").ok(),
);
unsafe {
std::env::set_var("CELLOS_DNS_UPSTREAM_PROTOCOL", "dot");
std::env::set_var("CELLOS_DNS_UPSTREAM_TRANSPORT", "do53-udp");
}
assert_eq!(UpstreamTransport::from_env(), Some(UpstreamTransport::Dot));
unsafe {
std::env::remove_var("CELLOS_DNS_UPSTREAM_PROTOCOL");
}
assert_eq!(
UpstreamTransport::from_env(),
Some(UpstreamTransport::Do53Udp)
);
unsafe {
match saved.0 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_PROTOCOL", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_PROTOCOL"),
}
match saved.1 {
Some(v) => std::env::set_var("CELLOS_DNS_UPSTREAM_TRANSPORT", v),
None => std::env::remove_var("CELLOS_DNS_UPSTREAM_TRANSPORT"),
}
}
}
#[test]
fn udp_path_round_trips_against_synthetic_upstream() {
let echo = UdpSocket::bind("127.0.0.1:0").unwrap();
echo.set_read_timeout(Some(Duration::from_millis(500)))
.unwrap();
let echo_addr = echo.local_addr().unwrap();
std::thread::spawn(move || {
let mut b = [0u8; 1500];
if let Ok((_n, peer)) = echo.recv_from(&mut b) {
let _ = echo.send_to(b"\x00\x00ABCDEFGHIJK", peer);
}
});
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut out = [0u8; 1500];
let n = forward(
UpstreamTransport::Do53Udp,
&upstream_sock,
echo_addr,
b"\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00",
&mut out,
Duration::from_millis(500),
&UpstreamExtras::default(),
)
.expect("udp round-trip");
assert_eq!(n, 13);
assert_eq!(&out[..13], b"\x00\x00ABCDEFGHIJK");
}
}