Skip to main content

ip_discovery/dns/
mod.rs

1//! DNS protocol implementation for public IP detection
2//!
3//! Uses DNS TXT/A records from special domains to detect public IP.
4//! This implementation uses raw UDP sockets instead of external DNS libraries.
5
6mod protocol;
7pub(crate) mod providers;
8
9pub use providers::{default_providers, provider_names};
10
11use crate::error::ProviderError;
12use crate::provider::Provider;
13use crate::types::{IpVersion, Protocol};
14use async_trait::async_trait;
15use protocol::{build_query, parse_response, DnsClass, RecordType};
16use std::net::{IpAddr, SocketAddr};
17use std::str::FromStr;
18use tokio::net::UdpSocket;
19use tracing::debug;
20
21/// Record type for DNS query
22#[derive(Debug, Clone, Copy)]
23pub enum DnsRecordType {
24    /// A/AAAA record (direct IP)
25    Address,
26    /// TXT record (IP as text)
27    Txt,
28}
29
30/// DNS provider configuration
31#[derive(Debug, Clone)]
32pub struct DnsProvider {
33    name: String,
34    query_domain: String,
35    resolver_addr: SocketAddr,
36    resolver_addr_v6: Option<SocketAddr>,
37    record_type: DnsRecordType,
38    dns_class: DnsClass,
39    supports_v4: bool,
40    supports_v6: bool,
41}
42
43impl DnsProvider {
44    /// Create a new DNS provider
45    pub fn new(
46        name: impl Into<String>,
47        query_domain: impl Into<String>,
48        resolver_addr: SocketAddr,
49        record_type: DnsRecordType,
50    ) -> Self {
51        Self {
52            name: name.into(),
53            query_domain: query_domain.into(),
54            resolver_addr,
55            resolver_addr_v6: None,
56            record_type,
57            dns_class: DnsClass::In,
58            supports_v4: true,
59            supports_v6: false,
60        }
61    }
62
63    /// Set DNS class (for special queries like Cloudflare CHAOS)
64    pub fn with_class(mut self, class: DnsClass) -> Self {
65        self.dns_class = class;
66        self
67    }
68
69    /// Set IPv6 support
70    pub fn with_v6(mut self, supports: bool) -> Self {
71        self.supports_v6 = supports;
72        self
73    }
74
75    /// Set IPv6 resolver address
76    ///
77    /// When requesting IPv6, the query is sent to this resolver so the
78    /// DNS server sees the client's IPv6 source address.
79    pub fn with_v6_resolver(mut self, addr: SocketAddr) -> Self {
80        self.resolver_addr_v6 = Some(addr);
81        self.supports_v6 = true;
82        self
83    }
84
85    /// Query for IP address using raw UDP
86    async fn query(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
87        // Pick resolver: use IPv6 resolver when requesting v6 if available
88        let resolver = match version {
89            IpVersion::V6 => self.resolver_addr_v6.unwrap_or(self.resolver_addr),
90            _ => self.resolver_addr,
91        };
92
93        debug!(
94            provider = %self.name,
95            domain = %self.query_domain,
96            resolver = %resolver,
97            "querying DNS"
98        );
99
100        // Determine record type based on version and configured type
101        let record_type = match self.record_type {
102            DnsRecordType::Address => match version {
103                IpVersion::V6 => RecordType::Aaaa,
104                _ => RecordType::A,
105            },
106            DnsRecordType::Txt => RecordType::Txt,
107        };
108
109        // Build query packet
110        let query = build_query(&self.query_domain, record_type, self.dns_class)
111            .map_err(|e| ProviderError::new(&self.name, e))?;
112
113        // Create UDP socket
114        let bind_addr = if resolver.is_ipv6() {
115            "[::]:0"
116        } else {
117            "0.0.0.0:0"
118        };
119        let socket = UdpSocket::bind(bind_addr)
120            .await
121            .map_err(|e| ProviderError::new(&self.name, e))?;
122
123        // Send query
124        socket
125            .send_to(&query, resolver)
126            .await
127            .map_err(|e| ProviderError::new(&self.name, e))?;
128
129        // Receive response
130        let mut buf = [0u8; 1232]; // RFC 8020 recommended safe UDP DNS size
131        let len = socket
132            .recv(&mut buf)
133            .await
134            .map_err(|e| ProviderError::new(&self.name, e))?;
135
136        // Parse response
137        let results = parse_response(&buf[..len], record_type)
138            .map_err(|e| ProviderError::message(&self.name, e))?;
139
140        // Extract IP from results
141        for result in results {
142            // Handle potential CIDR notation or prefixed text
143            for part in result.split_whitespace() {
144                let ip_str = part.split('/').next().unwrap_or(part);
145                if let Ok(ip) = IpAddr::from_str(ip_str) {
146                    // Filter by version if needed
147                    match version {
148                        IpVersion::V4 if ip.is_ipv4() => return Ok(ip),
149                        IpVersion::V6 if ip.is_ipv6() => return Ok(ip),
150                        IpVersion::Any => return Ok(ip),
151                        _ => continue,
152                    }
153                }
154            }
155        }
156
157        Err(ProviderError::message(
158            &self.name,
159            "no valid IP in DNS response",
160        ))
161    }
162}
163
164#[async_trait]
165impl Provider for DnsProvider {
166    fn name(&self) -> &str {
167        &self.name
168    }
169
170    fn protocol(&self) -> Protocol {
171        Protocol::Dns
172    }
173
174    fn supports_v4(&self) -> bool {
175        self.supports_v4
176    }
177
178    fn supports_v6(&self) -> bool {
179        self.supports_v6
180    }
181
182    async fn get_ip(&self, version: IpVersion) -> Result<IpAddr, ProviderError> {
183        self.query(version).await
184    }
185}