use std::time::Duration;
use bytes::Bytes;
use hickory_net::proto::op::{DnsRequest, DnsRequestOptions, Message, Query, ResponseCode};
use hickory_net::proto::rr::{DNSClass, Name, RecordType};
use hickory_net::xfer::{DnsHandle, FirstAnswer as _};
use crate::codec::message::{Qclass, Qtype, Question};
use super::{Error, Result, UpstreamClient};
pub const DEFAULT_QUERY_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug, Clone)]
pub struct ForwardResult {
pub bytes: Bytes,
pub negative_ttl: Option<u32>,
pub is_negative: bool,
}
impl UpstreamClient {
pub async fn forward(&self, question: &Question, timeout: Duration) -> Result<ForwardResult> {
let name = Name::from_ascii(question.name.to_string()).map_err(|e| {
Error::Transport(format!(
"invalid question name {:?}: {e}",
question.name.to_string()
))
})?;
let record_type = match question.qtype {
Qtype::A => RecordType::A,
Qtype::Aaaa => RecordType::AAAA,
Qtype::Other(v) => RecordType::from(v),
};
let dns_class = match question.qclass {
Qclass::In => DNSClass::IN,
Qclass::Other(v) => DNSClass::from(v),
};
let mut query = Query::query(name, record_type);
query.set_query_class(dns_class);
let mut message = Message::query();
message.add_query(query);
message.metadata.recursion_desired = true;
let mut options = DnsRequestOptions::default();
options.recursion_desired = true;
options.edns_set_dnssec_ok = false;
let request = DnsRequest::new(message, options);
let response = tokio::time::timeout(timeout, self.handle().send(request).first_answer())
.await
.map_err(|_| Error::Timeout {
transport: self.transport(),
})?
.map_err(|source| Error::Exchange {
transport: self.transport(),
source,
})?;
let negative_ttl = response.negative_ttl();
let rcode = response.metadata.response_code;
let is_negative = rcode == ResponseCode::NXDomain
|| (rcode == ResponseCode::NoError && !response.contains_answer());
let bytes = Bytes::from(response.into_buffer());
Ok(ForwardResult {
bytes,
negative_ttl,
is_negative,
})
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use hickory_net::proto::op::{Message, MessageType, ResponseCode};
use hickory_net::proto::rr::rdata::{A, SOA};
use hickory_net::proto::rr::{Name, RData, Record};
use tokio::net::UdpSocket;
use super::*;
use crate::codec::message::{Qclass, Qtype, Question};
use crate::resolver::upstream::{UpstreamConfig, UpstreamTransport};
async fn spawn_mock_udp<F>(mut handler: F) -> SocketAddr
where
F: FnMut(Message) -> Option<Message> + Send + 'static,
{
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 512];
loop {
let Ok((len, peer)) = sock.recv_from(&mut buf).await else {
break;
};
let Ok(req) = Message::from_vec(&buf[..len]) else {
continue;
};
if let Some(resp) = handler(req)
&& let Ok(resp_bytes) = resp.to_vec()
{
let _ = sock.send_to(&resp_bytes, peer).await;
}
}
});
addr
}
fn stock_question() -> Question {
Question {
name: "example.com".parse().unwrap(),
qtype: Qtype::A,
qclass: Qclass::In,
}
}
async fn udp_client(addr: SocketAddr) -> UpstreamClient {
let cfg = UpstreamConfig {
addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
};
let (client, bg) = UpstreamClient::connect(&cfg).await.unwrap();
tokio::spawn(bg);
client
}
#[tokio::test]
async fn positive_a_answer() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NoError;
let name = Name::from_ascii("example.com.").unwrap();
let rdata = RData::A(A::new(93, 184, 216, 34));
resp.add_answer(Record::from_rdata(name, 300, rdata));
Some(resp)
})
.await;
let client = udp_client(addr).await;
let result = client
.forward(&stock_question(), Duration::from_secs(5))
.await
.expect("forward must succeed");
assert!(!result.is_negative, "A response must not be negative");
assert_eq!(result.negative_ttl, None, "positive answer has no SOA");
let scan = crate::codec::ttl::TtlScan::scan(&result.bytes)
.expect("TtlScan must succeed on returned bytes");
assert_eq!(scan.min_ttl, Some(300), "TTL from bytes must be 300");
let query = crate::codec::message::Query::try_from(result.bytes.clone())
.expect("Query::try_from must succeed on returned bytes");
assert_eq!(
query.question().name.to_string(),
"example.com.",
"question name in parsed bytes must match queried name"
);
}
#[tokio::test]
async fn nxdomain_with_soa() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NXDomain;
let zone = Name::from_ascii("example.com.").unwrap();
let mname = Name::from_ascii("ns1.example.com.").unwrap();
let rname = Name::from_ascii("hostmaster.example.com.").unwrap();
let soa = SOA::new(mname, rname, 1, 3600, 900, 604800, 60);
resp.add_authority(Record::from_rdata(zone, 120, RData::SOA(soa)));
Some(resp)
})
.await;
let client = udp_client(addr).await;
let result = client
.forward(&stock_question(), Duration::from_secs(5))
.await
.expect("forward must succeed");
assert!(result.is_negative, "NXDOMAIN must be negative");
assert_eq!(
result.negative_ttl,
Some(60),
"negative_ttl must be min(soa_ttl=120, soa_minimum=60) = 60"
);
}
#[tokio::test]
async fn nxdomain_without_soa() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NXDomain;
Some(resp)
})
.await;
let client = udp_client(addr).await;
let result = client
.forward(&stock_question(), Duration::from_secs(5))
.await
.expect("forward must succeed");
assert!(result.is_negative, "NXDOMAIN must be negative");
assert_eq!(
result.negative_ttl, None,
"NXDOMAIN without SOA must have no negative_ttl"
);
}
#[tokio::test]
async fn nodata_noerror_no_answer() {
let addr = spawn_mock_udp(|req| {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NoError;
Some(resp)
})
.await;
let client = udp_client(addr).await;
let result = client
.forward(&stock_question(), Duration::from_secs(5))
.await
.expect("forward must succeed");
assert!(
result.is_negative,
"NODATA (NOERROR, no answers) must be negative"
);
}
#[tokio::test]
async fn timeout_when_upstream_silent() {
let addr = spawn_mock_udp(|_req| None ).await;
let client = udp_client(addr).await;
let result = tokio::time::timeout(
Duration::from_secs(5), client.forward(&stock_question(), Duration::from_millis(150)),
)
.await
.expect("safety timeout: test took too long");
assert!(
matches!(result, Err(Error::Timeout { .. })),
"expected Error::Timeout, got: {result:?}"
);
}
}