Skip to main content

e2etest_dns/
lib.rs

1/*
2 * Copyright 2025-present ScyllaDB
3 * SPDX-License-Identifier: MIT OR Apache-2.0
4 */
5
6//! This crate provides a DNS server for e2etest tests. It provides an actor with handler using
7//! `tokio::sync::mpsc::Sender` over `enum Dns` message. It provides also a `trait DnsExt` with
8//! helper methods to send messages to the actor.
9
10use async_backtrace::frame;
11use async_backtrace::framed;
12use hickory_server::proto::rr::DNSClass;
13use hickory_server::proto::rr::LowerName;
14use hickory_server::proto::rr::Name;
15use hickory_server::proto::rr::RData;
16use hickory_server::proto::rr::Record;
17use hickory_server::proto::rr::RecordType;
18use hickory_server::proto::rr::RrKey;
19use hickory_server::proto::rr::rdata::a::A;
20use hickory_server::proto::rr::rdata::soa::SOA;
21use hickory_server::server::Server;
22use hickory_server::store::in_memory::InMemoryZoneHandler;
23use hickory_server::zone_handler::AxfrPolicy;
24use hickory_server::zone_handler::Catalog;
25use hickory_server::zone_handler::ZoneType;
26use std::net::Ipv4Addr;
27use std::str::FromStr;
28use std::sync::Arc;
29use tokio::net::UdpSocket;
30use tokio::sync::mpsc;
31use tokio::sync::oneshot;
32use tracing::Instrument;
33use tracing::debug;
34use tracing::debug_span;
35
36/// Messages for the DNS actor.
37pub enum Dns {
38    Version { tx: oneshot::Sender<String> },
39    Domain { tx: oneshot::Sender<String> },
40    Remove { name: String },
41    Upsert { name: String, ip: Ipv4Addr },
42}
43
44/// Extension trait for `mpsc::Sender<Dns>` to provide helper methods to send messages to the DNS
45/// actor.
46pub trait DnsExt {
47    /// Returns the version of the DNS server.
48    fn version(&self) -> impl Future<Output = String>;
49
50    /// Returns the domain name of the DNS server.
51    fn domain(&self) -> impl Future<Output = String>;
52
53    /// Remove an A DNS record with the given name.
54    fn remove(&self, name: String) -> impl Future<Output = ()>;
55
56    /// Upserts an A DNS record with the given name and IP address.
57    fn upsert(&self, name: String, ip: Ipv4Addr) -> impl Future<Output = ()>;
58}
59
60impl DnsExt for mpsc::Sender<Dns> {
61    #[framed]
62    async fn version(&self) -> String {
63        let (tx, rx) = oneshot::channel();
64        self.send(Dns::Version { tx })
65            .await
66            .expect("DnsExt::version: internal actor should receive request");
67        rx.await
68            .expect("DnsExt::version: internal actor should send response")
69    }
70
71    #[framed]
72    async fn domain(&self) -> String {
73        let (tx, rx) = oneshot::channel();
74        self.send(Dns::Domain { tx })
75            .await
76            .expect("DnsExt::domain: internal actor should receive request");
77        rx.await
78            .expect("DnsExt::domain: internal actor should send response")
79    }
80
81    #[framed]
82    async fn remove(&self, name: String) {
83        self.send(Dns::Remove { name })
84            .await
85            .expect("DnsExt::remove: internal actor should receive request");
86    }
87
88    #[framed]
89    async fn upsert(&self, name: String, ip: Ipv4Addr) {
90        self.send(Dns::Upsert { name, ip })
91            .await
92            .expect("DnsExt::upsert: internal actor should receive request");
93    }
94}
95#[framed]
96/// Starts the DNS server on the given IP address.
97pub async fn new(ip: Ipv4Addr) -> mpsc::Sender<Dns> {
98    assert!(ip.is_loopback(), "DNS server should listen on a localhost");
99
100    let (tx, mut rx) = mpsc::channel(10);
101
102    let mut state = State::new().await;
103
104    let socket = UdpSocket::bind((ip, 53))
105        .await
106        .expect("dns: failed to bind UDP socket");
107
108    let mut catalog = Catalog::new();
109    catalog.upsert(
110        LowerName::from_str(ZONE).unwrap(),
111        vec![state.authority.clone()],
112    );
113    let mut server = Server::new(catalog);
114    server.register_socket(socket);
115
116    tokio::spawn(
117        frame!(async move {
118            debug!("starting");
119
120            while let Some(msg) = rx.recv().await {
121                process(msg, &mut state).await;
122            }
123
124            server
125                .shutdown_gracefully()
126                .await
127                .expect("stop: failed to shutdown server gracefully");
128
129            debug!("stopped");
130        })
131        .instrument(debug_span!("dns")),
132    );
133
134    tx
135}
136
137struct State {
138    version: String,
139    authority: Arc<InMemoryZoneHandler>,
140    serial: u32,
141}
142
143const ZONE: &str = "validator.test.";
144const TTL: u32 = 60;
145
146impl State {
147    #[framed]
148    async fn new() -> Self {
149        let version = format!("hickory-server-{}", hickory_server::version());
150
151        let authority = Arc::new(InMemoryZoneHandler::empty(
152            Name::from_str(ZONE).unwrap(),
153            ZoneType::Primary,
154            AxfrPolicy::Deny,
155        ));
156        let mut soa = Record::from_rdata(
157            Name::from_str(ZONE).unwrap(),
158            TTL,
159            RData::SOA(SOA::new(
160                Name::from_str(ZONE).unwrap(),
161                Name::new(),
162                0,
163                0,
164                0,
165                0,
166                0,
167            )),
168        );
169        soa.dns_class = DNSClass::IN;
170        authority.upsert(soa, 0).await;
171
172        Self {
173            version,
174            authority,
175            serial: 0,
176        }
177    }
178}
179
180#[framed]
181async fn process(msg: Dns, state: &mut State) {
182    match msg {
183        Dns::Version { tx } => {
184            tx.send(state.version.clone())
185                .expect("process Dns::Version: failed to send a response");
186        }
187
188        Dns::Domain { tx } => {
189            tx.send(ZONE[..ZONE.len() - 1].to_string())
190                .expect("process Dns::Domain: failed to send a response");
191        }
192
193        Dns::Remove { name } => {
194            remove(name, state).await;
195        }
196
197        Dns::Upsert { name, ip } => {
198            upsert(name, ip, state).await;
199        }
200    }
201}
202
203#[framed]
204async fn remove(name: String, state: &mut State) {
205    state.authority.records_mut().await.remove(&RrKey::new(
206        LowerName::from_str(&format!("{name}.{ZONE}")).expect("remove: failed to parse name"),
207        RecordType::A,
208    ));
209}
210
211#[framed]
212async fn upsert(name: String, ip: Ipv4Addr, state: &mut State) {
213    let name = Name::from_str(&format!("{name}.{ZONE}")).expect("upsert: failed to parse name");
214
215    let serial = state.serial;
216    state.serial += 1;
217
218    let octets = ip.octets();
219    let mut record = Record::from_rdata(
220        name,
221        TTL,
222        RData::A(A::new(octets[0], octets[1], octets[2], octets[3])),
223    );
224    record.dns_class = DNSClass::IN;
225
226    state.authority.upsert(record, serial).await;
227}