use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::time::Duration;
use arc_swap::ArcSwap;
use tracing::{debug, info, warn};
use super::upstream::Backend;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsRecordType {
A,
SRV,
}
#[derive(Debug, Clone)]
pub struct DynamicUpstreamConfig {
pub dns_name: String,
pub port: u16,
pub refresh_interval: Duration,
}
pub struct DynamicBackends {
inner: ArcSwap<Vec<Backend>>,
}
impl Default for DynamicBackends {
fn default() -> Self {
Self::new()
}
}
impl DynamicBackends {
pub fn new() -> Self {
Self {
inner: ArcSwap::from_pointee(Vec::new()),
}
}
pub fn load(&self) -> arc_swap::Guard<Arc<Vec<Backend>>> {
self.inner.load()
}
pub fn store(&self, backends: Vec<Backend>) {
self.inner.store(Arc::new(backends));
}
}
pub struct DnsResolver {
_task: tokio::task::JoinHandle<()>,
pub backends: Arc<DynamicBackends>,
}
impl DnsResolver {
pub fn start(config: &DynamicUpstreamConfig) -> Self {
let backends = Arc::new(DynamicBackends::new());
let backends_ref = Arc::clone(&backends);
let dns_name = config.dns_name.clone();
let port = config.port;
let interval = config.refresh_interval;
let task = tokio::spawn(async move {
resolve_and_update(&dns_name, port, &backends_ref).await;
loop {
tokio::time::sleep(interval).await;
resolve_and_update(&dns_name, port, &backends_ref).await;
}
});
Self {
_task: task,
backends,
}
}
}
impl Drop for DnsResolver {
fn drop(&mut self) {
self._task.abort();
}
}
async fn resolve_and_update(dns_name: &str, port: u16, backends: &DynamicBackends) {
let lookup_target = format!("{dns_name}:{port}");
match tokio::net::lookup_host(&lookup_target).await {
Ok(addrs) => {
let mut new_backends: Vec<Backend> = addrs
.map(|addr| Backend {
addr: addr.to_string(),
weight: 1,
})
.collect();
new_backends.sort_by(|a, b| a.addr.cmp(&b.addr));
new_backends.dedup_by(|a, b| a.addr == b.addr);
let count = new_backends.len();
debug!(
dns = dns_name,
resolved = count,
"DNS upstream resolution complete"
);
if new_backends.is_empty() {
warn!(
dns = dns_name,
"DNS resolution returned zero addresses; keeping previous list"
);
return;
}
backends.store(new_backends);
info!(
dns = dns_name,
backends = count,
"updated dynamic upstream backends"
);
}
Err(e) => {
warn!(
dns = dns_name,
error = %e,
"DNS resolution failed; keeping previous list"
);
}
};
}
pub fn build_tracking_vecs(count: usize) -> (Vec<AtomicBool>, Vec<AtomicUsize>) {
let healthy: Vec<AtomicBool> = (0..count).map(|_| AtomicBool::new(true)).collect();
let active_conns: Vec<AtomicUsize> = (0..count).map(|_| AtomicUsize::new(0)).collect();
(healthy, active_conns)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dynamic_backends_store_load() {
let db = DynamicBackends::new();
assert!(db.load().is_empty());
db.store(vec![
Backend {
addr: "1.2.3.4:8080".into(),
weight: 1,
},
Backend {
addr: "5.6.7.8:8080".into(),
weight: 1,
},
]);
let loaded = db.load();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].addr, "1.2.3.4:8080");
}
#[test]
fn test_build_tracking_vecs() {
let (healthy, conns) = build_tracking_vecs(3);
assert_eq!(healthy.len(), 3);
assert_eq!(conns.len(), 3);
assert!(healthy[0].load(std::sync::atomic::Ordering::Relaxed));
}
}