1use std::sync::Arc;
4use std::time::Duration;
5
6use crate::error::{Error, Result};
7
8use super::config::DnsConfig;
9use super::error::DnsError;
10use super::resolver::{DnsResolver, UdpDnsResolver};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct DomainStatus {
19 pub txt_verified: bool,
22 pub cname_verified: bool,
24}
25
26pub(crate) struct Inner {
27 pub(crate) resolver: Arc<dyn DnsResolver>,
28 pub(crate) txt_prefix: String,
29}
30
31pub struct DomainVerifier {
53 inner: Arc<Inner>,
54}
55
56impl Clone for DomainVerifier {
57 fn clone(&self) -> Self {
58 Self {
59 inner: Arc::clone(&self.inner),
60 }
61 }
62}
63
64impl DomainVerifier {
65 pub fn from_config(config: &DnsConfig) -> Result<Self> {
73 let nameserver = config.parse_nameserver()?;
74 let timeout = Duration::from_millis(config.timeout_ms);
75 let resolver = UdpDnsResolver::new(nameserver, timeout);
76
77 Ok(Self {
78 inner: Arc::new(Inner {
79 resolver: Arc::new(resolver),
80 txt_prefix: config.txt_prefix.clone(),
81 }),
82 })
83 }
84
85 #[allow(dead_code)]
93 pub(crate) fn with_resolver(
94 resolver: impl DnsResolver + 'static,
95 txt_prefix: impl Into<String>,
96 ) -> Self {
97 Self {
98 inner: Arc::new(Inner {
99 resolver: Arc::new(resolver),
100 txt_prefix: txt_prefix.into(),
101 }),
102 }
103 }
104
105 pub async fn check_txt(&self, domain: &str, expected_token: &str) -> Result<bool> {
117 if domain.is_empty() {
118 return Err(Error::bad_request("domain must not be empty")
119 .chain(DnsError::InvalidInput)
120 .with_code(DnsError::InvalidInput.code()));
121 }
122 if expected_token.is_empty() {
123 return Err(Error::bad_request("token must not be empty")
124 .chain(DnsError::InvalidInput)
125 .with_code(DnsError::InvalidInput.code()));
126 }
127
128 let lookup_domain = format!("{}.{}", self.inner.txt_prefix, domain);
129 let records = self.inner.resolver.resolve_txt(&lookup_domain).await?;
130
131 Ok(records.iter().any(|r| r == expected_token))
132 }
133
134 pub async fn check_cname(&self, domain: &str, expected_target: &str) -> Result<bool> {
145 if domain.is_empty() {
146 return Err(Error::bad_request("domain must not be empty")
147 .chain(DnsError::InvalidInput)
148 .with_code(DnsError::InvalidInput.code()));
149 }
150 if expected_target.is_empty() {
151 return Err(Error::bad_request("target must not be empty")
152 .chain(DnsError::InvalidInput)
153 .with_code(DnsError::InvalidInput.code()));
154 }
155
156 let target = self.inner.resolver.resolve_cname(domain).await?;
157
158 match target {
159 Some(resolved) => {
160 let normalized_resolved = normalize_domain(&resolved);
161 let normalized_expected = normalize_domain(expected_target);
162 Ok(normalized_resolved == normalized_expected)
163 }
164 None => Ok(false),
165 }
166 }
167
168 pub async fn verify_domain(
179 &self,
180 domain: &str,
181 expected_token: &str,
182 expected_cname: &str,
183 ) -> Result<DomainStatus> {
184 let (txt_result, cname_result) = tokio::join!(
185 self.check_txt(domain, expected_token),
186 self.check_cname(domain, expected_cname),
187 );
188
189 Ok(DomainStatus {
190 txt_verified: txt_result?,
191 cname_verified: cname_result?,
192 })
193 }
194}
195
196fn normalize_domain(domain: &str) -> String {
198 domain.to_lowercase().trim_end_matches('.').to_owned()
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use std::collections::HashMap;
205 use std::future::Future;
206 use std::pin::Pin;
207
208 struct MockResolver {
209 txt_records: HashMap<String, Vec<String>>,
210 cname_records: HashMap<String, String>,
211 }
212
213 impl MockResolver {
214 fn new() -> Self {
215 Self {
216 txt_records: HashMap::new(),
217 cname_records: HashMap::new(),
218 }
219 }
220
221 fn with_txt(mut self, domain: &str, records: Vec<&str>) -> Self {
222 self.txt_records.insert(
223 domain.to_owned(),
224 records.into_iter().map(|s| s.to_owned()).collect(),
225 );
226 self
227 }
228
229 fn with_cname(mut self, domain: &str, target: &str) -> Self {
230 self.cname_records
231 .insert(domain.to_owned(), target.to_owned());
232 self
233 }
234 }
235
236 impl DnsResolver for MockResolver {
237 fn resolve_txt(
238 &self,
239 domain: &str,
240 ) -> Pin<Box<dyn Future<Output = Result<Vec<String>>> + Send + '_>> {
241 let records = self.txt_records.get(domain).cloned().unwrap_or_default();
242 Box::pin(async move { Ok(records) })
243 }
244
245 fn resolve_cname(
246 &self,
247 domain: &str,
248 ) -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send + '_>> {
249 let target = self.cname_records.get(domain).cloned();
250 Box::pin(async move { Ok(target) })
251 }
252 }
253
254 fn verifier_with_mock(resolver: MockResolver) -> DomainVerifier {
255 DomainVerifier {
256 inner: Arc::new(Inner {
257 resolver: Arc::new(resolver),
258 txt_prefix: "_modo-verify".into(),
259 }),
260 }
261 }
262
263 #[tokio::test]
266 async fn check_txt_matching_token_returns_true() {
267 let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["abc123"]);
268 let v = verifier_with_mock(mock);
269 assert!(v.check_txt("example.com", "abc123").await.unwrap());
270 }
271
272 #[tokio::test]
273 async fn check_txt_no_match_returns_false() {
274 let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["wrong"]);
275 let v = verifier_with_mock(mock);
276 assert!(!v.check_txt("example.com", "abc123").await.unwrap());
277 }
278
279 #[tokio::test]
280 async fn check_txt_multiple_records_one_matches() {
281 let mock = MockResolver::new().with_txt(
282 "_modo-verify.example.com",
283 vec!["spf-record", "abc123", "other"],
284 );
285 let v = verifier_with_mock(mock);
286 assert!(v.check_txt("example.com", "abc123").await.unwrap());
287 }
288
289 #[tokio::test]
290 async fn check_txt_no_records_returns_false() {
291 let mock = MockResolver::new();
292 let v = verifier_with_mock(mock);
293 assert!(!v.check_txt("example.com", "abc123").await.unwrap());
294 }
295
296 #[tokio::test]
297 async fn check_txt_prefix_is_prepended() {
298 let mock = MockResolver::new().with_txt("_modo-verify.test.io", vec!["token1"]);
299 let v = verifier_with_mock(mock);
300 assert!(v.check_txt("test.io", "token1").await.unwrap());
301 }
302
303 #[tokio::test]
304 async fn check_txt_case_sensitive() {
305 let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["ABC123"]);
306 let v = verifier_with_mock(mock);
307 assert!(!v.check_txt("example.com", "abc123").await.unwrap());
308 }
309
310 #[tokio::test]
311 async fn check_txt_empty_domain_returns_bad_request() {
312 let mock = MockResolver::new();
313 let v = verifier_with_mock(mock);
314 let err = v.check_txt("", "abc123").await.unwrap_err();
315 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
316 }
317
318 #[tokio::test]
319 async fn check_txt_empty_token_returns_bad_request() {
320 let mock = MockResolver::new();
321 let v = verifier_with_mock(mock);
322 let err = v.check_txt("example.com", "").await.unwrap_err();
323 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
324 }
325
326 #[tokio::test]
329 async fn check_cname_matching_target_returns_true() {
330 let mock = MockResolver::new().with_cname("custom.example.com", "app.myservice.com");
331 let v = verifier_with_mock(mock);
332 assert!(
333 v.check_cname("custom.example.com", "app.myservice.com")
334 .await
335 .unwrap()
336 );
337 }
338
339 #[tokio::test]
340 async fn check_cname_trailing_dot_normalized() {
341 let mock = MockResolver::new().with_cname("custom.example.com", "app.myservice.com.");
342 let v = verifier_with_mock(mock);
343 assert!(
344 v.check_cname("custom.example.com", "app.myservice.com")
345 .await
346 .unwrap()
347 );
348 }
349
350 #[tokio::test]
351 async fn check_cname_case_insensitive() {
352 let mock = MockResolver::new().with_cname("custom.example.com", "App.MyService.COM");
353 let v = verifier_with_mock(mock);
354 assert!(
355 v.check_cname("custom.example.com", "app.myservice.com")
356 .await
357 .unwrap()
358 );
359 }
360
361 #[tokio::test]
362 async fn check_cname_no_record_returns_false() {
363 let mock = MockResolver::new();
364 let v = verifier_with_mock(mock);
365 assert!(
366 !v.check_cname("custom.example.com", "app.myservice.com")
367 .await
368 .unwrap()
369 );
370 }
371
372 #[tokio::test]
373 async fn check_cname_no_match_returns_false() {
374 let mock = MockResolver::new().with_cname("custom.example.com", "other.service.com");
375 let v = verifier_with_mock(mock);
376 assert!(
377 !v.check_cname("custom.example.com", "app.myservice.com")
378 .await
379 .unwrap()
380 );
381 }
382
383 #[tokio::test]
384 async fn check_cname_empty_domain_returns_bad_request() {
385 let mock = MockResolver::new();
386 let v = verifier_with_mock(mock);
387 let err = v.check_cname("", "app.myservice.com").await.unwrap_err();
388 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
389 }
390
391 #[tokio::test]
392 async fn check_cname_empty_target_returns_bad_request() {
393 let mock = MockResolver::new();
394 let v = verifier_with_mock(mock);
395 let err = v.check_cname("example.com", "").await.unwrap_err();
396 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
397 }
398
399 #[tokio::test]
402 async fn verify_domain_both_pass() {
403 let mock = MockResolver::new()
404 .with_txt("_modo-verify.example.com", vec!["token1"])
405 .with_cname("example.com", "app.myservice.com");
406 let v = verifier_with_mock(mock);
407 let status = v
408 .verify_domain("example.com", "token1", "app.myservice.com")
409 .await
410 .unwrap();
411 assert!(status.txt_verified);
412 assert!(status.cname_verified);
413 }
414
415 #[tokio::test]
416 async fn verify_domain_txt_pass_cname_fail() {
417 let mock = MockResolver::new().with_txt("_modo-verify.example.com", vec!["token1"]);
418 let v = verifier_with_mock(mock);
419 let status = v
420 .verify_domain("example.com", "token1", "app.myservice.com")
421 .await
422 .unwrap();
423 assert!(status.txt_verified);
424 assert!(!status.cname_verified);
425 }
426
427 #[tokio::test]
428 async fn verify_domain_both_fail() {
429 let mock = MockResolver::new();
430 let v = verifier_with_mock(mock);
431 let status = v
432 .verify_domain("example.com", "token1", "app.myservice.com")
433 .await
434 .unwrap();
435 assert!(!status.txt_verified);
436 assert!(!status.cname_verified);
437 }
438
439 #[tokio::test]
440 async fn verify_domain_dns_error_propagates() {
441 struct FailingResolver;
442 impl DnsResolver for FailingResolver {
443 fn resolve_txt(
444 &self,
445 _domain: &str,
446 ) -> Pin<Box<dyn Future<Output = Result<Vec<String>>> + Send + '_>> {
447 Box::pin(async {
448 Err(Error::bad_gateway("dns server failure")
449 .chain(DnsError::ServerFailure)
450 .with_code(DnsError::ServerFailure.code()))
451 })
452 }
453 fn resolve_cname(
454 &self,
455 _domain: &str,
456 ) -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send + '_>> {
457 Box::pin(async { Ok(None) })
458 }
459 }
460
461 let v = DomainVerifier {
462 inner: Arc::new(Inner {
463 resolver: Arc::new(FailingResolver),
464 txt_prefix: "_modo-verify".into(),
465 }),
466 };
467 let err = v
468 .verify_domain("example.com", "token1", "app.myservice.com")
469 .await
470 .unwrap_err();
471 assert_eq!(err.status(), http::StatusCode::BAD_GATEWAY);
472 }
473
474 #[test]
477 fn from_config_valid() {
478 let config = DnsConfig {
479 nameserver: "8.8.8.8:53".into(),
480 txt_prefix: "_myapp-verify".into(),
481 timeout_ms: 3000,
482 };
483 let v = DomainVerifier::from_config(&config).unwrap();
484 assert_eq!(v.inner.txt_prefix, "_myapp-verify");
485 }
486
487 #[test]
488 fn from_config_invalid_nameserver_fails() {
489 let config = DnsConfig {
490 nameserver: "not-valid".into(),
491 txt_prefix: "_modo-verify".into(),
492 timeout_ms: 5000,
493 };
494 let err = DomainVerifier::from_config(&config).err().unwrap();
495 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
496 }
497}