Skip to main content

dns_mock_server/
lib.rs

1//! Implementation of a DNS server intended for use in tests.
2//!
3//! This allows you to run a proper DNS server while setting up records to be mapped to specific IP
4//! addresses. Your test code can then target the locally bound server and make normal DNS
5//! requests.
6
7use std::collections::HashMap;
8use std::net::IpAddr;
9use std::str::FromStr;
10
11use async_trait::async_trait;
12use hickory_proto::rr::LowerName;
13use hickory_proto::ProtoError;
14use hickory_server::authority::MessageResponseBuilder;
15use hickory_server::proto::op::Header;
16use hickory_server::proto::op::ResponseCode;
17use hickory_server::proto::rr::rdata::{A, AAAA};
18use hickory_server::proto::rr::{RData, Record};
19use hickory_server::server::{
20    Request, RequestHandler, ResponseHandler, ResponseInfo, ServerFuture,
21};
22use tokio::net::UdpSocket;
23
24/// A simple mock server for DNS requests.
25///
26/// The intended usage is to create a new instance using [`Server::default()`] and add some record
27/// mappings to it. You can then bind a [`UdpSocket`] and start the server with [`Server::start()`]
28/// in a background task before making requests on the main thread.
29#[derive(Clone, Debug, Default)]
30pub struct Server {
31    store: HashMap<LowerName, Vec<IpAddr>>,
32}
33
34impl Server {
35    /// Adds a mapping from a DNS record to some IP addresses.
36    ///
37    /// # Example
38    ///
39    /// ```
40    /// # use std::net::{IpAddr, Ipv4Addr};
41    /// # use dns_mock_server::Server;
42    /// let mut server = Server::default();
43    /// let records = vec![IpAddr::V4(Ipv4Addr::LOCALHOST)];
44    ///
45    /// server.add_records("example.com", records).expect("Invalid hostname");
46    /// ```
47    pub fn add_records(&mut self, name: &str, records: Vec<IpAddr>) -> Result<(), ProtoError> {
48        let name = LowerName::from_str(name)?;
49
50        self.store.insert(name, records);
51
52        Ok(())
53    }
54
55    /// Starts the mock server on the given [`UdpSocket`].
56    ///
57    /// This should be run in a background task using a method such as [`tokio::spawn`].
58    pub async fn start(self, socket: UdpSocket) -> Result<(), ProtoError> {
59        let mut server = ServerFuture::new(self);
60
61        server.register_socket(socket);
62        server.block_until_done().await?;
63
64        Ok(())
65    }
66}
67
68#[async_trait]
69impl RequestHandler for Server {
70    async fn handle_request<R: ResponseHandler>(
71        &self,
72        request: &Request,
73        mut response_handler: R,
74    ) -> ResponseInfo {
75        let builder = MessageResponseBuilder::from_message_request(request);
76
77        let mut header = Header::response_from_request(request.header());
78        header.set_authoritative(true);
79
80        let name = request.queries()[0].name();
81
82        if let Some(entries) = self.store.get(name) {
83            let records: Vec<_> = entries
84                .iter()
85                .map(|entry| match entry {
86                    IpAddr::V4(ipv4) => RData::A(A::from(*ipv4)),
87                    IpAddr::V6(ipv6) => RData::AAAA(AAAA::from(*ipv6)),
88                })
89                .map(|rdata| Record::from_rdata(name.into(), 60, rdata))
90                .collect();
91
92            let response = builder.build(header, records.iter(), &[], &[], &[]);
93            response_handler.send_response(response).await.unwrap()
94        } else {
95            header.set_response_code(ResponseCode::ServFail);
96
97            let response = builder.build_no_records(header);
98            response_handler.send_response(response).await.unwrap()
99        }
100    }
101}