Skip to main content

modo/dns/
verifier.rs

1//! Domain ownership verification via DNS TXT and CNAME record lookups.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use crate::error::{Error, Result};
7
8use super::config::DnsConfig;
9use super::error::DnsError;
10use super::resolver::{DnsResolver, UdpDnsResolver};
11
12/// Result of a domain verification check.
13///
14/// Returned by [`DomainVerifier::verify_domain`]. Both checks run
15/// concurrently; a field is `true` only when the corresponding record was
16/// found and matched exactly.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct DomainStatus {
19    /// Whether the TXT record at `{txt_prefix}.{domain}` matched the
20    /// expected token.
21    pub txt_verified: bool,
22    /// Whether the CNAME record at `domain` pointed to the expected target.
23    pub cname_verified: bool,
24}
25
26pub(crate) struct Inner {
27    pub(crate) resolver: Arc<dyn DnsResolver>,
28    pub(crate) txt_prefix: String,
29}
30
31/// DNS-based domain ownership verification service.
32///
33/// Checks TXT record ownership and CNAME routing via raw UDP DNS queries.
34/// Construct with [`DomainVerifier::from_config`]. The struct is cheap to
35/// clone because it wraps an `Arc` internally.
36///
37/// # Example
38///
39/// ```rust,no_run
40/// # {
41/// use modo::dns::{DnsConfig, DomainVerifier, generate_verification_token};
42///
43/// let config = DnsConfig::new("8.8.8.8:53");
44/// let verifier = DomainVerifier::from_config(&config).unwrap();
45/// let token = generate_verification_token();
46///
47/// // Ask the user to create: _modo-verify.example.com TXT "<token>"
48/// // Then verify:
49/// // let ok = verifier.check_txt("example.com", &token).await?;
50/// # }
51/// ```
52pub struct DomainVerifier {
53    inner: Arc<Inner>,
54}
55
56impl Clone for DomainVerifier {
57    fn clone(&self) -> Self {
58        Self {
59            inner: Arc::clone(&self.inner),
60        }
61    }
62}
63
64impl DomainVerifier {
65    /// Create a new verifier from [`DnsConfig`].
66    ///
67    /// Parses the nameserver address and builds a UDP resolver.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the nameserver string is not a valid IP address.
72    pub fn from_config(config: &DnsConfig) -> Result<Self> {
73        let nameserver = config.parse_nameserver()?;
74        let timeout = Duration::from_millis(config.timeout_ms);
75        let resolver = UdpDnsResolver::new(nameserver, timeout);
76
77        Ok(Self {
78            inner: Arc::new(Inner {
79                resolver: Arc::new(resolver),
80                txt_prefix: config.txt_prefix.clone(),
81            }),
82        })
83    }
84
85    /// Create a verifier with a custom resolver and TXT record prefix.
86    ///
87    /// Used by other in-crate modules to build a `DomainVerifier` backed by a
88    /// mock resolver for testing. Only called from `#[cfg(test)]` blocks in
89    /// other modules, so it has zero callers on the lib target — hence
90    /// `allow(dead_code)`. Cannot use `#[cfg(test)]` here because that would
91    /// make it invisible to other modules' test blocks.
92    #[allow(dead_code)]
93    pub(crate) fn with_resolver(
94        resolver: impl DnsResolver + 'static,
95        txt_prefix: impl Into<String>,
96    ) -> Self {
97        Self {
98            inner: Arc::new(Inner {
99                resolver: Arc::new(resolver),
100                txt_prefix: txt_prefix.into(),
101            }),
102        }
103    }
104
105    /// Check whether a TXT record matches the expected verification token.
106    ///
107    /// Looks up `{txt_prefix}.{domain}` and returns `true` if any TXT record
108    /// value equals `expected_token` exactly (case-sensitive). Returns `false`
109    /// when the record exists but no value matches, or when no TXT records
110    /// exist (NXDOMAIN is treated as an empty record set, not an error).
111    ///
112    /// # Errors
113    ///
114    /// Returns [`crate::Error`] with status 400 when `domain` or
115    /// `expected_token` is empty, or a gateway error on network/DNS failure.
116    pub async fn check_txt(&self, domain: &str, expected_token: &str) -> Result<bool> {
117        if domain.is_empty() {
118            return Err(Error::bad_request("domain must not be empty")
119                .chain(DnsError::InvalidInput)
120                .with_code(DnsError::InvalidInput.code()));
121        }
122        if expected_token.is_empty() {
123            return Err(Error::bad_request("token must not be empty")
124                .chain(DnsError::InvalidInput)
125                .with_code(DnsError::InvalidInput.code()));
126        }
127
128        let lookup_domain = format!("{}.{}", self.inner.txt_prefix, domain);
129        let records = self.inner.resolver.resolve_txt(&lookup_domain).await?;
130
131        Ok(records.iter().any(|r| r == expected_token))
132    }
133
134    /// Check whether a CNAME record points to the expected target.
135    ///
136    /// Normalizes both the resolved target and `expected_target` before
137    /// comparing: both are lowercased and any trailing dot is stripped.
138    /// Returns `false` when no CNAME record is present.
139    ///
140    /// # Errors
141    ///
142    /// Returns [`crate::Error`] with status 400 when `domain` or
143    /// `expected_target` is empty, or a gateway error on network/DNS failure.
144    pub async fn check_cname(&self, domain: &str, expected_target: &str) -> Result<bool> {
145        if domain.is_empty() {
146            return Err(Error::bad_request("domain must not be empty")
147                .chain(DnsError::InvalidInput)
148                .with_code(DnsError::InvalidInput.code()));
149        }
150        if expected_target.is_empty() {
151            return Err(Error::bad_request("target must not be empty")
152                .chain(DnsError::InvalidInput)
153                .with_code(DnsError::InvalidInput.code()));
154        }
155
156        let target = self.inner.resolver.resolve_cname(domain).await?;
157
158        match target {
159            Some(resolved) => {
160                let normalized_resolved = normalize_domain(&resolved);
161                let normalized_expected = normalize_domain(expected_target);
162                Ok(normalized_resolved == normalized_expected)
163            }
164            None => Ok(false),
165        }
166    }
167
168    /// Check both TXT ownership and CNAME routing concurrently.
169    ///
170    /// Runs [`check_txt`](Self::check_txt) and
171    /// [`check_cname`](Self::check_cname) in parallel via `tokio::join!`.
172    /// Returns [`DomainStatus`] with individual results.
173    ///
174    /// # Errors
175    ///
176    /// If either check returns a hard error (e.g. network failure) the error
177    /// is propagated and the other result is discarded.
178    pub async fn verify_domain(
179        &self,
180        domain: &str,
181        expected_token: &str,
182        expected_cname: &str,
183    ) -> Result<DomainStatus> {
184        let (txt_result, cname_result) = tokio::join!(
185            self.check_txt(domain, expected_token),
186            self.check_cname(domain, expected_cname),
187        );
188
189        Ok(DomainStatus {
190            txt_verified: txt_result?,
191            cname_verified: cname_result?,
192        })
193    }
194}
195
196/// Normalize a domain name: lowercase, strip trailing dot.
197fn normalize_domain(domain: &str) -> String {
198    domain.to_lowercase().trim_end_matches('.').to_owned()
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use std::collections::HashMap;
205    use std::future::Future;
206    use std::pin::Pin;
207
208    struct MockResolver {
209        txt_records: HashMap<String, Vec<String>>,
210        cname_records: HashMap<String, String>,
211    }
212
213    impl MockResolver {
214        fn new() -> Self {
215            Self {
216                txt_records: HashMap::new(),
217                cname_records: HashMap::new(),
218            }
219        }
220
221        fn with_txt(mut self, domain: &str, records: Vec<&str>) -> Self {
222            self.txt_records.insert(
223                domain.to_owned(),
224                records.into_iter().map(|s| s.to_owned()).collect(),
225            );
226            self
227        }
228
229        fn with_cname(mut self, domain: &str, target: &str) -> Self {
230            self.cname_records
231                .insert(domain.to_owned(), target.to_owned());
232            self
233        }
234    }
235
236    impl DnsResolver for MockResolver {
237        fn resolve_txt(
238            &self,
239            domain: &str,
240        ) -> Pin<Box<dyn Future<Output = Result<Vec<String>>> + Send + '_>> {
241            let records = self.txt_records.get(domain).cloned().unwrap_or_default();
242            Box::pin(async move { Ok(records) })
243        }
244
245        fn resolve_cname(
246            &self,
247            domain: &str,
248        ) -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send + '_>> {
249            let target = self.cname_records.get(domain).cloned();
250            Box::pin(async move { Ok(target) })
251        }
252    }
253
254    fn verifier_with_mock(resolver: MockResolver) -> DomainVerifier {
255        DomainVerifier {
256            inner: Arc::new(Inner {
257                resolver: Arc::new(resolver),
258                txt_prefix: "_modo-verify".into(),
259            }),
260        }
261    }
262
263    // -- check_txt tests --
264
265    #[tokio::test]
266    async fn check_txt_matching_token_returns_true() {
267        let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["abc123"]);
268        let v = verifier_with_mock(mock);
269        assert!(v.check_txt("example.com", "abc123").await.unwrap());
270    }
271
272    #[tokio::test]
273    async fn check_txt_no_match_returns_false() {
274        let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["wrong"]);
275        let v = verifier_with_mock(mock);
276        assert!(!v.check_txt("example.com", "abc123").await.unwrap());
277    }
278
279    #[tokio::test]
280    async fn check_txt_multiple_records_one_matches() {
281        let mock = MockResolver::new().with_txt(
282            "_modo-verify.example.com",
283            vec!["spf-record", "abc123", "other"],
284        );
285        let v = verifier_with_mock(mock);
286        assert!(v.check_txt("example.com", "abc123").await.unwrap());
287    }
288
289    #[tokio::test]
290    async fn check_txt_no_records_returns_false() {
291        let mock = MockResolver::new();
292        let v = verifier_with_mock(mock);
293        assert!(!v.check_txt("example.com", "abc123").await.unwrap());
294    }
295
296    #[tokio::test]
297    async fn check_txt_prefix_is_prepended() {
298        let mock = MockResolver::new().with_txt("_modo-verify.test.io", vec!["token1"]);
299        let v = verifier_with_mock(mock);
300        assert!(v.check_txt("test.io", "token1").await.unwrap());
301    }
302
303    #[tokio::test]
304    async fn check_txt_case_sensitive() {
305        let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["ABC123"]);
306        let v = verifier_with_mock(mock);
307        assert!(!v.check_txt("example.com", "abc123").await.unwrap());
308    }
309
310    #[tokio::test]
311    async fn check_txt_empty_domain_returns_bad_request() {
312        let mock = MockResolver::new();
313        let v = verifier_with_mock(mock);
314        let err = v.check_txt("", "abc123").await.unwrap_err();
315        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
316    }
317
318    #[tokio::test]
319    async fn check_txt_empty_token_returns_bad_request() {
320        let mock = MockResolver::new();
321        let v = verifier_with_mock(mock);
322        let err = v.check_txt("example.com", "").await.unwrap_err();
323        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
324    }
325
326    // -- check_cname tests --
327
328    #[tokio::test]
329    async fn check_cname_matching_target_returns_true() {
330        let mock = MockResolver::new().with_cname("custom.example.com", "app.myservice.com");
331        let v = verifier_with_mock(mock);
332        assert!(
333            v.check_cname("custom.example.com", "app.myservice.com")
334                .await
335                .unwrap()
336        );
337    }
338
339    #[tokio::test]
340    async fn check_cname_trailing_dot_normalized() {
341        let mock = MockResolver::new().with_cname("custom.example.com", "app.myservice.com.");
342        let v = verifier_with_mock(mock);
343        assert!(
344            v.check_cname("custom.example.com", "app.myservice.com")
345                .await
346                .unwrap()
347        );
348    }
349
350    #[tokio::test]
351    async fn check_cname_case_insensitive() {
352        let mock = MockResolver::new().with_cname("custom.example.com", "App.MyService.COM");
353        let v = verifier_with_mock(mock);
354        assert!(
355            v.check_cname("custom.example.com", "app.myservice.com")
356                .await
357                .unwrap()
358        );
359    }
360
361    #[tokio::test]
362    async fn check_cname_no_record_returns_false() {
363        let mock = MockResolver::new();
364        let v = verifier_with_mock(mock);
365        assert!(
366            !v.check_cname("custom.example.com", "app.myservice.com")
367                .await
368                .unwrap()
369        );
370    }
371
372    #[tokio::test]
373    async fn check_cname_no_match_returns_false() {
374        let mock = MockResolver::new().with_cname("custom.example.com", "other.service.com");
375        let v = verifier_with_mock(mock);
376        assert!(
377            !v.check_cname("custom.example.com", "app.myservice.com")
378                .await
379                .unwrap()
380        );
381    }
382
383    #[tokio::test]
384    async fn check_cname_empty_domain_returns_bad_request() {
385        let mock = MockResolver::new();
386        let v = verifier_with_mock(mock);
387        let err = v.check_cname("", "app.myservice.com").await.unwrap_err();
388        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
389    }
390
391    #[tokio::test]
392    async fn check_cname_empty_target_returns_bad_request() {
393        let mock = MockResolver::new();
394        let v = verifier_with_mock(mock);
395        let err = v.check_cname("example.com", "").await.unwrap_err();
396        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
397    }
398
399    // -- verify_domain tests --
400
401    #[tokio::test]
402    async fn verify_domain_both_pass() {
403        let mock = MockResolver::new()
404            .with_txt("_modo-verify.example.com", vec!["token1"])
405            .with_cname("example.com", "app.myservice.com");
406        let v = verifier_with_mock(mock);
407        let status = v
408            .verify_domain("example.com", "token1", "app.myservice.com")
409            .await
410            .unwrap();
411        assert!(status.txt_verified);
412        assert!(status.cname_verified);
413    }
414
415    #[tokio::test]
416    async fn verify_domain_txt_pass_cname_fail() {
417        let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["token1"]);
418        let v = verifier_with_mock(mock);
419        let status = v
420            .verify_domain("example.com", "token1", "app.myservice.com")
421            .await
422            .unwrap();
423        assert!(status.txt_verified);
424        assert!(!status.cname_verified);
425    }
426
427    #[tokio::test]
428    async fn verify_domain_both_fail() {
429        let mock = MockResolver::new();
430        let v = verifier_with_mock(mock);
431        let status = v
432            .verify_domain("example.com", "token1", "app.myservice.com")
433            .await
434            .unwrap();
435        assert!(!status.txt_verified);
436        assert!(!status.cname_verified);
437    }
438
439    #[tokio::test]
440    async fn verify_domain_dns_error_propagates() {
441        struct FailingResolver;
442        impl DnsResolver for FailingResolver {
443            fn resolve_txt(
444                &self,
445                _domain: &str,
446            ) -> Pin<Box<dyn Future<Output = Result<Vec<String>>> + Send + '_>> {
447                Box::pin(async {
448                    Err(Error::bad_gateway("dns server failure")
449                        .chain(DnsError::ServerFailure)
450                        .with_code(DnsError::ServerFailure.code()))
451                })
452            }
453            fn resolve_cname(
454                &self,
455                _domain: &str,
456            ) -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send + '_>> {
457                Box::pin(async { Ok(None) })
458            }
459        }
460
461        let v = DomainVerifier {
462            inner: Arc::new(Inner {
463                resolver: Arc::new(FailingResolver),
464                txt_prefix: "_modo-verify".into(),
465            }),
466        };
467        let err = v
468            .verify_domain("example.com", "token1", "app.myservice.com")
469            .await
470            .unwrap_err();
471        assert_eq!(err.status(), http::StatusCode::BAD_GATEWAY);
472    }
473
474    // -- from_config tests --
475
476    #[test]
477    fn from_config_valid() {
478        let config = DnsConfig {
479            nameserver: "8.8.8.8:53".into(),
480            txt_prefix: "_myapp-verify".into(),
481            timeout_ms: 3000,
482        };
483        let v = DomainVerifier::from_config(&config).unwrap();
484        assert_eq!(v.inner.txt_prefix, "_myapp-verify");
485    }
486
487    #[test]
488    fn from_config_invalid_nameserver_fails() {
489        let config = DnsConfig {
490            nameserver: "not-valid".into(),
491            txt_prefix: "_modo-verify".into(),
492            timeout_ms: 5000,
493        };
494        let err = DomainVerifier::from_config(&config).err().unwrap();
495        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
496    }
497}