use std::{
net::{Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
};
use base64::Engine;
use netstack::{CreateSocket, netcore::Channel, netsock::TcpStream};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::watch,
time::timeout,
};
use ts_dns_wire::{Rcode, decode_query, encode_response};
use crate::magic_dns::{Decision, DnsView, decide, forward_query};
const MAX_REQUEST: usize = 8 * 1024;
const CLIENT_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_CLIENT_RESPONSE: usize = 64 * 1024;
pub(crate) async fn forward_doh(
channel: &Channel,
doh_addr: SocketAddr,
query: &[u8],
nxdomain: Vec<u8>,
) -> Vec<u8> {
match timeout(CLIENT_TIMEOUT, doh_round_trip(channel, doh_addr, query)).await {
Ok(Ok(resp)) if !resp.is_empty() => resp,
Ok(Ok(_)) => {
tracing::warn!(%doh_addr, "peerapi doh client: empty response from exit node");
nxdomain
}
Ok(Err(e)) => {
tracing::warn!(error = %e, %doh_addr, "peerapi doh client: delegation failed");
nxdomain
}
Err(_) => {
tracing::warn!(%doh_addr, "peerapi doh client: delegation timed out");
nxdomain
}
}
}
async fn doh_round_trip(
channel: &Channel,
doh_addr: SocketAddr,
query: &[u8],
) -> std::io::Result<Vec<u8>> {
let local = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
let mut stream = channel
.tcp_connect(local, doh_addr)
.await
.map_err(|e| std::io::Error::other(e.to_string()))?;
let request = format!(
"POST /dns-query HTTP/1.1\r\nHost: {doh_addr}\r\nContent-Type: application/dns-message\r\nAccept: application/dns-message\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
query.len()
);
stream.write_all(request.as_bytes()).await?;
stream.write_all(query).await?;
stream.flush().await?;
read_doh_response(&mut stream).await
}
async fn read_doh_response(stream: &mut TcpStream) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::with_capacity(1024);
let mut tmp = [0u8; 1024];
let header_end = loop {
if let Some(pos) = find_header_end(&buf) {
break pos;
}
if buf.len() > MAX_CLIENT_RESPONSE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"doh response headers too large",
));
}
let n = stream.read(&mut tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"eof before doh response headers",
));
}
buf.extend_from_slice(&tmp[..n]);
};
let content_length = parse_response_head(&buf)?;
let mut body = buf[header_end..].to_vec();
while body.len() < content_length {
let n = stream.read(&mut tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"eof before doh response body complete",
));
}
body.extend_from_slice(&tmp[..n]);
if body.len() > MAX_CLIENT_RESPONSE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"doh response body too large",
));
}
}
body.truncate(content_length);
Ok(body)
}
fn parse_response_head(buf: &[u8]) -> std::io::Result<usize> {
let mut headers = [httparse::EMPTY_HEADER; 32];
let mut resp = httparse::Response::new(&mut headers);
match resp.parse(buf) {
Ok(httparse::Status::Complete(_)) => {}
Ok(httparse::Status::Partial) | Err(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"malformed doh response headers",
));
}
}
if resp.code != Some(200) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("doh response status {:?}", resp.code),
));
}
let content_length = resp
.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case("content-length"))
.and_then(|h| std::str::from_utf8(h.value).ok())
.and_then(|v| v.trim().parse::<usize>().ok())
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"doh response missing length",
)
})?;
if content_length > MAX_CLIENT_RESPONSE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"doh response body too large",
));
}
Ok(content_length)
}
pub(crate) async fn handle_conn(
mut stream: TcpStream,
seed: Vec<u8>,
header_end: usize,
channel: &Channel,
view_rx: &watch::Receiver<Arc<DnsView>>,
forward_exit_egress: bool,
) -> std::io::Result<()> {
let request = match read_request(&mut stream, seed, header_end).await? {
Some(r) => r,
None => return Ok(()),
};
let query = match request {
DohRequest::TooLarge => {
return write_status(&mut stream, "413 Payload Too Large").await;
}
DohRequest::BadRequest => {
return write_status(&mut stream, "400 Bad Request").await;
}
DohRequest::NotFound => {
return write_status(&mut stream, "404 Not Found").await;
}
DohRequest::Query(bytes) => bytes,
};
let view = view_rx.borrow().clone();
let response = resolve(&view, &query, channel, forward_exit_egress).await;
write_dns_response(&mut stream, &response).await
}
async fn resolve(
view: &DnsView,
query: &[u8],
channel: &Channel,
forward_exit_egress: bool,
) -> Vec<u8> {
match server_decide(view, query, forward_exit_egress) {
ServerDecision::Reply(resp) => resp,
ServerDecision::Forward {
upstreams,
query,
nxdomain,
} => forward_query(channel, &upstreams, &query, nxdomain).await,
}
}
enum ServerDecision {
Reply(Vec<u8>),
Forward {
upstreams: Vec<SocketAddr>,
query: Vec<u8>,
nxdomain: Vec<u8>,
},
}
fn server_decide(view: &DnsView, query: &[u8], forward_exit_egress: bool) -> ServerDecision {
let Ok(decoded) = decode_query(query) else {
let id = if query.len() >= 2 {
u16::from_be_bytes([query[0], query[1]])
} else {
0
};
return ServerDecision::Reply(encode_formerr(id));
};
let canon = decoded.question.name.to_canon();
if view.cfg.exit_node_filters(&canon) {
return ServerDecision::Reply(encode_response(
decoded.id,
&decoded.question,
Rcode::Refused,
&[],
));
}
match decide(view, query) {
None => ServerDecision::Reply(encode_formerr(decoded.id)),
Some(Decision::Reply(resp)) => ServerDecision::Reply(resp),
Some(Decision::Forward {
upstreams,
query,
nxdomain,
recursive: _,
}) => {
if !forward_exit_egress {
return ServerDecision::Reply(encode_response(
decoded.id,
&decoded.question,
Rcode::Refused,
&[],
));
}
ServerDecision::Forward {
upstreams,
query,
nxdomain,
}
}
}
}
fn encode_formerr(id: u16) -> Vec<u8> {
let mut msg = vec![0u8; 12];
msg[0..2].copy_from_slice(&id.to_be_bytes());
msg[2] = 0x80; msg[3] = 0x01; msg
}
enum DohRequest {
Query(Vec<u8>),
TooLarge,
BadRequest,
NotFound,
}
async fn read_request(
stream: &mut TcpStream,
buf: Vec<u8>,
header_end: usize,
) -> std::io::Result<Option<DohRequest>> {
let mut tmp = [0u8; 1024];
let mut headers = [httparse::EMPTY_HEADER; 32];
let mut req = httparse::Request::new(&mut headers);
let parsed = match req.parse(&buf) {
Ok(httparse::Status::Complete(n)) => n,
Ok(httparse::Status::Partial) => return Ok(Some(DohRequest::BadRequest)),
Err(_) => return Ok(Some(DohRequest::BadRequest)),
};
debug_assert_eq!(parsed, header_end);
let method = req.method.unwrap_or("");
let path = req.path.unwrap_or("");
let (raw_path, query_str) = match path.split_once('?') {
Some((p, q)) => (p, Some(q)),
None => (path, None),
};
if raw_path != "/dns-query" {
return Ok(Some(DohRequest::NotFound));
}
match method {
"GET" => Ok(Some(parse_get(query_str))),
"POST" => {
let content_length =
header_value(&req, "content-length").and_then(|v| v.trim().parse::<usize>().ok());
let Some(len) = content_length else {
return Ok(Some(DohRequest::BadRequest));
};
if len > MAX_REQUEST {
return Ok(Some(DohRequest::TooLarge));
}
if !header_value(&req, "content-type")
.is_some_and(|v| v.trim().eq_ignore_ascii_case("application/dns-message"))
{
return Ok(Some(DohRequest::BadRequest));
}
let mut body = buf[header_end..].to_vec();
while body.len() < len {
if buf.len() + tmp.len() > MAX_REQUEST + 1024 {
return Ok(Some(DohRequest::TooLarge));
}
let n = stream.read(&mut tmp).await?;
if n == 0 {
return Ok(Some(DohRequest::BadRequest));
}
body.extend_from_slice(&tmp[..n]);
}
body.truncate(len);
Ok(Some(DohRequest::Query(body)))
}
_ => Ok(Some(DohRequest::BadRequest)),
}
}
fn parse_get(query_str: Option<&str>) -> DohRequest {
let Some(qs) = query_str else {
return DohRequest::BadRequest;
};
let Some(dns_param) = qs.split('&').find_map(|kv| kv.strip_prefix("dns=")) else {
return DohRequest::BadRequest;
};
match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(dns_param) {
Ok(bytes) if bytes.len() <= MAX_REQUEST => DohRequest::Query(bytes),
Ok(_) => DohRequest::TooLarge,
Err(_) => DohRequest::BadRequest,
}
}
fn header_value<'a>(req: &'a httparse::Request<'_, '_>, name: &str) -> Option<&'a str> {
req.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case(name))
.and_then(|h| std::str::from_utf8(h.value).ok())
}
pub(crate) fn find_header_end(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
}
async fn write_dns_response(stream: &mut TcpStream, dns_msg: &[u8]) -> std::io::Result<()> {
let head = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
dns_msg.len()
);
stream.write_all(head.as_bytes()).await?;
stream.write_all(dns_msg).await?;
stream.flush().await
}
pub(crate) async fn write_status(stream: &mut TcpStream, status: &str) -> std::io::Result<()> {
let head = format!("HTTP/1.1 {status}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n");
stream.write_all(head.as_bytes()).await?;
stream.flush().await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_header_end_locates_terminator() {
assert_eq!(find_header_end(b"GET / HTTP/1.1\r\n\r\n"), Some(18));
assert_eq!(
find_header_end(b"GET / HTTP/1.1\r\nX: 1\r\n\r\nBODY"),
Some(24)
);
assert_eq!(find_header_end(b"GET / HTTP/1.1\r\n"), None);
}
#[test]
fn parse_get_decodes_base64url_dns_param() {
let raw = [0xab, 0xcd, 0x01, 0x00];
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
match parse_get(Some(&format!("dns={encoded}"))) {
DohRequest::Query(b) => assert_eq!(b, raw),
_ => panic!("expected Query"),
}
}
#[test]
fn parse_get_rejects_missing_or_bad_param() {
assert!(matches!(parse_get(None), DohRequest::BadRequest));
assert!(matches!(parse_get(Some("foo=bar")), DohRequest::BadRequest));
assert!(matches!(
parse_get(Some("dns=!!!notbase64!!!")),
DohRequest::BadRequest
));
}
#[test]
fn parse_response_head_returns_content_length_on_200() {
let head = b"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: 42\r\nConnection: close\r\n\r\n";
assert_eq!(parse_response_head(head).unwrap(), 42);
}
#[test]
fn parse_response_head_rejects_non_200() {
let head = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
assert!(parse_response_head(head).is_err());
}
#[test]
fn parse_response_head_rejects_missing_length() {
let head = b"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\n\r\n";
assert!(parse_response_head(head).is_err());
}
#[test]
fn parse_response_head_rejects_oversized_body() {
let head = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n",
MAX_CLIENT_RESPONSE + 1
);
assert!(parse_response_head(head.as_bytes()).is_err());
}
#[test]
fn encode_formerr_sets_response_and_rcode() {
let msg = encode_formerr(0x1234);
assert_eq!(&msg[0..2], &[0x12, 0x34]);
assert_eq!(msg[2] & 0x80, 0x80, "QR response bit set");
assert_eq!(msg[3] & 0x0F, 0x01, "FORMERR rcode");
}
use ts_control::DnsConfig;
fn query_for(id: u16, labels: &[&str]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(&id.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); for label in labels {
buf.push(label.len() as u8);
buf.extend_from_slice(label.as_bytes());
}
buf.push(0); buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&1u16.to_be_bytes()); buf
}
fn rcode(resp: &[u8]) -> u8 {
resp[3] & 0x0F
}
fn view(filtered: &[&str]) -> DnsView {
DnsView {
cfg: DnsConfig {
magic_dns: true,
search_domains: vec!["user.ts.net".to_string()],
fallback_resolvers: vec![ts_control::DnsResolver {
transport: ts_control::ResolverTransport::Udp("9.9.9.9:53".parse().unwrap()),
use_with_exit_node: false,
}],
exit_node_filtered_set: filtered.iter().map(|s| s.to_string()).collect(),
..Default::default()
},
..Default::default()
}
}
#[test]
fn filtered_name_is_refused() {
let v = view(&["blocked.example.com"]);
let q = query_for(0x1, &["blocked", "example", "com"]);
match server_decide(&v, &q, true) {
ServerDecision::Reply(resp) => assert_eq!(rcode(&resp), 5, "REFUSED"),
ServerDecision::Forward { .. } => panic!("filtered name must not forward"),
}
}
#[test]
fn recursive_query_refused_when_egress_disabled() {
let v = view(&[]);
let q = query_for(0x2, &["example", "com"]);
match server_decide(&v, &q, false) {
ServerDecision::Reply(resp) => assert_eq!(rcode(&resp), 5, "REFUSED"),
ServerDecision::Forward { .. } => panic!("must not forward when egress disabled"),
}
}
#[test]
fn recursive_query_forwards_when_egress_enabled() {
let v = view(&[]);
let q = query_for(0x3, &["example", "com"]);
match server_decide(&v, &q, true) {
ServerDecision::Forward { upstreams, .. } => {
assert_eq!(upstreams, vec!["9.9.9.9:53".parse().unwrap()]);
}
ServerDecision::Reply(_) => panic!("expected forward when egress enabled"),
}
}
#[test]
fn authoritative_answer_is_not_gated() {
let v = view(&[]);
let q = query_for(0x4, &["host", "user", "ts", "net"]);
match server_decide(&v, &q, false) {
ServerDecision::Reply(resp) => assert_eq!(rcode(&resp), 3, "NXDOMAIN, not REFUSED"),
ServerDecision::Forward { .. } => panic!("tailnet name must not forward"),
}
}
#[test]
fn unparseable_body_is_formerr() {
match server_decide(&view(&[]), &[0xAB, 0xCD, 0xFF], true) {
ServerDecision::Reply(resp) => {
assert_eq!(&resp[0..2], &[0xAB, 0xCD]);
assert_eq!(rcode(&resp), 1, "FORMERR");
}
ServerDecision::Forward { .. } => panic!("garbage must not forward"),
}
}
}