Skip to main content

ip_discovery/dns/
mod.rs

1//! DNS protocol implementation for public IP detection
2//!
3//! Uses DNS TXT/A records from special domains to detect public IP.
4//! This implementation uses raw UDP sockets instead of external DNS libraries.
5//!
6//! # Security
7//!
8//! Transaction IDs are generated with [`getrandom`] (OS-level CSPRNG),
9//! preventing DNS transaction ID spoofing attacks.
10
11mod protocol;
12pub(crate) mod providers;
13
14pub use providers::{default_providers, provider_names};
15
16use crate::error::ProviderError;
17use crate::provider::Provider;
18use crate::types::{IpVersion, Protocol};
19use protocol::{build_query, parse_response, DnsClass, RecordType};
20use std::future::Future;
21use std::net::{IpAddr, SocketAddr};
22use std::pin::Pin;
23use std::str::FromStr;
24use tokio::net::UdpSocket;
25
26/// Record type for DNS query
27#[derive(Debug, Clone, Copy)]
28pub enum DnsRecordType {
29    /// A/AAAA record (direct IP)
30    Address,
31    /// TXT record (IP as text)
32    Txt,
33}
34
35/// DNS provider configuration
36#[derive(Debug, Clone)]
37pub struct DnsProvider {
38    name: String,
39    query_domain: String,
40    resolver_addr: SocketAddr,
41    resolver_addr_v6: Option<SocketAddr>,
42    record_type: DnsRecordType,
43    dns_class: DnsClass,
44    supports_v4: bool,
45    supports_v6: bool,
46}
47
48impl DnsProvider {
49    /// Create a new DNS provider
50    pub fn new(
51        name: impl Into<String>,
52        query_domain: impl Into<String>,
53        resolver_addr: SocketAddr,
54        record_type: DnsRecordType,
55    ) -> Self {
56        Self {
57            name: name.into(),
58            query_domain: query_domain.into(),
59            resolver_addr,
60            resolver_addr_v6: None,
61            record_type,
62            dns_class: DnsClass::In,
63            supports_v4: true,
64            supports_v6: false,
65        }
66    }
67
68    /// Set DNS class (for special queries like Cloudflare CHAOS)
69    pub fn with_class(mut self, class: DnsClass) -> Self {
70        self.dns_class = class;
71        self
72    }
73
74    /// Set IPv6 support
75    pub fn with_v6(mut self, supports: bool) -> Self {
76        self.supports_v6 = supports;
77        self
78    }
79
80    /// Set IPv6 resolver address
81    ///
82    /// When requesting IPv6, the query is sent to this resolver so the
83    /// DNS server sees the client's IPv6 source address.
84    pub fn with_v6_resolver(mut self, addr: SocketAddr) -> Self {
85        self.resolver_addr_v6 = Some(addr);
86        self.supports_v6 = true;
87        self
88    }
89
90    /// Query for IP address using raw UDP
91    async fn query(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
92        // Pick resolver: use IPv6 resolver when requesting v6 if available
93        let resolver = match version {
94            IpVersion::V6 => self.resolver_addr_v6.unwrap_or(self.resolver_addr),
95            _ => self.resolver_addr,
96        };
97
98        // Determine record type based on version and configured type
99        let record_type = match self.record_type {
100            DnsRecordType::Address => match version {
101                IpVersion::V6 => RecordType::Aaaa,
102                _ => RecordType::A,
103            },
104            DnsRecordType::Txt => RecordType::Txt,
105        };
106
107        // Build query packet
108        let query = build_query(&self.query_domain, record_type, self.dns_class)
109            .map_err(|e| ProviderError::new(&self.name, e))?;
110
111        // Create UDP socket
112        let bind_addr = if resolver.is_ipv6() {
113            "[::]:0"
114        } else {
115            "0.0.0.0:0"
116        };
117        let socket = UdpSocket::bind(bind_addr)
118            .await
119            .map_err(|e| ProviderError::new(&self.name, e))?;
120
121        // Send query
122        socket
123            .send_to(&query, resolver)
124            .await
125            .map_err(|e| ProviderError::new(&self.name, e))?;
126
127        // Receive response
128        let mut buf = [0u8; 1232]; // DNS Flag Day 2020 safe UDP size (RFC 6891 EDNS0)
129        let len = socket
130            .recv(&mut buf)
131            .await
132            .map_err(|e| ProviderError::new(&self.name, e))?;
133
134        // Parse response
135        let results = parse_response(&buf[..len], record_type)
136            .map_err(|e| ProviderError::message(&self.name, e))?;
137
138        // Extract IP from results
139        for result in results {
140            // Handle potential CIDR notation or prefixed text
141            for part in result.split_whitespace() {
142                let ip_str = part.split('/').next().unwrap_or(part);
143                if let Ok(ip) = IpAddr::from_str(ip_str) {
144                    // Filter by version if needed
145                    match version {
146                        IpVersion::V4 if ip.is_ipv4() => return Ok(ip),
147                        IpVersion::V6 if ip.is_ipv6() => return Ok(ip),
148                        IpVersion::Any => return Ok(ip),
149                        _ => continue,
150                    }
151                }
152            }
153        }
154
155        Err(ProviderError::message(
156            &self.name,
157            "no valid IP in DNS response",
158        ))
159    }
160}
161
162impl Provider for DnsProvider {
163    fn name(&self) -> &str {
164        &self.name
165    }
166
167    fn protocol(&self) -> Protocol {
168        Protocol::Dns
169    }
170
171    fn supports_v4(&self) -> bool {
172        self.supports_v4
173    }
174
175    fn supports_v6(&self) -> bool {
176        self.supports_v6
177    }
178
179    fn get_ip(
180        &self,
181        version: IpVersion,
182    ) -> Pin<Box<dyn Future<Output = Result<IpAddr, ProviderError>> + Send + '_>> {
183        Box::pin(self.query(version))
184    }
185}