1use std::net::{IpAddr, SocketAddr};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
17use hickory_resolver::TokioAsyncResolver;
18use tokio::sync::RwLock;
19use tracing::{debug, warn};
20
21use crate::error::TunnelError;
22
23pub const SRV_NAME: &str = "_v2-origintunneld._tcp.argotunnel.com";
25
26pub const DOT_SERVER_NAME: &str = "cloudflare-dns.com";
28
29pub const DOT_SERVER_ADDR: &str = "1.1.1.1:853";
31
32pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(3600);
34
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub enum IpVersionFilter {
37 #[default]
38 Auto,
39 V4Only,
40 V6Only,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum EdgeIpVersion {
45 V4,
46 V6,
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct EdgeAddr {
51 pub ip: IpAddr,
52 pub port: u16,
53 pub version: EdgeIpVersion,
54}
55
56impl EdgeAddr {
57 pub fn socket(&self) -> SocketAddr {
58 SocketAddr::new(self.ip, self.port)
59 }
60
61 fn from_ip(ip: IpAddr, port: u16) -> Self {
62 let version = if ip.is_ipv4() {
63 EdgeIpVersion::V4
64 } else {
65 EdgeIpVersion::V6
66 };
67 Self { ip, port, version }
68 }
69
70 fn matches(&self, filter: IpVersionFilter) -> bool {
71 matches!(
72 (filter, self.version),
73 (IpVersionFilter::Auto, _)
74 | (IpVersionFilter::V4Only, EdgeIpVersion::V4)
75 | (IpVersionFilter::V6Only, EdgeIpVersion::V6)
76 )
77 }
78}
79
80pub async fn discover(filter: IpVersionFilter) -> Result<Vec<EdgeAddr>, TunnelError> {
83 let primary = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
84 match resolve_srv(&primary, filter).await {
85 Ok(edges) if !edges.is_empty() => return Ok(edges),
86 Ok(_) => warn!("system resolver returned zero edges; falling back to DoT"),
87 Err(e) => warn!(error = %e, "system resolver SRV failed; falling back to DoT"),
88 }
89
90 let dot = build_dot_resolver()?;
91 let edges = resolve_srv(&dot, filter).await?;
92 if edges.is_empty() {
93 return Err(TunnelError::Discovery(format!(
94 "DoT fallback also returned no edges for {SRV_NAME}"
95 )));
96 }
97 Ok(edges)
98}
99
100#[derive(Clone)]
104pub struct EdgeRegistry {
105 inner: Arc<RwLock<Option<Cached>>>,
106 ttl: Duration,
107}
108
109struct Cached {
110 edges: Vec<EdgeAddr>,
111 expires_at: Instant,
112 filter: IpVersionFilter,
113}
114
115impl EdgeRegistry {
116 pub fn new() -> Self {
117 Self::with_ttl(DEFAULT_CACHE_TTL)
118 }
119
120 pub fn with_ttl(ttl: Duration) -> Self {
121 Self {
122 inner: Arc::new(RwLock::new(None)),
123 ttl,
124 }
125 }
126
127 pub async fn get_or_refresh(
128 &self,
129 filter: IpVersionFilter,
130 ) -> Result<Vec<EdgeAddr>, TunnelError> {
131 {
132 let guard = self.inner.read().await;
133 if let Some(c) = guard.as_ref() {
134 if c.filter == filter && c.expires_at > Instant::now() {
135 debug!(count = c.edges.len(), "edge cache hit");
136 return Ok(shuffled(&c.edges));
137 }
138 }
139 }
140 let edges = discover(filter).await?;
141 let mut guard = self.inner.write().await;
142 *guard = Some(Cached {
143 edges: edges.clone(),
144 expires_at: Instant::now() + self.ttl,
145 filter,
146 });
147 Ok(shuffled(&edges))
148 }
149}
150
151impl Default for EdgeRegistry {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157fn build_dot_resolver() -> Result<TokioAsyncResolver, TunnelError> {
160 let addr: SocketAddr = DOT_SERVER_ADDR
161 .parse()
162 .map_err(|e| TunnelError::Discovery(format!("DoT addr parse: {e}")))?;
163 let ns = NameServerConfigGroup::from_ips_tls(
164 &[addr.ip()],
165 addr.port(),
166 DOT_SERVER_NAME.into(),
167 true,
168 );
169 let cfg = ResolverConfig::from_parts(None, vec![], ns);
170 let mut opts = ResolverOpts::default();
171 opts.timeout = Duration::from_secs(15);
172 Ok(TokioAsyncResolver::tokio(cfg, opts))
173}
174
175async fn resolve_srv(
176 resolver: &TokioAsyncResolver,
177 filter: IpVersionFilter,
178) -> Result<Vec<EdgeAddr>, TunnelError> {
179 let srv = resolver
180 .srv_lookup(SRV_NAME)
181 .await
182 .map_err(|e| TunnelError::Discovery(format!("SRV {SRV_NAME}: {e}")))?;
183
184 let mut edges = Vec::new();
185 for rec in srv.iter() {
186 let target = rec.target().to_utf8();
187 let target = target.trim_end_matches('.');
188 let port = rec.port();
189 match resolver.lookup_ip(target).await {
190 Ok(ips) => {
191 for ip in ips.iter() {
192 let edge = EdgeAddr::from_ip(ip, port);
193 if edge.matches(filter) {
194 edges.push(edge);
195 }
196 }
197 }
198 Err(e) => warn!(target, error = %e, "IP resolution failed for SRV target"),
199 }
200 }
201 Ok(edges)
202}
203
204fn shuffled(input: &[EdgeAddr]) -> Vec<EdgeAddr> {
205 use std::collections::hash_map::DefaultHasher;
206 use std::hash::{Hash, Hasher};
207 let mut h = DefaultHasher::new();
208 Instant::now().elapsed().as_nanos().hash(&mut h);
209 let n = input.len().max(1);
210 let offset = (h.finish() as usize) % n;
211 let mut out = Vec::with_capacity(input.len());
212 out.extend_from_slice(&input[offset..]);
213 out.extend_from_slice(&input[..offset]);
214 out
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::net::Ipv4Addr;
221
222 fn fake(ip: u8) -> EdgeAddr {
223 EdgeAddr {
224 ip: IpAddr::V4(Ipv4Addr::new(198, 41, 192, ip)),
225 port: 7844,
226 version: EdgeIpVersion::V4,
227 }
228 }
229
230 #[test]
231 fn filter_matches_auto() {
232 let e = fake(1);
233 assert!(e.matches(IpVersionFilter::Auto));
234 assert!(e.matches(IpVersionFilter::V4Only));
235 assert!(!e.matches(IpVersionFilter::V6Only));
236 }
237
238 #[test]
239 fn shuffle_preserves_set() {
240 let input: Vec<_> = (0..8).map(fake).collect();
241 let out = shuffled(&input);
242 assert_eq!(out.len(), input.len());
243 let mut in_ips: Vec<_> = input.iter().map(|e| e.ip).collect();
244 let mut out_ips: Vec<_> = out.iter().map(|e| e.ip).collect();
245 in_ips.sort();
246 out_ips.sort();
247 assert_eq!(in_ips, out_ips);
248 }
249
250 #[tokio::test]
251 async fn registry_serves_cached_within_ttl() {
252 let reg = EdgeRegistry::with_ttl(Duration::from_secs(60));
253 {
254 let mut g = reg.inner.write().await;
255 *g = Some(Cached {
256 edges: vec![fake(7), fake(8)],
257 expires_at: Instant::now() + Duration::from_secs(60),
258 filter: IpVersionFilter::Auto,
259 });
260 }
261 let got = reg.get_or_refresh(IpVersionFilter::Auto).await.unwrap();
262 assert_eq!(got.len(), 2);
263 }
264
265 #[tokio::test]
268 #[ignore]
269 async fn live_discover_returns_edges() {
270 if std::env::var_os("CFQT_LIVE_TESTS").is_none() {
271 eprintln!("skip: set CFQT_LIVE_TESTS=1 to run");
272 return;
273 }
274 let edges = discover(IpVersionFilter::Auto).await.unwrap();
275 assert!(!edges.is_empty(), "should resolve at least one edge");
276 for e in &edges {
277 assert_eq!(e.port, 7844);
278 }
279 }
280}