1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum DnsError {
7 NxDomain,
9 NoRecords,
11 TempFail,
13}
14
15impl std::fmt::Display for DnsError {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 DnsError::NxDomain => write!(f, "NXDOMAIN"),
19 DnsError::NoRecords => write!(f, "no records"),
20 DnsError::TempFail => write!(f, "temporary DNS failure"),
21 }
22 }
23}
24
25impl std::error::Error for DnsError {}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct MxRecord {
30 pub preference: u16,
31 pub exchange: String,
32}
33
34pub trait DnsResolver: Send + Sync {
42 fn query_txt(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<String>, DnsError>> + Send;
43 fn query_a(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<Ipv4Addr>, DnsError>> + Send;
44 fn query_aaaa(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<Ipv6Addr>, DnsError>> + Send;
45 fn query_mx(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<MxRecord>, DnsError>> + Send;
46 fn query_ptr(&self, ip: &IpAddr) -> impl std::future::Future<Output = Result<Vec<String>, DnsError>> + Send;
47 fn query_exists(&self, name: &str) -> impl std::future::Future<Output = Result<bool, DnsError>> + Send;
48}
49
50impl<R: DnsResolver> DnsResolver for &R {
53 async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
54 <R as DnsResolver>::query_txt(self, name).await
55 }
56 async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
57 <R as DnsResolver>::query_a(self, name).await
58 }
59 async fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
60 <R as DnsResolver>::query_aaaa(self, name).await
61 }
62 async fn query_mx(&self, name: &str) -> Result<Vec<MxRecord>, DnsError> {
63 <R as DnsResolver>::query_mx(self, name).await
64 }
65 async fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, DnsError> {
66 <R as DnsResolver>::query_ptr(self, ip).await
67 }
68 async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
69 <R as DnsResolver>::query_exists(self, name).await
70 }
71}
72
73#[cfg(test)]
75pub mod mock {
76 use super::*;
77 use std::collections::HashMap;
78
79 #[derive(Debug, Default, Clone)]
80 pub struct MockResolver {
81 pub txt: HashMap<String, Result<Vec<String>, DnsError>>,
82 pub a: HashMap<String, Result<Vec<Ipv4Addr>, DnsError>>,
83 pub aaaa: HashMap<String, Result<Vec<Ipv6Addr>, DnsError>>,
84 pub mx: HashMap<String, Result<Vec<MxRecord>, DnsError>>,
85 pub ptr: HashMap<String, Result<Vec<String>, DnsError>>,
86 }
87
88 impl MockResolver {
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn add_txt(&mut self, name: &str, records: Vec<String>) {
94 self.txt.insert(name.to_lowercase(), Ok(records));
95 }
96
97 pub fn add_txt_err(&mut self, name: &str, err: DnsError) {
98 self.txt.insert(name.to_lowercase(), Err(err));
99 }
100
101 pub fn add_a(&mut self, name: &str, addrs: Vec<Ipv4Addr>) {
102 self.a.insert(name.to_lowercase(), Ok(addrs));
103 }
104
105 pub fn add_a_err(&mut self, name: &str, err: DnsError) {
106 self.a.insert(name.to_lowercase(), Err(err));
107 }
108
109 pub fn add_aaaa(&mut self, name: &str, addrs: Vec<Ipv6Addr>) {
110 self.aaaa.insert(name.to_lowercase(), Ok(addrs));
111 }
112
113 pub fn add_aaaa_err(&mut self, name: &str, err: DnsError) {
114 self.aaaa.insert(name.to_lowercase(), Err(err));
115 }
116
117 pub fn add_mx(&mut self, name: &str, records: Vec<MxRecord>) {
118 self.mx.insert(name.to_lowercase(), Ok(records));
119 }
120
121 pub fn add_mx_err(&mut self, name: &str, err: DnsError) {
122 self.mx.insert(name.to_lowercase(), Err(err));
123 }
124
125 pub fn add_ptr(&mut self, ip_str: &str, names: Vec<String>) {
126 self.ptr.insert(ip_str.to_string(), Ok(names));
127 }
128
129 pub fn add_ptr_err(&mut self, ip_str: &str, err: DnsError) {
130 self.ptr.insert(ip_str.to_string(), Err(err));
131 }
132
133 fn lookup<T: Clone>(
134 map: &HashMap<String, Result<Vec<T>, DnsError>>,
135 key: &str,
136 ) -> Result<Vec<T>, DnsError> {
137 match map.get(&key.to_lowercase()) {
138 Some(Ok(v)) => Ok(v.clone()),
139 Some(Err(e)) => Err(e.clone()),
140 None => Err(DnsError::NxDomain),
141 }
142 }
143 }
144
145 impl DnsResolver for MockResolver {
146 async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
147 Self::lookup(&self.txt, name)
148 }
149
150 async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
151 Self::lookup(&self.a, name)
152 }
153
154 async fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
155 Self::lookup(&self.aaaa, name)
156 }
157
158 async fn query_mx(&self, name: &str) -> Result<Vec<MxRecord>, DnsError> {
159 Self::lookup(&self.mx, name)
160 }
161
162 async fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, DnsError> {
163 let key = ip.to_string();
164 match self.ptr.get(&key) {
165 Some(Ok(v)) => Ok(v.clone()),
166 Some(Err(e)) => Err(e.clone()),
167 None => Err(DnsError::NxDomain),
168 }
169 }
170
171 async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
172 match self.query_a(name).await {
173 Ok(addrs) => Ok(!addrs.is_empty()),
174 Err(DnsError::NxDomain) => Ok(false),
175 Err(DnsError::NoRecords) => Ok(false),
176 Err(e) => Err(e),
177 }
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use super::mock::MockResolver;
186
187 #[tokio::test]
191 async fn trait_has_all_required_methods() {
192 let resolver = MockResolver::new();
193 let _: Result<Vec<String>, DnsError> = resolver.query_txt("example.com").await;
195 let _: Result<Vec<Ipv4Addr>, DnsError> = resolver.query_a("example.com").await;
196 let _: Result<Vec<Ipv6Addr>, DnsError> = resolver.query_aaaa("example.com").await;
197 let _: Result<Vec<MxRecord>, DnsError> = resolver.query_mx("example.com").await;
198 let ip: IpAddr = "1.2.3.4".parse().unwrap();
199 let _: Result<Vec<String>, DnsError> = resolver.query_ptr(&ip).await;
200 let _: Result<bool, DnsError> = resolver.query_exists("example.com").await;
201 }
202
203 #[tokio::test]
205 async fn query_txt_returns_records() {
206 let mut resolver = MockResolver::new();
207 resolver.add_txt("example.com", vec!["v=spf1 -all".to_string()]);
208 let result = resolver.query_txt("example.com").await.unwrap();
209 assert_eq!(result, vec!["v=spf1 -all"]);
210 }
211
212 #[tokio::test]
214 async fn query_a_returns_addresses() {
215 let mut resolver = MockResolver::new();
216 resolver.add_a("example.com", vec!["1.2.3.4".parse().unwrap()]);
217 let result = resolver.query_a("example.com").await.unwrap();
218 assert_eq!(result, vec!["1.2.3.4".parse::<Ipv4Addr>().unwrap()]);
219 }
220
221 #[tokio::test]
223 async fn query_aaaa_returns_addresses() {
224 let mut resolver = MockResolver::new();
225 resolver.add_aaaa("example.com", vec!["::1".parse().unwrap()]);
226 let result = resolver.query_aaaa("example.com").await.unwrap();
227 assert_eq!(result, vec!["::1".parse::<Ipv6Addr>().unwrap()]);
228 }
229
230 #[tokio::test]
232 async fn query_mx_returns_records() {
233 let mut resolver = MockResolver::new();
234 resolver.add_mx(
235 "example.com",
236 vec![MxRecord { preference: 10, exchange: "mail.example.com".into() }],
237 );
238 let result = resolver.query_mx("example.com").await.unwrap();
239 assert_eq!(result.len(), 1);
240 assert_eq!(result[0].preference, 10);
241 assert_eq!(result[0].exchange, "mail.example.com");
242 }
243
244 #[tokio::test]
246 async fn query_ptr_returns_names() {
247 let mut resolver = MockResolver::new();
248 resolver.add_ptr("1.2.3.4", vec!["host.example.com".into()]);
249 let ip: IpAddr = "1.2.3.4".parse().unwrap();
250 let result = resolver.query_ptr(&ip).await.unwrap();
251 assert_eq!(result, vec!["host.example.com"]);
252 }
253
254 #[tokio::test]
256 async fn query_exists_returns_bool() {
257 let mut resolver = MockResolver::new();
258 resolver.add_a("example.com", vec!["1.2.3.4".parse().unwrap()]);
259 assert!(resolver.query_exists("example.com").await.unwrap());
260 }
261
262 #[tokio::test]
263 async fn query_exists_false_for_nxdomain() {
264 let resolver = MockResolver::new();
265 assert!(!resolver.query_exists("nonexistent.example.com").await.unwrap());
266 }
267
268 #[tokio::test]
270 async fn dns_error_nxdomain() {
271 let resolver = MockResolver::new();
272 assert_eq!(
273 resolver.query_txt("nope.example.com").await.unwrap_err(),
274 DnsError::NxDomain
275 );
276 }
277
278 #[tokio::test]
279 async fn dns_error_tempfail() {
280 let mut resolver = MockResolver::new();
281 resolver.add_txt_err("fail.example.com", DnsError::TempFail);
282 assert_eq!(
283 resolver.query_txt("fail.example.com").await.unwrap_err(),
284 DnsError::TempFail
285 );
286 }
287
288 #[tokio::test]
289 async fn dns_error_no_records() {
290 let mut resolver = MockResolver::new();
291 resolver.add_a_err("empty.example.com", DnsError::NoRecords);
292 assert_eq!(
293 resolver.query_a("empty.example.com").await.unwrap_err(),
294 DnsError::NoRecords
295 );
296 }
297
298 #[tokio::test]
304 async fn blanket_impl_ref_resolver() {
305 let mut resolver = MockResolver::new();
306 resolver.add_txt("example.com", vec!["test".into()]);
307 let r: &MockResolver = &resolver;
308 let result = r.query_txt("example.com").await.unwrap();
309 assert_eq!(result, vec!["test"]);
310 }
311
312 #[tokio::test]
314 async fn mock_case_insensitive() {
315 let mut resolver = MockResolver::new();
316 resolver.add_txt("EXAMPLE.COM", vec!["data".into()]);
317 let result = resolver.query_txt("example.com").await.unwrap();
318 assert_eq!(result, vec!["data"]);
319 }
320}