1mod protocol;
12pub(crate) mod providers;
13
14pub use providers::{default_providers, provider_names};
15
16use crate::error::ProviderError;
17use crate::provider::Provider;
18use crate::types::{IpVersion, Protocol};
19use protocol::{build_query, parse_response, DnsClass, RecordType};
20use std::future::Future;
21use std::net::{IpAddr, SocketAddr};
22use std::pin::Pin;
23use std::str::FromStr;
24use tokio::net::UdpSocket;
25
26#[derive(Debug, Clone, Copy)]
28pub enum DnsRecordType {
29 Address,
31 Txt,
33}
34
35#[derive(Debug, Clone)]
37pub struct DnsProvider {
38 name: String,
39 query_domain: String,
40 resolver_addr: SocketAddr,
41 resolver_addr_v6: Option<SocketAddr>,
42 record_type: DnsRecordType,
43 dns_class: DnsClass,
44 supports_v4: bool,
45 supports_v6: bool,
46}
47
48impl DnsProvider {
49 pub fn new(
51 name: impl Into<String>,
52 query_domain: impl Into<String>,
53 resolver_addr: SocketAddr,
54 record_type: DnsRecordType,
55 ) -> Self {
56 Self {
57 name: name.into(),
58 query_domain: query_domain.into(),
59 resolver_addr,
60 resolver_addr_v6: None,
61 record_type,
62 dns_class: DnsClass::In,
63 supports_v4: true,
64 supports_v6: false,
65 }
66 }
67
68 pub fn with_class(mut self, class: DnsClass) -> Self {
70 self.dns_class = class;
71 self
72 }
73
74 pub fn with_v6(mut self, supports: bool) -> Self {
76 self.supports_v6 = supports;
77 self
78 }
79
80 pub fn with_v6_resolver(mut self, addr: SocketAddr) -> Self {
85 self.resolver_addr_v6 = Some(addr);
86 self.supports_v6 = true;
87 self
88 }
89
90 async fn query(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
92 let resolver = match version {
94 IpVersion::V6 => self.resolver_addr_v6.unwrap_or(self.resolver_addr),
95 _ => self.resolver_addr,
96 };
97
98 let record_type = match self.record_type {
100 DnsRecordType::Address => match version {
101 IpVersion::V6 => RecordType::Aaaa,
102 _ => RecordType::A,
103 },
104 DnsRecordType::Txt => RecordType::Txt,
105 };
106
107 let query = build_query(&self.query_domain, record_type, self.dns_class)
109 .map_err(|e| ProviderError::new(&self.name, e))?;
110
111 let bind_addr = if resolver.is_ipv6() {
113 "[::]:0"
114 } else {
115 "0.0.0.0:0"
116 };
117 let socket = UdpSocket::bind(bind_addr)
118 .await
119 .map_err(|e| ProviderError::new(&self.name, e))?;
120
121 socket
123 .send_to(&query, resolver)
124 .await
125 .map_err(|e| ProviderError::new(&self.name, e))?;
126
127 let mut buf = [0u8; 1232]; let len = socket
130 .recv(&mut buf)
131 .await
132 .map_err(|e| ProviderError::new(&self.name, e))?;
133
134 let results = parse_response(&buf[..len], record_type)
136 .map_err(|e| ProviderError::message(&self.name, e))?;
137
138 for result in results {
140 for part in result.split_whitespace() {
142 let ip_str = part.split('/').next().unwrap_or(part);
143 if let Ok(ip) = IpAddr::from_str(ip_str) {
144 match version {
146 IpVersion::V4 if ip.is_ipv4() => return Ok(ip),
147 IpVersion::V6 if ip.is_ipv6() => return Ok(ip),
148 IpVersion::Any => return Ok(ip),
149 _ => continue,
150 }
151 }
152 }
153 }
154
155 Err(ProviderError::message(
156 &self.name,
157 "no valid IP in DNS response",
158 ))
159 }
160}
161
162impl Provider for DnsProvider {
163 fn name(&self) -> &str {
164 &self.name
165 }
166
167 fn protocol(&self) -> Protocol {
168 Protocol::Dns
169 }
170
171 fn supports_v4(&self) -> bool {
172 self.supports_v4
173 }
174
175 fn supports_v6(&self) -> bool {
176 self.supports_v6
177 }
178
179 fn get_ip(
180 &self,
181 version: IpVersion,
182 ) -> Pin<Box<dyn Future<Output = Result<IpAddr, ProviderError>> + Send + '_>> {
183 Box::pin(self.query(version))
184 }
185}