Skip to main content

email_auth/common/
dns.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2
3/// DNS error types — must distinguish NxDomain, NoRecords, and TempFail
4/// for correct SPF void lookup tracking and error propagation.
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum DnsError {
7    /// Domain does not exist (NXDOMAIN).
8    NxDomain,
9    /// Domain exists but has no records of the requested type.
10    NoRecords,
11    /// Transient DNS failure (timeout, SERVFAIL, network error).
12    TempFail,
13}
14
15impl std::fmt::Display for DnsError {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            DnsError::NxDomain => write!(f, "NXDOMAIN"),
19            DnsError::NoRecords => write!(f, "no records"),
20            DnsError::TempFail => write!(f, "temporary DNS failure"),
21        }
22    }
23}
24
25impl std::error::Error for DnsError {}
26
27/// MX record with preference and exchange hostname.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct MxRecord {
30    pub preference: u16,
31    pub exchange: String,
32}
33
34/// Abstract async DNS resolver trait.
35///
36/// DNS caching is the caller's responsibility. Implement this trait
37/// with a caching layer at the resolver level.
38///
39/// Methods return `impl Future + Send` to allow use in `Pin<Box<dyn Future + Send>>` contexts
40/// (required for SPF async recursion via include/redirect).
41pub trait DnsResolver: Send + Sync {
42    fn query_txt(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<String>, DnsError>> + Send;
43    fn query_a(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<Ipv4Addr>, DnsError>> + Send;
44    fn query_aaaa(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<Ipv6Addr>, DnsError>> + Send;
45    fn query_mx(&self, name: &str) -> impl std::future::Future<Output = Result<Vec<MxRecord>, DnsError>> + Send;
46    fn query_ptr(&self, ip: &IpAddr) -> impl std::future::Future<Output = Result<Vec<String>, DnsError>> + Send;
47    fn query_exists(&self, name: &str) -> impl std::future::Future<Output = Result<bool, DnsError>> + Send;
48}
49
50/// Blanket impl: allow passing `&R` where `R: DnsResolver` is expected.
51/// Uses UFCS to avoid infinite recursion.
52impl<R: DnsResolver> DnsResolver for &R {
53    async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
54        <R as DnsResolver>::query_txt(self, name).await
55    }
56    async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
57        <R as DnsResolver>::query_a(self, name).await
58    }
59    async fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
60        <R as DnsResolver>::query_aaaa(self, name).await
61    }
62    async fn query_mx(&self, name: &str) -> Result<Vec<MxRecord>, DnsError> {
63        <R as DnsResolver>::query_mx(self, name).await
64    }
65    async fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, DnsError> {
66        <R as DnsResolver>::query_ptr(self, ip).await
67    }
68    async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
69        <R as DnsResolver>::query_exists(self, name).await
70    }
71}
72
73/// Mock DNS resolver for testing. Configure responses per domain.
74#[cfg(test)]
75pub mod mock {
76    use super::*;
77    use std::collections::HashMap;
78
79    #[derive(Debug, Default, Clone)]
80    pub struct MockResolver {
81        pub txt: HashMap<String, Result<Vec<String>, DnsError>>,
82        pub a: HashMap<String, Result<Vec<Ipv4Addr>, DnsError>>,
83        pub aaaa: HashMap<String, Result<Vec<Ipv6Addr>, DnsError>>,
84        pub mx: HashMap<String, Result<Vec<MxRecord>, DnsError>>,
85        pub ptr: HashMap<String, Result<Vec<String>, DnsError>>,
86    }
87
88    impl MockResolver {
89        pub fn new() -> Self {
90            Self::default()
91        }
92
93        pub fn add_txt(&mut self, name: &str, records: Vec<String>) {
94            self.txt.insert(name.to_lowercase(), Ok(records));
95        }
96
97        pub fn add_txt_err(&mut self, name: &str, err: DnsError) {
98            self.txt.insert(name.to_lowercase(), Err(err));
99        }
100
101        pub fn add_a(&mut self, name: &str, addrs: Vec<Ipv4Addr>) {
102            self.a.insert(name.to_lowercase(), Ok(addrs));
103        }
104
105        pub fn add_a_err(&mut self, name: &str, err: DnsError) {
106            self.a.insert(name.to_lowercase(), Err(err));
107        }
108
109        pub fn add_aaaa(&mut self, name: &str, addrs: Vec<Ipv6Addr>) {
110            self.aaaa.insert(name.to_lowercase(), Ok(addrs));
111        }
112
113        pub fn add_aaaa_err(&mut self, name: &str, err: DnsError) {
114            self.aaaa.insert(name.to_lowercase(), Err(err));
115        }
116
117        pub fn add_mx(&mut self, name: &str, records: Vec<MxRecord>) {
118            self.mx.insert(name.to_lowercase(), Ok(records));
119        }
120
121        pub fn add_mx_err(&mut self, name: &str, err: DnsError) {
122            self.mx.insert(name.to_lowercase(), Err(err));
123        }
124
125        pub fn add_ptr(&mut self, ip_str: &str, names: Vec<String>) {
126            self.ptr.insert(ip_str.to_string(), Ok(names));
127        }
128
129        pub fn add_ptr_err(&mut self, ip_str: &str, err: DnsError) {
130            self.ptr.insert(ip_str.to_string(), Err(err));
131        }
132
133        fn lookup<T: Clone>(
134            map: &HashMap<String, Result<Vec<T>, DnsError>>,
135            key: &str,
136        ) -> Result<Vec<T>, DnsError> {
137            match map.get(&key.to_lowercase()) {
138                Some(Ok(v)) => Ok(v.clone()),
139                Some(Err(e)) => Err(e.clone()),
140                None => Err(DnsError::NxDomain),
141            }
142        }
143    }
144
145    impl DnsResolver for MockResolver {
146        async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
147            Self::lookup(&self.txt, name)
148        }
149
150        async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
151            Self::lookup(&self.a, name)
152        }
153
154        async fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
155            Self::lookup(&self.aaaa, name)
156        }
157
158        async fn query_mx(&self, name: &str) -> Result<Vec<MxRecord>, DnsError> {
159            Self::lookup(&self.mx, name)
160        }
161
162        async fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, DnsError> {
163            let key = ip.to_string();
164            match self.ptr.get(&key) {
165                Some(Ok(v)) => Ok(v.clone()),
166                Some(Err(e)) => Err(e.clone()),
167                None => Err(DnsError::NxDomain),
168            }
169        }
170
171        async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
172            match self.query_a(name).await {
173                Ok(addrs) => Ok(!addrs.is_empty()),
174                Err(DnsError::NxDomain) => Ok(false),
175                Err(DnsError::NoRecords) => Ok(false),
176                Err(e) => Err(e),
177            }
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use super::mock::MockResolver;
186
187    // CHK-157: Abstract DNS resolver trait
188    // CHK-158: Support async DNS queries
189    // CHK-159: Methods needed
190    #[tokio::test]
191    async fn trait_has_all_required_methods() {
192        let resolver = MockResolver::new();
193        // Verify all 6 methods exist and return correct types
194        let _: Result<Vec<String>, DnsError> = resolver.query_txt("example.com").await;
195        let _: Result<Vec<Ipv4Addr>, DnsError> = resolver.query_a("example.com").await;
196        let _: Result<Vec<Ipv6Addr>, DnsError> = resolver.query_aaaa("example.com").await;
197        let _: Result<Vec<MxRecord>, DnsError> = resolver.query_mx("example.com").await;
198        let ip: IpAddr = "1.2.3.4".parse().unwrap();
199        let _: Result<Vec<String>, DnsError> = resolver.query_ptr(&ip).await;
200        let _: Result<bool, DnsError> = resolver.query_exists("example.com").await;
201    }
202
203    // CHK-160: query_txt
204    #[tokio::test]
205    async fn query_txt_returns_records() {
206        let mut resolver = MockResolver::new();
207        resolver.add_txt("example.com", vec!["v=spf1 -all".to_string()]);
208        let result = resolver.query_txt("example.com").await.unwrap();
209        assert_eq!(result, vec!["v=spf1 -all"]);
210    }
211
212    // CHK-161: query_a
213    #[tokio::test]
214    async fn query_a_returns_addresses() {
215        let mut resolver = MockResolver::new();
216        resolver.add_a("example.com", vec!["1.2.3.4".parse().unwrap()]);
217        let result = resolver.query_a("example.com").await.unwrap();
218        assert_eq!(result, vec!["1.2.3.4".parse::<Ipv4Addr>().unwrap()]);
219    }
220
221    // CHK-162: query_aaaa
222    #[tokio::test]
223    async fn query_aaaa_returns_addresses() {
224        let mut resolver = MockResolver::new();
225        resolver.add_aaaa("example.com", vec!["::1".parse().unwrap()]);
226        let result = resolver.query_aaaa("example.com").await.unwrap();
227        assert_eq!(result, vec!["::1".parse::<Ipv6Addr>().unwrap()]);
228    }
229
230    // CHK-163: query_mx
231    #[tokio::test]
232    async fn query_mx_returns_records() {
233        let mut resolver = MockResolver::new();
234        resolver.add_mx(
235            "example.com",
236            vec![MxRecord { preference: 10, exchange: "mail.example.com".into() }],
237        );
238        let result = resolver.query_mx("example.com").await.unwrap();
239        assert_eq!(result.len(), 1);
240        assert_eq!(result[0].preference, 10);
241        assert_eq!(result[0].exchange, "mail.example.com");
242    }
243
244    // CHK-164: query_ptr
245    #[tokio::test]
246    async fn query_ptr_returns_names() {
247        let mut resolver = MockResolver::new();
248        resolver.add_ptr("1.2.3.4", vec!["host.example.com".into()]);
249        let ip: IpAddr = "1.2.3.4".parse().unwrap();
250        let result = resolver.query_ptr(&ip).await.unwrap();
251        assert_eq!(result, vec!["host.example.com"]);
252    }
253
254    // CHK-165: query_exists
255    #[tokio::test]
256    async fn query_exists_returns_bool() {
257        let mut resolver = MockResolver::new();
258        resolver.add_a("example.com", vec!["1.2.3.4".parse().unwrap()]);
259        assert!(resolver.query_exists("example.com").await.unwrap());
260    }
261
262    #[tokio::test]
263    async fn query_exists_false_for_nxdomain() {
264        let resolver = MockResolver::new();
265        assert!(!resolver.query_exists("nonexistent.example.com").await.unwrap());
266    }
267
268    // CHK-166: DnsError distinguishes NxDomain, NoRecords, TempFail
269    #[tokio::test]
270    async fn dns_error_nxdomain() {
271        let resolver = MockResolver::new();
272        assert_eq!(
273            resolver.query_txt("nope.example.com").await.unwrap_err(),
274            DnsError::NxDomain
275        );
276    }
277
278    #[tokio::test]
279    async fn dns_error_tempfail() {
280        let mut resolver = MockResolver::new();
281        resolver.add_txt_err("fail.example.com", DnsError::TempFail);
282        assert_eq!(
283            resolver.query_txt("fail.example.com").await.unwrap_err(),
284            DnsError::TempFail
285        );
286    }
287
288    #[tokio::test]
289    async fn dns_error_no_records() {
290        let mut resolver = MockResolver::new();
291        resolver.add_a_err("empty.example.com", DnsError::NoRecords);
292        assert_eq!(
293            resolver.query_a("empty.example.com").await.unwrap_err(),
294            DnsError::NoRecords
295        );
296    }
297
298    // CHK-167: DNS caching is caller responsibility
299    // This is a documentation/design constraint — verified by trait lacking cache methods.
300    // The DnsResolver trait has no cache-related methods.
301
302    // Blanket impl test: &R implements DnsResolver
303    #[tokio::test]
304    async fn blanket_impl_ref_resolver() {
305        let mut resolver = MockResolver::new();
306        resolver.add_txt("example.com", vec!["test".into()]);
307        let r: &MockResolver = &resolver;
308        let result = r.query_txt("example.com").await.unwrap();
309        assert_eq!(result, vec!["test"]);
310    }
311
312    // Mock: case-insensitive lookup
313    #[tokio::test]
314    async fn mock_case_insensitive() {
315        let mut resolver = MockResolver::new();
316        resolver.add_txt("EXAMPLE.COM", vec!["data".into()]);
317        let result = resolver.query_txt("example.com").await.unwrap();
318        assert_eq!(result, vec!["data"]);
319    }
320}