Skip to main content

cloudflare_quick_tunnel/
edge.rs

1//! Cloudflare edge discovery: DNS SRV
2//! (`_v2-origintunneld._tcp.argotunnel.com`) with a DNS-over-TLS
3//! fallback through `1.1.1.1:853`. Mirrors the semantics of
4//! `cloudflared/edgediscovery/allregions/discovery.go`.
5//!
6//! The result is a list of `EdgeAddr`s (resolved IPs + port 7844)
7//! the caller can hand to `quic_dial::dial_any`. Order is shuffled
8//! per-resolution so two adjacent processes don't pin the same
9//! edge, and an in-memory cache with a 1h TTL keeps repeated
10//! reconnects from hammering DNS.
11
12use std::net::{IpAddr, SocketAddr};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
17use hickory_resolver::TokioAsyncResolver;
18use tokio::sync::RwLock;
19use tracing::{debug, warn};
20
21use crate::error::TunnelError;
22
23/// SRV record we resolve to discover the v2 origintunneld pool.
24pub const SRV_NAME: &str = "_v2-origintunneld._tcp.argotunnel.com";
25
26/// Server name for the DoT fallback resolver.
27pub const DOT_SERVER_NAME: &str = "cloudflare-dns.com";
28
29/// DoT endpoint address (Cloudflare public resolver).
30pub const DOT_SERVER_ADDR: &str = "1.1.1.1:853";
31
32/// Default in-memory cache TTL for resolved edges.
33pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(3600);
34
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub enum IpVersionFilter {
37    #[default]
38    Auto,
39    V4Only,
40    V6Only,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum EdgeIpVersion {
45    V4,
46    V6,
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct EdgeAddr {
51    pub ip: IpAddr,
52    pub port: u16,
53    pub version: EdgeIpVersion,
54}
55
56impl EdgeAddr {
57    pub fn socket(&self) -> SocketAddr {
58        SocketAddr::new(self.ip, self.port)
59    }
60
61    fn from_ip(ip: IpAddr, port: u16) -> Self {
62        let version = if ip.is_ipv4() {
63            EdgeIpVersion::V4
64        } else {
65            EdgeIpVersion::V6
66        };
67        Self { ip, port, version }
68    }
69
70    fn matches(&self, filter: IpVersionFilter) -> bool {
71        matches!(
72            (filter, self.version),
73            (IpVersionFilter::Auto, _)
74                | (IpVersionFilter::V4Only, EdgeIpVersion::V4)
75                | (IpVersionFilter::V6Only, EdgeIpVersion::V6)
76        )
77    }
78}
79
80/// One-shot discovery without caching. System resolver first; on
81/// failure / empty answer, falls back to DoT via `1.1.1.1`.
82pub async fn discover(filter: IpVersionFilter) -> Result<Vec<EdgeAddr>, TunnelError> {
83    let primary = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
84    match resolve_srv(&primary, filter).await {
85        Ok(edges) if !edges.is_empty() => return Ok(edges),
86        Ok(_) => warn!("system resolver returned zero edges; falling back to DoT"),
87        Err(e) => warn!(error = %e, "system resolver SRV failed; falling back to DoT"),
88    }
89
90    let dot = build_dot_resolver()?;
91    let edges = resolve_srv(&dot, filter).await?;
92    if edges.is_empty() {
93        return Err(TunnelError::Discovery(format!(
94            "DoT fallback also returned no edges for {SRV_NAME}"
95        )));
96    }
97    Ok(edges)
98}
99
100/// In-memory cache around `discover`. Re-resolves once the TTL
101/// expires; otherwise hands out a fresh shuffle of the previous
102/// result so callers see a different head edge across calls.
103#[derive(Clone)]
104pub struct EdgeRegistry {
105    inner: Arc<RwLock<Option<Cached>>>,
106    ttl: Duration,
107}
108
109struct Cached {
110    edges: Vec<EdgeAddr>,
111    expires_at: Instant,
112    filter: IpVersionFilter,
113}
114
115impl EdgeRegistry {
116    pub fn new() -> Self {
117        Self::with_ttl(DEFAULT_CACHE_TTL)
118    }
119
120    pub fn with_ttl(ttl: Duration) -> Self {
121        Self {
122            inner: Arc::new(RwLock::new(None)),
123            ttl,
124        }
125    }
126
127    pub async fn get_or_refresh(
128        &self,
129        filter: IpVersionFilter,
130    ) -> Result<Vec<EdgeAddr>, TunnelError> {
131        {
132            let guard = self.inner.read().await;
133            if let Some(c) = guard.as_ref() {
134                if c.filter == filter && c.expires_at > Instant::now() {
135                    debug!(count = c.edges.len(), "edge cache hit");
136                    return Ok(shuffled(&c.edges));
137                }
138            }
139        }
140        let edges = discover(filter).await?;
141        let mut guard = self.inner.write().await;
142        *guard = Some(Cached {
143            edges: edges.clone(),
144            expires_at: Instant::now() + self.ttl,
145            filter,
146        });
147        Ok(shuffled(&edges))
148    }
149}
150
151impl Default for EdgeRegistry {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157// ── Internals ─────────────────────────────────────────────────────────────────
158
159fn build_dot_resolver() -> Result<TokioAsyncResolver, TunnelError> {
160    let addr: SocketAddr = DOT_SERVER_ADDR
161        .parse()
162        .map_err(|e| TunnelError::Discovery(format!("DoT addr parse: {e}")))?;
163    let ns = NameServerConfigGroup::from_ips_tls(
164        &[addr.ip()],
165        addr.port(),
166        DOT_SERVER_NAME.into(),
167        true,
168    );
169    let cfg = ResolverConfig::from_parts(None, vec![], ns);
170    let mut opts = ResolverOpts::default();
171    opts.timeout = Duration::from_secs(15);
172    Ok(TokioAsyncResolver::tokio(cfg, opts))
173}
174
175async fn resolve_srv(
176    resolver: &TokioAsyncResolver,
177    filter: IpVersionFilter,
178) -> Result<Vec<EdgeAddr>, TunnelError> {
179    let srv = resolver
180        .srv_lookup(SRV_NAME)
181        .await
182        .map_err(|e| TunnelError::Discovery(format!("SRV {SRV_NAME}: {e}")))?;
183
184    let mut edges = Vec::new();
185    for rec in srv.iter() {
186        let target = rec.target().to_utf8();
187        let target = target.trim_end_matches('.');
188        let port = rec.port();
189        match resolver.lookup_ip(target).await {
190            Ok(ips) => {
191                for ip in ips.iter() {
192                    let edge = EdgeAddr::from_ip(ip, port);
193                    if edge.matches(filter) {
194                        edges.push(edge);
195                    }
196                }
197            }
198            Err(e) => warn!(target, error = %e, "IP resolution failed for SRV target"),
199        }
200    }
201    Ok(edges)
202}
203
204fn shuffled(input: &[EdgeAddr]) -> Vec<EdgeAddr> {
205    use std::collections::hash_map::DefaultHasher;
206    use std::hash::{Hash, Hasher};
207    let mut h = DefaultHasher::new();
208    Instant::now().elapsed().as_nanos().hash(&mut h);
209    let n = input.len().max(1);
210    let offset = (h.finish() as usize) % n;
211    let mut out = Vec::with_capacity(input.len());
212    out.extend_from_slice(&input[offset..]);
213    out.extend_from_slice(&input[..offset]);
214    out
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use std::net::Ipv4Addr;
221
222    fn fake(ip: u8) -> EdgeAddr {
223        EdgeAddr {
224            ip: IpAddr::V4(Ipv4Addr::new(198, 41, 192, ip)),
225            port: 7844,
226            version: EdgeIpVersion::V4,
227        }
228    }
229
230    #[test]
231    fn filter_matches_auto() {
232        let e = fake(1);
233        assert!(e.matches(IpVersionFilter::Auto));
234        assert!(e.matches(IpVersionFilter::V4Only));
235        assert!(!e.matches(IpVersionFilter::V6Only));
236    }
237
238    #[test]
239    fn shuffle_preserves_set() {
240        let input: Vec<_> = (0..8).map(fake).collect();
241        let out = shuffled(&input);
242        assert_eq!(out.len(), input.len());
243        let mut in_ips: Vec<_> = input.iter().map(|e| e.ip).collect();
244        let mut out_ips: Vec<_> = out.iter().map(|e| e.ip).collect();
245        in_ips.sort();
246        out_ips.sort();
247        assert_eq!(in_ips, out_ips);
248    }
249
250    #[tokio::test]
251    async fn registry_serves_cached_within_ttl() {
252        let reg = EdgeRegistry::with_ttl(Duration::from_secs(60));
253        {
254            let mut g = reg.inner.write().await;
255            *g = Some(Cached {
256                edges: vec![fake(7), fake(8)],
257                expires_at: Instant::now() + Duration::from_secs(60),
258                filter: IpVersionFilter::Auto,
259            });
260        }
261        let got = reg.get_or_refresh(IpVersionFilter::Auto).await.unwrap();
262        assert_eq!(got.len(), 2);
263    }
264
265    /// Real edge discovery. Gated so CI doesn't pound DNS on every
266    /// PR; opt-in with `CFQT_LIVE_TESTS=1`.
267    #[tokio::test]
268    #[ignore]
269    async fn live_discover_returns_edges() {
270        if std::env::var_os("CFQT_LIVE_TESTS").is_none() {
271            eprintln!("skip: set CFQT_LIVE_TESTS=1 to run");
272            return;
273        }
274        let edges = discover(IpVersionFilter::Auto).await.unwrap();
275        assert!(!edges.is_empty(), "should resolve at least one edge");
276        for e in &edges {
277            assert_eq!(e.port, 7844);
278        }
279    }
280}