1mod protocol;
7pub(crate) mod providers;
8
9pub use providers::{default_providers, provider_names};
10
11use crate::error::ProviderError;
12use crate::provider::Provider;
13use crate::types::{IpVersion, Protocol};
14use async_trait::async_trait;
15use protocol::{build_query, parse_response, DnsClass, RecordType};
16use std::net::{IpAddr, SocketAddr};
17use std::str::FromStr;
18use tokio::net::UdpSocket;
19use tracing::debug;
20
21#[derive(Debug, Clone, Copy)]
23pub enum DnsRecordType {
24 Address,
26 Txt,
28}
29
30#[derive(Debug, Clone)]
32pub struct DnsProvider {
33 name: String,
34 query_domain: String,
35 resolver_addr: SocketAddr,
36 resolver_addr_v6: Option<SocketAddr>,
37 record_type: DnsRecordType,
38 dns_class: DnsClass,
39 supports_v4: bool,
40 supports_v6: bool,
41}
42
43impl DnsProvider {
44 pub fn new(
46 name: impl Into<String>,
47 query_domain: impl Into<String>,
48 resolver_addr: SocketAddr,
49 record_type: DnsRecordType,
50 ) -> Self {
51 Self {
52 name: name.into(),
53 query_domain: query_domain.into(),
54 resolver_addr,
55 resolver_addr_v6: None,
56 record_type,
57 dns_class: DnsClass::In,
58 supports_v4: true,
59 supports_v6: false,
60 }
61 }
62
63 pub fn with_class(mut self, class: DnsClass) -> Self {
65 self.dns_class = class;
66 self
67 }
68
69 pub fn with_v6(mut self, supports: bool) -> Self {
71 self.supports_v6 = supports;
72 self
73 }
74
75 pub fn with_v6_resolver(mut self, addr: SocketAddr) -> Self {
80 self.resolver_addr_v6 = Some(addr);
81 self.supports_v6 = true;
82 self
83 }
84
85 async fn query(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
87 let resolver = match version {
89 IpVersion::V6 => self.resolver_addr_v6.unwrap_or(self.resolver_addr),
90 _ => self.resolver_addr,
91 };
92
93 debug!(
94 provider = %self.name,
95 domain = %self.query_domain,
96 resolver = %resolver,
97 "querying DNS"
98 );
99
100 let record_type = match self.record_type {
102 DnsRecordType::Address => match version {
103 IpVersion::V6 => RecordType::Aaaa,
104 _ => RecordType::A,
105 },
106 DnsRecordType::Txt => RecordType::Txt,
107 };
108
109 let query = build_query(&self.query_domain, record_type, self.dns_class)
111 .map_err(|e| ProviderError::new(&self.name, e))?;
112
113 let bind_addr = if resolver.is_ipv6() {
115 "[::]:0"
116 } else {
117 "0.0.0.0:0"
118 };
119 let socket = UdpSocket::bind(bind_addr)
120 .await
121 .map_err(|e| ProviderError::new(&self.name, e))?;
122
123 socket
125 .send_to(&query, resolver)
126 .await
127 .map_err(|e| ProviderError::new(&self.name, e))?;
128
129 let mut buf = [0u8; 1232]; let len = socket
132 .recv(&mut buf)
133 .await
134 .map_err(|e| ProviderError::new(&self.name, e))?;
135
136 let results = parse_response(&buf[..len], record_type)
138 .map_err(|e| ProviderError::message(&self.name, e))?;
139
140 for result in results {
142 for part in result.split_whitespace() {
144 let ip_str = part.split('/').next().unwrap_or(part);
145 if let Ok(ip) = IpAddr::from_str(ip_str) {
146 match version {
148 IpVersion::V4 if ip.is_ipv4() => return Ok(ip),
149 IpVersion::V6 if ip.is_ipv6() => return Ok(ip),
150 IpVersion::Any => return Ok(ip),
151 _ => continue,
152 }
153 }
154 }
155 }
156
157 Err(ProviderError::message(
158 &self.name,
159 "no valid IP in DNS response",
160 ))
161 }
162}
163
164#[async_trait]
165impl Provider for DnsProvider {
166 fn name(&self) -> &str {
167 &self.name
168 }
169
170 fn protocol(&self) -> Protocol {
171 Protocol::Dns
172 }
173
174 fn supports_v4(&self) -> bool {
175 self.supports_v4
176 }
177
178 fn supports_v6(&self) -> bool {
179 self.supports_v6
180 }
181
182 async fn get_ip(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
183 self.query(version).await
184 }
185}