1use std::{future::Future, net::IpAddr, sync::Arc, time::Duration};
7
8use hickory_resolver::{
9 Resolver,
10 config::{ConnectionConfig, NameServerConfig, ResolverConfig, ResolverOpts},
11 net::runtime::TokioRuntimeProvider,
12 proto::rr::RecordType,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 control_plane::config::{ValidationEndpointConfig, ValidationTransport},
19 core::dns::{
20 records::RecordData,
21 responses::{AnyRecordData, ListRecordsResponse, ZoneRecord},
22 },
23};
24
25fn default_enabled() -> bool {
26 true
27}
28
29pub type DnsEndpointResolverResult<T> = std::result::Result<T, ValidationFailureKind>;
31
32pub trait DnsEndpointResolver {
37 fn query_endpoint<'a>(
39 &'a self,
40 endpoint: &'a ValidationEndpointConfig,
41 fqdn: &'a str,
42 record_type: &'a str,
43 timeout: Duration,
44 ) -> impl Future<Output = DnsEndpointResolverResult<Vec<ObservedRecord>>> + Send + 'a;
45}
46#[derive(Debug, Clone, Copy, Default)]
48pub struct HickoryDnsEndpointResolver;
49
50impl HickoryDnsEndpointResolver {
51 pub fn resolver_for_endpoint(
53 endpoint: &ValidationEndpointConfig,
54 timeout: Duration,
55 ) -> DnsEndpointResolverResult<Resolver<TokioRuntimeProvider>> {
56 let mut opts = ResolverOpts::default();
57 opts.timeout = timeout;
58 opts.attempts = 1;
59
60 Resolver::builder_with_config(resolver_config(endpoint)?, TokioRuntimeProvider::default())
61 .with_options(opts)
62 .build()
63 .map_err(|err| classify_hickory_error(endpoint.transport, &err.to_string()))
64 }
65}
66
67impl DnsEndpointResolver for HickoryDnsEndpointResolver {
68 fn query_endpoint<'a>(
69 &'a self,
70 endpoint: &'a ValidationEndpointConfig,
71 fqdn: &'a str,
72 record_type: &'a str,
73 timeout: Duration,
74 ) -> impl Future<Output = DnsEndpointResolverResult<Vec<ObservedRecord>>> + Send + 'a {
75 async move {
76 let rr_type = record_type
77 .parse::<RecordType>()
78 .map_err(|_| ValidationFailureKind::MalformedResponse)?;
79 let resolver = Self::resolver_for_endpoint(endpoint, timeout)?;
80
81 let lookup = tokio::time::timeout(timeout, resolver.lookup(fqdn, rr_type))
82 .await
83 .map_err(|_| ValidationFailureKind::Timeout)?
84 .map_err(|err| classify_hickory_error(endpoint.transport, &err.to_string()))?;
85
86 Ok(vec![ObservedRecord {
87 name: fqdn.to_string(),
88 record_type: record_type.to_ascii_uppercase(),
89 values: lookup
90 .answers()
91 .iter()
92 .map(|record| record.data.to_string())
93 .collect(),
94 }])
95 }
96 }
97}
98
99fn resolver_config(
100 endpoint: &ValidationEndpointConfig,
101) -> DnsEndpointResolverResult<ResolverConfig> {
102 let name_server = match endpoint.transport {
103 ValidationTransport::Dns => plain_dns_name_server(endpoint)?,
104 ValidationTransport::Dot => dot_name_server(endpoint)?,
105 ValidationTransport::Doh => doh_name_server(endpoint)?,
106 };
107
108 Ok(ResolverConfig::from_parts(
109 None,
110 Vec::new(),
111 vec![name_server],
112 ))
113}
114
115fn plain_dns_name_server(
116 endpoint: &ValidationEndpointConfig,
117) -> DnsEndpointResolverResult<NameServerConfig> {
118 let ip = endpoint_ip(endpoint)?;
119 let port = endpoint.port.unwrap_or(53);
120 let mut udp = ConnectionConfig::udp();
121 udp.port = port;
122 let mut tcp = ConnectionConfig::tcp();
123 tcp.port = port;
124
125 Ok(NameServerConfig::new(ip, true, vec![udp, tcp]))
126}
127
128fn dot_name_server(
129 endpoint: &ValidationEndpointConfig,
130) -> DnsEndpointResolverResult<NameServerConfig> {
131 let ip = endpoint_ip(endpoint)?;
132 let server_name = tls_server_name(endpoint)?.into();
133 let mut tls = ConnectionConfig::tls(server_name);
134 tls.port = endpoint.port.unwrap_or(853);
135
136 Ok(NameServerConfig::new(ip, true, vec![tls]))
137}
138
139fn doh_name_server(
140 endpoint: &ValidationEndpointConfig,
141) -> DnsEndpointResolverResult<NameServerConfig> {
142 let (host, path) = doh_url_parts(endpoint)?;
143 let ip = if endpoint.address.trim().is_empty() {
144 host.parse::<IpAddr>()
145 .map_err(|_| ValidationFailureKind::MalformedResponse)?
146 } else {
147 endpoint_ip(endpoint)?
148 };
149 let server_name = endpoint
150 .tls_server_name
151 .as_deref()
152 .filter(|name| !name.trim().is_empty())
153 .unwrap_or(host)
154 .to_string();
155 let mut https = ConnectionConfig::https(Arc::from(server_name), Some(Arc::from(path)));
156 https.port = endpoint.port.unwrap_or(443);
157
158 Ok(NameServerConfig::new(ip, true, vec![https]))
159}
160
161fn endpoint_ip(endpoint: &ValidationEndpointConfig) -> DnsEndpointResolverResult<IpAddr> {
162 endpoint
163 .address
164 .parse::<IpAddr>()
165 .map_err(|_| ValidationFailureKind::MalformedResponse)
166}
167
168fn tls_server_name(endpoint: &ValidationEndpointConfig) -> DnsEndpointResolverResult<String> {
169 endpoint
170 .tls_server_name
171 .as_deref()
172 .filter(|name| !name.trim().is_empty())
173 .map(str::to_string)
174 .or_else(|| (!endpoint.address.trim().is_empty()).then(|| endpoint.address.clone()))
175 .ok_or(ValidationFailureKind::MalformedResponse)
176}
177
178fn doh_url_parts(endpoint: &ValidationEndpointConfig) -> DnsEndpointResolverResult<(&str, &str)> {
179 let url = endpoint
180 .url
181 .as_deref()
182 .ok_or(ValidationFailureKind::MalformedResponse)?;
183 let without_scheme = url
184 .strip_prefix("https://")
185 .ok_or(ValidationFailureKind::DohHttpFailure)?;
186 let (authority, path) = without_scheme
187 .split_once('/')
188 .unwrap_or((without_scheme, "dns-query"));
189 let host = authority
190 .rsplit_once('@')
191 .map_or(authority, |(_, host_port)| host_port)
192 .split_once(':')
193 .map_or(authority, |(host, _)| host);
194
195 if host.trim().is_empty() {
196 return Err(ValidationFailureKind::MalformedResponse);
197 }
198
199 Ok((
200 host,
201 if path.is_empty() {
202 "/dns-query"
203 } else {
204 &url[url.len() - path.len() - 1..]
205 },
206 ))
207}
208
209fn classify_hickory_error(transport: ValidationTransport, error: &str) -> ValidationFailureKind {
210 let error = error.to_ascii_lowercase();
211
212 if error.contains("timed out") || error.contains("timeout") {
213 ValidationFailureKind::Timeout
214 } else if error.contains("nxdomain") || error.contains("no records found") {
215 ValidationFailureKind::Nxdomain
216 } else if error.contains("servfail") || error.contains("server failure") {
217 ValidationFailureKind::Servfail
218 } else if error.contains("refused") {
219 ValidationFailureKind::Refused
220 } else if matches!(transport, ValidationTransport::Dot) || error.contains("tls") {
221 ValidationFailureKind::TlsFailure
222 } else if matches!(transport, ValidationTransport::Doh) || error.contains("http") {
223 ValidationFailureKind::DohHttpFailure
224 } else {
225 ValidationFailureKind::MalformedResponse
226 }
227}
228
229#[must_use]
231pub fn endpoint_timeout(endpoint: &ValidationEndpointConfig) -> Duration {
232 Duration::from_millis(endpoint.timeout_ms.unwrap_or(5_000))
233}
234
235#[must_use]
237pub fn expected_records_from_response(
238 response: &ListRecordsResponse,
239) -> (Vec<ExpectedRecord>, Vec<SkippedRecord>) {
240 let mut expected = Vec::new();
241 let mut skipped = Vec::new();
242
243 for zone_records in &response.zones {
244 for record in &zone_records.records {
245 match expected_record_from_zone_record(&zone_records.zone.name, record) {
246 Ok(record) => expected.push(record),
247 Err(skip) => skipped.push(skip),
248 }
249 }
250 }
251
252 (expected, skipped)
253}
254
255#[must_use]
257pub fn compare_rrsets(
258 expected: &[ExpectedRecord],
259 observed: &[ObservedRecord],
260) -> Vec<RecordValidationResult> {
261 use std::collections::{BTreeMap, BTreeSet};
262
263 let expected_sets = expected.iter().fold(BTreeMap::new(), |mut acc, record| {
264 let key = normalized_rrset_key(&record.name, &record.record_type);
265 let values = normalize_values(&record.record_type, &record.values);
266 acc.entry(key).or_insert_with(BTreeSet::new).extend(values);
267 acc
268 });
269 let observed_sets = observed.iter().fold(BTreeMap::new(), |mut acc, record| {
270 let key = normalized_rrset_key(&record.name, &record.record_type);
271 let values = normalize_values(&record.record_type, &record.values);
272 acc.entry(key).or_insert_with(BTreeSet::new).extend(values);
273 acc
274 });
275
276 let mut results = Vec::new();
277 for ((name, record_type), expected_values) in &expected_sets {
278 let observed_values = observed_sets
279 .get(&(name.clone(), record_type.clone()))
280 .cloned()
281 .unwrap_or_default();
282
283 if observed_values.is_empty() {
284 results.push(mismatched_result(
285 name,
286 record_type,
287 expected_values,
288 &observed_values,
289 "missing",
290 ));
291 } else if expected_values == &observed_values {
292 results.push(RecordValidationResult {
293 name: name.clone(),
294 record_type: record_type.clone(),
295 status: ValidationStatus::Passed,
296 mismatch: None,
297 failure_kind: None,
298 skip_reason: None,
299 });
300 } else {
301 let mismatch_kind = if !expected_values.is_subset(&observed_values) {
302 "wrong_value"
303 } else {
304 "extra"
305 };
306 results.push(mismatched_result(
307 name,
308 record_type,
309 expected_values,
310 &observed_values,
311 mismatch_kind,
312 ));
313 }
314 }
315
316 for ((name, record_type), observed_values) in observed_sets {
317 if !expected_sets.contains_key(&(name.clone(), record_type.clone())) {
318 results.push(mismatched_result(
319 &name,
320 &record_type,
321 &BTreeSet::new(),
322 &observed_values,
323 "extra",
324 ));
325 }
326 }
327
328 results
329}
330
331fn expected_record_from_zone_record(
332 zone: &str,
333 record: &ZoneRecord,
334) -> std::result::Result<ExpectedRecord, SkippedRecord> {
335 let record_type = record.record_type.to_ascii_uppercase();
336 let name = normalize_domain_name(&fqdn_for_record(&record.name, zone));
337 let values = match record.parsed.as_ref() {
338 Some(AnyRecordData::Writable(data)) => values_from_record_data(data),
339 Some(AnyRecordData::ReadOnly(_)) | None => None,
340 };
341
342 match values {
343 Some(values) => Ok(ExpectedRecord {
344 name,
345 record_type,
346 values,
347 }),
348 None => Err(SkippedRecord {
349 name,
350 record_type,
351 reason: "unsupported_record_type".to_string(),
352 }),
353 }
354}
355
356fn values_from_record_data(record: &RecordData) -> Option<Vec<String>> {
357 match record {
358 RecordData::A { ip } => Some(vec![ip.to_string()]),
359 RecordData::Aaaa { ip } => Some(vec![ip.to_string()]),
360 RecordData::Cname { target } => Some(vec![target.clone()]),
361 RecordData::Txt { text, .. } => Some(vec![text.clone()]),
362 RecordData::Mx {
363 preference,
364 exchange,
365 } => Some(vec![format!("{preference} {exchange}")]),
366 RecordData::Ns { nameserver, .. } => Some(vec![nameserver.clone()]),
367 RecordData::Srv {
368 priority,
369 weight,
370 port,
371 target,
372 } => Some(vec![format!("{priority} {weight} {port} {target}")]),
373 RecordData::Caa { flags, tag, value } => Some(vec![format!("{flags} {tag} {value}")]),
374 _ => None,
375 }
376}
377
378fn mismatched_result(
379 name: &str,
380 record_type: &str,
381 expected: &std::collections::BTreeSet<String>,
382 observed: &std::collections::BTreeSet<String>,
383 mismatch_kind: &str,
384) -> RecordValidationResult {
385 RecordValidationResult {
386 name: name.to_string(),
387 record_type: record_type.to_string(),
388 status: ValidationStatus::Mismatched,
389 mismatch: Some(RecordMismatch {
390 name: name.to_string(),
391 record_type: record_type.to_string(),
392 expected: expected.iter().cloned().collect(),
393 observed: observed.iter().cloned().collect(),
394 mismatch_kind: mismatch_kind.to_string(),
395 }),
396 failure_kind: None,
397 skip_reason: None,
398 }
399}
400
401fn normalized_rrset_key(name: &str, record_type: &str) -> (String, String) {
402 (
403 normalize_domain_name(name),
404 record_type.trim().to_ascii_uppercase(),
405 )
406}
407
408fn normalize_values(record_type: &str, values: &[String]) -> std::collections::BTreeSet<String> {
409 values
410 .iter()
411 .map(|value| normalize_record_value(record_type, value))
412 .collect()
413}
414
415fn normalize_record_value(record_type: &str, value: &str) -> String {
416 let value = value.trim();
417 match record_type.to_ascii_uppercase().as_str() {
418 "CNAME" | "NS" => normalize_domain_name(value),
419 "MX" => normalize_priority_target(value),
420 "SRV" => normalize_srv(value),
421 "TXT" => normalize_txt(value),
422 "CAA" => normalize_caa(value),
423 _ => value.trim_end_matches('.').to_ascii_lowercase(),
424 }
425}
426
427fn normalize_domain_name(value: &str) -> String {
428 value.trim().trim_end_matches('.').to_ascii_lowercase()
429}
430
431fn normalize_priority_target(value: &str) -> String {
432 let mut parts = value.split_whitespace();
433 let preference = parts.next().unwrap_or_default();
434 let target = parts.next().unwrap_or_default();
435 format!("{} {}", preference, normalize_domain_name(target))
436}
437
438fn normalize_srv(value: &str) -> String {
439 let mut parts = value.split_whitespace();
440 let priority = parts.next().unwrap_or_default();
441 let weight = parts.next().unwrap_or_default();
442 let port = parts.next().unwrap_or_default();
443 let target = parts.next().unwrap_or_default();
444 format!(
445 "{} {} {} {}",
446 priority,
447 weight,
448 port,
449 normalize_domain_name(target)
450 )
451}
452
453fn normalize_txt(value: &str) -> String {
454 value
455 .trim()
456 .replace("\" \"", "")
457 .trim_matches('"')
458 .to_string()
459}
460
461fn normalize_caa(value: &str) -> String {
462 let mut parts = value.split_whitespace();
463 let flags = parts.next().unwrap_or_default();
464 let tag = parts.next().unwrap_or_default().to_ascii_lowercase();
465 let value = parts.collect::<Vec<_>>().join(" ");
466 format!("{flags} {tag} {value}")
467}
468
469fn fqdn_for_record(name: &str, zone: &str) -> String {
470 let name = name.trim_end_matches('.');
471 let zone = zone.trim_end_matches('.');
472 if name == "@" || name.eq_ignore_ascii_case(zone) {
473 zone.to_string()
474 } else if name
475 .to_ascii_lowercase()
476 .ends_with(&format!(".{}", zone.to_ascii_lowercase()))
477 {
478 name.to_string()
479 } else {
480 format!("{name}.{zone}")
481 }
482}
483
484#[cfg(test)]
486#[derive(Debug, Clone)]
487pub struct FakeDnsEndpointResolver {
488 result: DnsEndpointResolverResult<Vec<ObservedRecord>>,
489}
490
491#[cfg(test)]
492impl FakeDnsEndpointResolver {
493 pub fn with_records(records: Vec<ObservedRecord>) -> Self {
494 Self {
495 result: Ok(records),
496 }
497 }
498
499 pub fn with_failure(failure: ValidationFailureKind) -> Self {
500 Self {
501 result: Err(failure),
502 }
503 }
504}
505
506#[cfg(test)]
507impl DnsEndpointResolver for FakeDnsEndpointResolver {
508 fn query_endpoint(
509 &self,
510 _endpoint: &ValidationEndpointConfig,
511 _fqdn: &str,
512 _record_type: &str,
513 _timeout: Duration,
514 ) -> impl Future<Output = DnsEndpointResolverResult<Vec<ObservedRecord>>> + Send + '_ {
515 std::future::ready(self.result.clone())
516 }
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
521#[serde(rename_all = "camelCase")]
522pub struct ValidationOptions {
523 #[serde(default = "default_enabled")]
524 pub enabled: bool,
525 #[serde(default, skip_serializing_if = "Option::is_none")]
526 pub endpoint_filter: Option<Vec<String>>,
527}
528
529impl Default for ValidationOptions {
530 fn default() -> Self {
531 Self {
532 enabled: true,
533 endpoint_filter: None,
534 }
535 }
536}
537
538#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
540#[serde(rename_all = "camelCase")]
541pub struct ValidationRequest {
542 pub zone: String,
543 #[serde(default, skip_serializing_if = "Option::is_none")]
544 pub domain: Option<String>,
545 #[serde(default)]
546 pub expected_records: Vec<ExpectedRecord>,
547 #[serde(default)]
548 pub options: ValidationOptions,
549}
550
551#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
553#[serde(rename_all = "camelCase")]
554pub struct ExpectedRecord {
555 pub name: String,
556 pub record_type: String,
557 pub values: Vec<String>,
558}
559
560#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
562#[serde(rename_all = "camelCase")]
563pub struct ObservedRecord {
564 pub name: String,
565 pub record_type: String,
566 pub values: Vec<String>,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
571#[serde(rename_all = "lowercase")]
572pub enum ValidationStatus {
573 Passed,
574 Mismatched,
575 Skipped,
576 Failed,
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
581#[serde(rename_all = "snake_case")]
582pub enum ValidationFailureKind {
583 Timeout,
584 Nxdomain,
585 Servfail,
586 Refused,
587 TlsFailure,
588 DohHttpFailure,
589 MalformedResponse,
590 UnsupportedTransport,
591}
592
593#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
595#[serde(rename_all = "camelCase")]
596pub struct RecordMismatch {
597 pub name: String,
598 pub record_type: String,
599 pub expected: Vec<String>,
600 pub observed: Vec<String>,
601 pub mismatch_kind: String,
602}
603
604#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
606#[serde(rename_all = "camelCase")]
607pub struct SkippedRecord {
608 pub name: String,
609 pub record_type: String,
610 pub reason: String,
611}
612
613#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
615#[serde(rename_all = "camelCase")]
616pub struct RecordValidationResult {
617 pub name: String,
618 pub record_type: String,
619 pub status: ValidationStatus,
620 #[serde(default, skip_serializing_if = "Option::is_none")]
621 pub mismatch: Option<RecordMismatch>,
622 #[serde(default, skip_serializing_if = "Option::is_none")]
623 pub failure_kind: Option<ValidationFailureKind>,
624 #[serde(default, skip_serializing_if = "Option::is_none")]
625 pub skip_reason: Option<String>,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
630#[serde(rename_all = "camelCase")]
631pub struct EndpointValidationReport {
632 pub endpoint_name: String,
633 pub transport: String,
634 pub address: String,
635 pub status: ValidationStatus,
636 #[serde(default)]
637 pub results: Vec<RecordValidationResult>,
638 #[serde(default)]
639 pub mismatches: Vec<RecordMismatch>,
640 #[serde(default)]
641 pub skipped: Vec<SkippedRecord>,
642 #[serde(default)]
643 pub failures: Vec<ValidationFailureKind>,
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
648#[serde(rename_all = "camelCase")]
649pub struct ValidationReport {
650 pub enabled: bool,
651 pub status: ValidationStatus,
652 #[serde(default, skip_serializing_if = "Option::is_none")]
653 pub zone: Option<String>,
654 #[serde(default, skip_serializing_if = "Option::is_none")]
655 pub domain: Option<String>,
656 #[serde(default, skip_serializing_if = "Option::is_none")]
658 pub phase: Option<String>,
659 #[serde(default)]
660 pub endpoints: Vec<EndpointValidationReport>,
661 #[serde(default)]
662 pub results: Vec<RecordValidationResult>,
663 #[serde(default)]
664 pub mismatches: Vec<RecordMismatch>,
665 #[serde(default)]
666 pub skipped: Vec<SkippedRecord>,
667 #[serde(default)]
668 pub failures: Vec<ValidationFailureKind>,
669}
670
671impl ValidationReport {
672 #[must_use]
674 pub fn disabled() -> Self {
675 Self {
676 enabled: false,
677 status: ValidationStatus::Skipped,
678 zone: None,
679 domain: None,
680 phase: None,
681 endpoints: Vec::new(),
682 results: Vec::new(),
683 mismatches: Vec::new(),
684 skipped: vec![SkippedRecord {
685 name: "*".to_string(),
686 record_type: "*".to_string(),
687 reason: "validation_disabled".to_string(),
688 }],
689 failures: Vec::new(),
690 }
691 }
692
693 #[must_use]
695 pub fn skipped_no_endpoints() -> Self {
696 Self::skipped("no_validation_endpoints_configured")
697 }
698
699 #[must_use]
701 pub fn skipped(reason: &str) -> Self {
702 Self {
703 enabled: true,
704 status: ValidationStatus::Skipped,
705 zone: None,
706 domain: None,
707 phase: None,
708 endpoints: Vec::new(),
709 results: Vec::new(),
710 mismatches: Vec::new(),
711 skipped: vec![SkippedRecord {
712 name: "*".to_string(),
713 record_type: "*".to_string(),
714 reason: reason.to_string(),
715 }],
716 failures: Vec::new(),
717 }
718 }
719
720 #[must_use]
722 pub const fn overall_status(&self) -> &ValidationStatus {
723 &self.status
724 }
725
726 #[must_use]
728 pub fn is_passed(&self) -> bool {
729 self.status == ValidationStatus::Passed
730 }
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736 use crate::core::dns::responses::{ZoneInfo, ZoneRecords};
737 use rstest::{fixture, rstest};
738 use serde_json::{Value, json};
739 use std::net::{Ipv4Addr, Ipv6Addr};
740
741 #[fixture]
742 fn expected_record() -> ExpectedRecord {
743 ExpectedRecord {
744 name: "www.example.com".to_string(),
745 record_type: "A".to_string(),
746 values: vec!["192.0.2.10".to_string()],
747 }
748 }
749
750 #[fixture]
751 fn mismatch() -> RecordMismatch {
752 RecordMismatch {
753 name: "www.example.com".to_string(),
754 record_type: "A".to_string(),
755 expected: vec!["192.0.2.10".to_string()],
756 observed: vec!["192.0.2.11".to_string()],
757 mismatch_kind: "wrong_value".to_string(),
758 }
759 }
760
761 #[fixture]
762 fn mismatched_result(mismatch: RecordMismatch) -> RecordValidationResult {
763 RecordValidationResult {
764 name: mismatch.name.clone(),
765 record_type: mismatch.record_type.clone(),
766 status: ValidationStatus::Mismatched,
767 mismatch: Some(mismatch),
768 failure_kind: None,
769 skip_reason: None,
770 }
771 }
772
773 #[fixture]
774 fn endpoint_report(
775 mismatch: RecordMismatch,
776 mismatched_result: RecordValidationResult,
777 ) -> EndpointValidationReport {
778 EndpointValidationReport {
779 endpoint_name: "public-doh".to_string(),
780 transport: "doh".to_string(),
781 address: "https://dns.example/dns-query".to_string(),
782 status: ValidationStatus::Mismatched,
783 results: vec![mismatched_result],
784 mismatches: vec![mismatch],
785 skipped: vec![SkippedRecord {
786 name: "dnskey.example.com".to_string(),
787 record_type: "DNSKEY".to_string(),
788 reason: "unsupported record type".to_string(),
789 }],
790 failures: vec![ValidationFailureKind::DohHttpFailure],
791 }
792 }
793
794 fn validation_endpoint(transport: ValidationTransport) -> ValidationEndpointConfig {
795 ValidationEndpointConfig {
796 name: "test-endpoint".to_string(),
797 transport,
798 address: if matches!(transport, ValidationTransport::Doh) {
799 String::new()
800 } else {
801 "127.0.0.1".to_string()
802 },
803 port: None,
804 url: matches!(transport, ValidationTransport::Doh)
805 .then(|| "https://127.0.0.1/dns-query".to_string()),
806 tls_server_name: matches!(transport, ValidationTransport::Dot)
807 .then(|| "dns.example.test".to_string()),
808 enabled: true,
809 timeout_ms: Some(10),
810 }
811 }
812
813 #[fixture]
814 fn validation_report(
815 endpoint_report: EndpointValidationReport,
816 mismatch: RecordMismatch,
817 mismatched_result: RecordValidationResult,
818 ) -> ValidationReport {
819 ValidationReport {
820 enabled: true,
821 status: ValidationStatus::Mismatched,
822 zone: Some("example.com".to_string()),
823 domain: Some("www.example.com".to_string()),
824 phase: Some("transfer_pre".to_string()),
825 endpoints: vec![endpoint_report],
826 results: vec![mismatched_result],
827 mismatches: vec![mismatch],
828 skipped: vec![SkippedRecord {
829 name: "dnskey.example.com".to_string(),
830 record_type: "DNSKEY".to_string(),
831 reason: "unsupported record type".to_string(),
832 }],
833 failures: vec![ValidationFailureKind::DohHttpFailure],
834 }
835 }
836
837 fn zone_info() -> ZoneInfo {
838 ZoneInfo {
839 id: None,
840 name: "example.test".to_string(),
841 zone_type: "Primary".to_string(),
842 disabled: false,
843 dnssec_status: None,
844 }
845 }
846
847 fn zone_record(name: &str, ttl: u32, data: RecordData) -> ZoneRecord {
848 ZoneRecord {
849 name: name.to_string(),
850 record_type: data.type_name().to_string(),
851 ttl,
852 disabled: false,
853 comments: String::new(),
854 expiry_ttl: 0,
855 data: serde_json::to_value(&data).expect("record data serializes"),
856 parsed: Some(AnyRecordData::Writable(data)),
857 }
858 }
859
860 fn list_response(records: Vec<ZoneRecord>) -> ListRecordsResponse {
861 ListRecordsResponse {
862 zones: vec![ZoneRecords {
863 zone: zone_info(),
864 records,
865 }],
866 }
867 }
868
869 #[rstest]
870 fn validation_report_json_shape(validation_report: ValidationReport) {
871 let value = serde_json::to_value(validation_report).expect("report serializes to JSON");
872
873 assert_eq!(value["enabled"], json!(true));
874 assert_eq!(value["status"], json!("mismatched"));
875 assert_eq!(value["phase"], json!("transfer_pre"));
876 assert!(value["endpoints"].is_array());
877 assert!(value["results"].is_array());
878 assert!(value["mismatches"].is_array());
879 assert!(value["skipped"].is_array());
880 assert!(value["failures"].is_array());
881 assert_eq!(value["failures"][0], json!("doh_http_failure"));
882 assert_eq!(value["results"][0]["status"], json!("mismatched"));
883 assert_eq!(value["mismatches"][0]["mismatchKind"], json!("wrong_value"));
884 assert_eq!(value["endpoints"][0]["endpointName"], json!("public-doh"));
885 }
886
887 #[rstest]
888 fn validation_disabled_report_shape() {
889 let report = ValidationReport::disabled();
890 let value = serde_json::to_value(&report).expect("disabled report serializes to JSON");
891
892 assert!(!report.enabled);
893 assert_eq!(report.overall_status(), &ValidationStatus::Skipped);
894 assert_eq!(value["enabled"], json!(false));
895 assert_eq!(value["status"], json!("skipped"));
896 assert_eq!(value["endpoints"], json!([]));
897 assert_eq!(value["results"], json!([]));
898 assert_eq!(value["mismatches"], json!([]));
899 assert_eq!(value["skipped"][0]["reason"], json!("validation_disabled"));
900 assert_eq!(value["failures"], json!([]));
901 }
902
903 #[rstest]
904 fn skipped_no_endpoints_report_shape() {
905 let value = serde_json::to_value(ValidationReport::skipped_no_endpoints())
906 .expect("skipped report serializes to JSON");
907
908 assert_eq!(value["enabled"], json!(true));
909 assert_eq!(value["status"], json!("skipped"));
910 assert_eq!(
911 value["skipped"][0]["reason"],
912 json!("no_validation_endpoints_configured")
913 );
914 }
915
916 #[rstest]
917 fn validation_options_default_is_enabled() {
918 assert_eq!(ValidationOptions::default().enabled, true);
919
920 let parsed: ValidationOptions =
921 serde_json::from_value(json!({})).expect("empty validation options use defaults");
922
923 assert!(parsed.enabled);
924 assert_eq!(parsed.endpoint_filter, None);
925 }
926
927 #[rstest]
928 fn validation_request_defaults_options(expected_record: ExpectedRecord) {
929 let request: ValidationRequest = serde_json::from_value(json!({
930 "zone": "example.com",
931 "expectedRecords": [expected_record]
932 }))
933 .expect("request deserializes with default options");
934
935 assert!(request.options.enabled);
936 assert_eq!(request.domain, None);
937 assert_eq!(request.expected_records.len(), 1);
938 }
939
940 #[tokio::test]
941 async fn validation_resolver_plain_dns_fake() {
942 let endpoint = validation_endpoint(ValidationTransport::Dns);
943 let expected = vec![ObservedRecord {
944 name: "www.example.com".to_string(),
945 record_type: "A".to_string(),
946 values: vec!["192.0.2.10".to_string()],
947 }];
948 let resolver = FakeDnsEndpointResolver::with_records(expected.clone());
949
950 let observed = resolver
951 .query_endpoint(
952 &endpoint,
953 "www.example.com",
954 "A",
955 endpoint_timeout(&endpoint),
956 )
957 .await
958 .expect("fake resolver returns deterministic records");
959
960 assert_eq!(observed, expected);
961 }
962
963 #[tokio::test]
964 async fn validation_resolver_doh_http_500_failure() {
965 let endpoint = validation_endpoint(ValidationTransport::Doh);
966 let resolver = FakeDnsEndpointResolver::with_failure(ValidationFailureKind::DohHttpFailure);
967
968 let failure = resolver
969 .query_endpoint(
970 &endpoint,
971 "www.example.com",
972 "A",
973 endpoint_timeout(&endpoint),
974 )
975 .await
976 .expect_err("fake resolver returns deterministic DoH failure");
977
978 assert_eq!(failure, ValidationFailureKind::DohHttpFailure);
979 }
980
981 #[tokio::test]
982 async fn validation_resolver_dot_tls_failure() {
983 let endpoint = validation_endpoint(ValidationTransport::Dot);
984 let resolver = FakeDnsEndpointResolver::with_failure(ValidationFailureKind::TlsFailure);
985
986 let failure = resolver
987 .query_endpoint(
988 &endpoint,
989 "www.example.com",
990 "A",
991 endpoint_timeout(&endpoint),
992 )
993 .await
994 .expect_err("fake resolver returns deterministic DoT failure");
995
996 assert_eq!(failure, ValidationFailureKind::TlsFailure);
997 }
998
999 #[tokio::test]
1000 async fn validation_resolver_timeout_failure() {
1001 let endpoint = validation_endpoint(ValidationTransport::Dns);
1002 let resolver = FakeDnsEndpointResolver::with_failure(ValidationFailureKind::Timeout);
1003
1004 let failure = resolver
1005 .query_endpoint(
1006 &endpoint,
1007 "www.example.com",
1008 "A",
1009 endpoint_timeout(&endpoint),
1010 )
1011 .await
1012 .expect_err("fake resolver returns deterministic timeout");
1013
1014 assert_eq!(failure, ValidationFailureKind::Timeout);
1015 }
1016
1017 #[rstest]
1018 fn validation_compare_exact_match() {
1019 let response = list_response(vec![
1020 zone_record(
1021 "@",
1022 300,
1023 RecordData::A {
1024 ip: Ipv4Addr::new(192, 0, 2, 10),
1025 },
1026 ),
1027 zone_record(
1028 "@",
1029 300,
1030 RecordData::Aaaa {
1031 ip: Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 0x0010),
1032 },
1033 ),
1034 zone_record(
1035 "www",
1036 300,
1037 RecordData::Cname {
1038 target: "example.test.".to_string(),
1039 },
1040 ),
1041 zone_record(
1042 "@",
1043 300,
1044 RecordData::Mx {
1045 preference: 10,
1046 exchange: "mail.example.test.".to_string(),
1047 },
1048 ),
1049 zone_record(
1050 "@",
1051 300,
1052 RecordData::Txt {
1053 text: "dnsync-validation-test".to_string(),
1054 split_text: false,
1055 },
1056 ),
1057 ]);
1058 let (expected, skipped) = expected_records_from_response(&response);
1059 let observed = expected
1060 .iter()
1061 .map(|record| ObservedRecord {
1062 name: record.name.clone(),
1063 record_type: record.record_type.clone(),
1064 values: record.values.clone(),
1065 })
1066 .collect::<Vec<_>>();
1067
1068 let results = compare_rrsets(&expected, &observed);
1069
1070 assert!(skipped.is_empty());
1071 assert_eq!(results.len(), 5);
1072 assert!(
1073 results
1074 .iter()
1075 .all(|result| result.status == ValidationStatus::Passed)
1076 );
1077 }
1078
1079 #[rstest]
1080 fn validation_compare_missing_extra_wrong_value() {
1081 let expected = vec![
1082 ExpectedRecord {
1083 name: "example.test".to_string(),
1084 record_type: "A".to_string(),
1085 values: vec!["192.0.2.10".to_string()],
1086 },
1087 ExpectedRecord {
1088 name: "www.example.test".to_string(),
1089 record_type: "CNAME".to_string(),
1090 values: vec!["example.test".to_string()],
1091 },
1092 ];
1093 let observed = vec![
1094 ObservedRecord {
1095 name: "example.test".to_string(),
1096 record_type: "A".to_string(),
1097 values: vec!["192.0.2.99".to_string()],
1098 },
1099 ObservedRecord {
1100 name: "extra.example.test".to_string(),
1101 record_type: "AAAA".to_string(),
1102 values: vec!["2001:db8::99".to_string()],
1103 },
1104 ];
1105
1106 let results = compare_rrsets(&expected, &observed);
1107 let kinds = results
1108 .iter()
1109 .filter_map(|result| result.mismatch.as_ref())
1110 .map(|mismatch| mismatch.mismatch_kind.as_str())
1111 .collect::<Vec<_>>();
1112
1113 assert_eq!(results.len(), 3);
1114 assert!(kinds.contains(&"wrong_value"));
1115 assert!(kinds.contains(&"missing"));
1116 assert!(kinds.contains(&"extra"));
1117 }
1118
1119 #[rstest]
1120 fn validation_skips_unsupported_types() {
1121 let response = list_response(vec![zone_record(
1122 "@",
1123 300,
1124 RecordData::Unknown {
1125 rdata: "00ff".to_string(),
1126 },
1127 )]);
1128
1129 let (expected, skipped) = expected_records_from_response(&response);
1130
1131 assert!(expected.is_empty());
1132 assert_eq!(skipped.len(), 1);
1133 assert_eq!(skipped[0].record_type, "UNKNOWN");
1134 assert_eq!(skipped[0].reason, "unsupported_record_type");
1135 }
1136
1137 #[rstest]
1138 fn validation_ignores_ttl_differences() {
1139 let response = list_response(vec![zone_record(
1140 "@",
1141 30,
1142 RecordData::A {
1143 ip: Ipv4Addr::new(192, 0, 2, 10),
1144 },
1145 )]);
1146 let (expected, skipped) = expected_records_from_response(&response);
1147 let observed = vec![ObservedRecord {
1148 name: "example.test.".to_string(),
1149 record_type: "a".to_string(),
1150 values: vec!["192.0.2.10".to_string()],
1151 }];
1152
1153 let results = compare_rrsets(&expected, &observed);
1154
1155 assert!(skipped.is_empty());
1156 assert_eq!(results[0].status, ValidationStatus::Passed);
1157 }
1158
1159 #[rstest]
1160 fn validation_normalizes_txt_mx_srv_cname_ns() {
1161 let response = list_response(vec![
1162 zone_record(
1163 "www",
1164 300,
1165 RecordData::Cname {
1166 target: "Example.TEST.".to_string(),
1167 },
1168 ),
1169 zone_record(
1170 "@",
1171 300,
1172 RecordData::Txt {
1173 text: "dnsync-validation-test".to_string(),
1174 split_text: true,
1175 },
1176 ),
1177 zone_record(
1178 "@",
1179 300,
1180 RecordData::Mx {
1181 preference: 10,
1182 exchange: "Mail.Example.Test.".to_string(),
1183 },
1184 ),
1185 zone_record(
1186 "@",
1187 300,
1188 RecordData::Ns {
1189 nameserver: "NS1.Example.Test.".to_string(),
1190 glue: None,
1191 },
1192 ),
1193 zone_record(
1194 "_sip._tcp",
1195 300,
1196 RecordData::Srv {
1197 priority: 10,
1198 weight: 20,
1199 port: 5060,
1200 target: "Sip.Example.Test.".to_string(),
1201 },
1202 ),
1203 ]);
1204 let (expected, skipped) = expected_records_from_response(&response);
1205 let observed = vec![
1206 ObservedRecord {
1207 name: "WWW.EXAMPLE.TEST.".to_string(),
1208 record_type: "cname".to_string(),
1209 values: vec!["example.test".to_string()],
1210 },
1211 ObservedRecord {
1212 name: "example.test".to_string(),
1213 record_type: "TXT".to_string(),
1214 values: vec!["\"dnsync-\" \"validation-test\"".to_string()],
1215 },
1216 ObservedRecord {
1217 name: "example.test".to_string(),
1218 record_type: "MX".to_string(),
1219 values: vec!["10 mail.example.test".to_string()],
1220 },
1221 ObservedRecord {
1222 name: "example.test".to_string(),
1223 record_type: "NS".to_string(),
1224 values: vec!["ns1.example.test".to_string()],
1225 },
1226 ObservedRecord {
1227 name: "_sip._tcp.example.test".to_string(),
1228 record_type: "SRV".to_string(),
1229 values: vec!["10 20 5060 sip.example.test".to_string()],
1230 },
1231 ];
1232
1233 let results = compare_rrsets(&expected, &observed);
1234
1235 assert!(skipped.is_empty());
1236 assert_eq!(results.len(), 5);
1237 assert!(
1238 results
1239 .iter()
1240 .all(|result| result.status == ValidationStatus::Passed)
1241 );
1242 }
1243
1244 #[rstest]
1245 #[case::passed(ValidationStatus::Passed, "passed")]
1246 #[case::mismatched(ValidationStatus::Mismatched, "mismatched")]
1247 #[case::skipped(ValidationStatus::Skipped, "skipped")]
1248 #[case::failed(ValidationStatus::Failed, "failed")]
1249 fn validation_status_serializes_lowercase(
1250 #[case] status: ValidationStatus,
1251 #[case] expected: &str,
1252 ) {
1253 assert_eq!(
1254 serde_json::to_value(status).expect("status serializes"),
1255 Value::String(expected.to_string())
1256 );
1257 }
1258
1259 #[rstest]
1260 #[case::timeout(ValidationFailureKind::Timeout, "timeout")]
1261 #[case::nxdomain(ValidationFailureKind::Nxdomain, "nxdomain")]
1262 #[case::servfail(ValidationFailureKind::Servfail, "servfail")]
1263 #[case::refused(ValidationFailureKind::Refused, "refused")]
1264 #[case::tls_failure(ValidationFailureKind::TlsFailure, "tls_failure")]
1265 #[case::doh_http_failure(ValidationFailureKind::DohHttpFailure, "doh_http_failure")]
1266 #[case::malformed_response(ValidationFailureKind::MalformedResponse, "malformed_response")]
1267 #[case::unsupported_transport(
1268 ValidationFailureKind::UnsupportedTransport,
1269 "unsupported_transport"
1270 )]
1271 fn validation_failure_kind_serializes_snake_case(
1272 #[case] failure_kind: ValidationFailureKind,
1273 #[case] expected: &str,
1274 ) {
1275 assert_eq!(
1276 serde_json::to_value(failure_kind).expect("failure kind serializes"),
1277 Value::String(expected.to_string())
1278 );
1279 }
1280}