#[cfg(test)]
mod tests;
use std::net::ToSocketAddrs;
use std::sync::{Arc, RwLock};
use std::time::Duration;
pub enum DiscoverySource {
Static(Vec<String>),
EnvPrefix(String),
File(String),
Dns { hostname: String, port: u16 },
}
impl DiscoverySource {
fn resolve(&self) -> Vec<String> {
match self {
DiscoverySource::Static(v) => v.clone(),
DiscoverySource::EnvPrefix(prefix) => {
let mut backends = Vec::new();
let mut i = 0usize;
loop {
let key = format!("{}_{}", prefix, i);
match std::env::var(&key) {
Ok(val) => { backends.push(val); i += 1; }
Err(_) => break,
}
}
backends
}
DiscoverySource::File(path) => {
match std::fs::read_to_string(path) {
Ok(contents) => contents
.lines()
.map(str::trim)
.filter(|line| !line.is_empty() && !line.starts_with('#'))
.map(str::to_string)
.collect(),
Err(e) => {
eprintln!("service_discovery: cannot read backend file {:?}: {}", path, e);
Vec::new()
}
}
}
DiscoverySource::Dns { hostname, port } => {
let addr_str = format!("{}:{}", hostname, port);
match addr_str.to_socket_addrs() {
Ok(addrs) => addrs
.map(|sa| format!("{}:{}", sa.ip(), sa.port()))
.collect(),
Err(e) => {
eprintln!("service_discovery: DNS lookup for {} failed: {}", addr_str, e);
Vec::new()
}
}
}
}
}
}
#[derive(Clone)]
pub struct BackendPool {
backends: Arc<RwLock<Vec<String>>>,
source: Arc<DiscoverySource>,
poll_interval_secs: u64,
}
impl BackendPool {
fn new(source: DiscoverySource) -> Self {
Self {
backends: Arc::new(RwLock::new(Vec::new())),
source: Arc::new(source),
poll_interval_secs: 30,
}
}
pub fn r#static(backends: Vec<String>) -> Self {
let initial = backends.clone();
let pool = Self::new(DiscoverySource::Static(backends));
*pool.backends.write().unwrap() = initial;
pool
}
pub fn env_prefix(prefix: impl Into<String>) -> Self {
Self::new(DiscoverySource::EnvPrefix(prefix.into()))
}
pub fn file(path: impl Into<String>) -> Self {
Self::new(DiscoverySource::File(path.into()))
}
pub fn dns(hostname: impl Into<String>, port: u16) -> Self {
Self::new(DiscoverySource::Dns { hostname: hostname.into(), port })
}
pub fn poll_interval_secs(mut self, secs: u64) -> Self {
self.poll_interval_secs = secs;
self
}
pub fn start(&self) {
if matches!(self.source.as_ref(), DiscoverySource::Static(_)) {
return;
}
self.refresh();
let pool = self.clone();
let interval = Duration::from_secs(self.poll_interval_secs);
std::thread::spawn(move || loop {
std::thread::sleep(interval);
pool.refresh();
});
}
pub fn backends(&self) -> Vec<String> {
self.backends.read().unwrap().clone()
}
pub fn update(&self, backends: Vec<String>) {
*self.backends.write().unwrap() = backends;
}
pub fn refresh(&self) {
let resolved = self.source.resolve();
*self.backends.write().unwrap() = resolved;
}
}