Skip to main content

gatel_core/proxy/
dns_upstream.rs

1//! DNS-based dynamic upstream resolution.
2//!
3//! Periodically resolves a DNS name and updates an `UpstreamPool` with the
4//! resulting addresses.  Currently supports A/AAAA records via
5//! `tokio::net::lookup_host`; SRV record support is planned as a future
6//! extension.
7
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicUsize};
10use std::time::Duration;
11
12use arc_swap::ArcSwap;
13use tracing::{debug, info, warn};
14
15use super::upstream::Backend;
16
17// ---------------------------------------------------------------------------
18// Configuration types
19// ---------------------------------------------------------------------------
20
21/// Which DNS record types to query.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DnsRecordType {
24    /// A and AAAA records.  The port is taken from configuration.
25    A,
26    /// SRV records (future extension).  Port comes from the SRV record.
27    SRV,
28}
29
30/// Configuration for a DNS-based dynamic upstream source.
31#[derive(Debug, Clone)]
32pub struct DynamicUpstreamConfig {
33    /// DNS name to resolve, e.g. `"app.svc.cluster.local"`.
34    pub dns_name: String,
35    /// Port to pair with resolved IPs (used for A/AAAA; ignored for SRV).
36    pub port: u16,
37    /// How often to re-resolve.
38    pub refresh_interval: Duration,
39}
40
41// ---------------------------------------------------------------------------
42// Dynamic backend list
43// ---------------------------------------------------------------------------
44
45/// A thread-safe, atomically swappable list of backends populated by DNS
46/// resolution.  Readers get a snapshot via `load()`; the background resolver
47/// updates via `store()`.
48pub struct DynamicBackends {
49    inner: ArcSwap<Vec<Backend>>,
50}
51
52impl Default for DynamicBackends {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl DynamicBackends {
59    pub fn new() -> Self {
60        Self {
61            inner: ArcSwap::from_pointee(Vec::new()),
62        }
63    }
64
65    /// Get the current list of backends.
66    pub fn load(&self) -> arc_swap::Guard<Arc<Vec<Backend>>> {
67        self.inner.load()
68    }
69
70    /// Replace the backend list atomically.
71    pub fn store(&self, backends: Vec<Backend>) {
72        self.inner.store(Arc::new(backends));
73    }
74}
75
76// ---------------------------------------------------------------------------
77// DNS resolver task
78// ---------------------------------------------------------------------------
79
80/// Handle to the background DNS resolver.  Aborting the task on drop ensures
81/// we don't leak spawned work.
82pub struct DnsResolver {
83    _task: tokio::task::JoinHandle<()>,
84    /// The dynamically-updated backend list.  Shared with whoever needs to
85    /// read the current set of upstreams.
86    pub backends: Arc<DynamicBackends>,
87}
88
89impl DnsResolver {
90    /// Start a background task that periodically resolves `config.dns_name`
91    /// and updates the shared backend list.
92    pub fn start(config: &DynamicUpstreamConfig) -> Self {
93        let backends = Arc::new(DynamicBackends::new());
94        let backends_ref = Arc::clone(&backends);
95        let dns_name = config.dns_name.clone();
96        let port = config.port;
97        let interval = config.refresh_interval;
98
99        let task = tokio::spawn(async move {
100            // Perform an initial resolution immediately.
101            resolve_and_update(&dns_name, port, &backends_ref).await;
102
103            loop {
104                tokio::time::sleep(interval).await;
105                resolve_and_update(&dns_name, port, &backends_ref).await;
106            }
107        });
108
109        Self {
110            _task: task,
111            backends,
112        }
113    }
114}
115
116impl Drop for DnsResolver {
117    fn drop(&mut self) {
118        self._task.abort();
119    }
120}
121
122/// Resolve a DNS name and update the dynamic backend list.
123async fn resolve_and_update(dns_name: &str, port: u16, backends: &DynamicBackends) {
124    let lookup_target = format!("{dns_name}:{port}");
125
126    match tokio::net::lookup_host(&lookup_target).await {
127        Ok(addrs) => {
128            let mut new_backends: Vec<Backend> = addrs
129                .map(|addr| Backend {
130                    addr: addr.to_string(),
131                    weight: 1,
132                })
133                .collect();
134
135            // Sort for deterministic ordering so we can detect changes.
136            new_backends.sort_by(|a, b| a.addr.cmp(&b.addr));
137            // Deduplicate.
138            new_backends.dedup_by(|a, b| a.addr == b.addr);
139
140            let count = new_backends.len();
141            debug!(
142                dns = dns_name,
143                resolved = count,
144                "DNS upstream resolution complete"
145            );
146
147            if new_backends.is_empty() {
148                warn!(
149                    dns = dns_name,
150                    "DNS resolution returned zero addresses; keeping previous list"
151                );
152                return;
153            }
154
155            backends.store(new_backends);
156            info!(
157                dns = dns_name,
158                backends = count,
159                "updated dynamic upstream backends"
160            );
161        }
162        Err(e) => {
163            warn!(
164                dns = dns_name,
165                error = %e,
166                "DNS resolution failed; keeping previous list"
167            );
168        }
169    };
170}
171
172/// Build per-backend health and connection tracking vectors for a set of
173/// dynamic backends.  This is a helper for code that needs to construct an
174/// `UpstreamPool`-like structure from dynamic backends.
175pub fn build_tracking_vecs(count: usize) -> (Vec<AtomicBool>, Vec<AtomicUsize>) {
176    let healthy: Vec<AtomicBool> = (0..count).map(|_| AtomicBool::new(true)).collect();
177    let active_conns: Vec<AtomicUsize> = (0..count).map(|_| AtomicUsize::new(0)).collect();
178    (healthy, active_conns)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_dynamic_backends_store_load() {
187        let db = DynamicBackends::new();
188        assert!(db.load().is_empty());
189
190        db.store(vec![
191            Backend {
192                addr: "1.2.3.4:8080".into(),
193                weight: 1,
194            },
195            Backend {
196                addr: "5.6.7.8:8080".into(),
197                weight: 1,
198            },
199        ]);
200        let loaded = db.load();
201        assert_eq!(loaded.len(), 2);
202        assert_eq!(loaded[0].addr, "1.2.3.4:8080");
203    }
204
205    #[test]
206    fn test_build_tracking_vecs() {
207        let (healthy, conns) = build_tracking_vecs(3);
208        assert_eq!(healthy.len(), 3);
209        assert_eq!(conns.len(), 3);
210        assert!(healthy[0].load(std::sync::atomic::Ordering::Relaxed));
211    }
212}