use async_trait::async_trait;
use crate::error::ClusterError;
use crate::hash_ring::SiloAddress;
#[async_trait]
pub trait MembershipProvider: Send + Sync + 'static {
async fn get_members(&self) -> Result<Vec<SiloAddress>, ClusterError>;
}
#[derive(Debug, Clone)]
pub struct StaticSeedProvider {
seeds: Vec<SiloAddress>,
}
impl StaticSeedProvider {
pub fn new(addresses: Vec<impl AsRef<str>>) -> Self {
let seeds = addresses
.iter()
.filter_map(|addr| {
let s = addr.as_ref();
let Some((host, port_str)) = s.rsplit_once(':') else {
tracing::warn!(
target: "discovery",
addr = s,
"ignoring invalid seed address: missing ':' separator"
);
return None;
};
let port = match port_str.parse::<u16>() {
Ok(p) => p,
Err(e) => {
tracing::warn!(
target: "discovery",
addr = s,
error = %e,
"ignoring invalid seed address: port not a u16"
);
return None;
}
};
Some(SiloAddress {
host: host.to_string(),
port,
silo_id: s.to_string(),
})
})
.collect();
Self { seeds }
}
}
#[async_trait]
impl MembershipProvider for StaticSeedProvider {
async fn get_members(&self) -> Result<Vec<SiloAddress>, ClusterError> {
Ok(self.seeds.clone())
}
}
#[derive(Debug, Clone)]
pub struct DnsMembershipProvider {
hostname: String,
port: u16,
}
impl DnsMembershipProvider {
pub fn new(hostname: impl Into<String>, port: u16) -> Self {
Self {
hostname: hostname.into(),
port,
}
}
}
#[async_trait]
impl MembershipProvider for DnsMembershipProvider {
async fn get_members(&self) -> Result<Vec<SiloAddress>, ClusterError> {
let resolved = tokio::net::lookup_host(format!("{}:{}", self.hostname, self.port))
.await
.map_err(|e| ClusterError::Transport(format!("DNS lookup failed: {}", e)))?;
let members = resolved
.map(|addr| SiloAddress {
host: addr.ip().to_string(),
port: addr.port(),
silo_id: format!("{}:{}", addr.ip(), addr.port()),
})
.collect();
Ok(members)
}
}
#[cfg(feature = "consul")]
#[derive(Debug, Clone)]
pub struct ConsulMembershipProvider {
consul_url: String,
service_name: String,
}
#[cfg(feature = "consul")]
impl ConsulMembershipProvider {
pub fn new(consul_url: impl Into<String>, service_name: impl Into<String>) -> Self {
Self {
consul_url: consul_url.into(),
service_name: service_name.into(),
}
}
}
#[cfg(feature = "consul")]
#[async_trait]
impl MembershipProvider for ConsulMembershipProvider {
async fn get_members(&self) -> Result<Vec<SiloAddress>, ClusterError> {
let url = format!(
"{}/v1/health/service/{}?passing=true",
self.consul_url.trim_end_matches('/'),
self.service_name,
);
let client = reqwest::Client::new();
let response = client
.get(&url)
.send()
.await
.map_err(|e| ClusterError::Transport(format!("Consul request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ClusterError::Transport(format!(
"Consul returned status {}",
response.status()
)));
}
let entries: Vec<serde_json::Value> = response
.json()
.await
.map_err(|e| ClusterError::Transport(format!("Consul response parse failed: {}", e)))?;
let members = entries
.iter()
.filter_map(|entry| {
let address = entry["Service"]["Address"]
.as_str()
.filter(|s| !s.is_empty())
.or_else(|| entry["Node"]["Address"].as_str())?;
let port = entry["Service"]["Port"].as_u64()? as u16;
let service_id = entry["Service"]["ID"]
.as_str()
.unwrap_or("")
.to_string();
Some(SiloAddress {
host: address.to_string(),
port,
silo_id: if service_id.is_empty() {
format!("{}:{}", address, port)
} else {
service_id
},
})
})
.collect();
Ok(members)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn static_seed_provider_parses_addresses() {
let provider = StaticSeedProvider::new(vec!["127.0.0.1:5001", "10.0.0.2:5002"]);
let members = provider.get_members().await.unwrap();
assert_eq!(members.len(), 2);
assert_eq!(members[0].host, "127.0.0.1");
assert_eq!(members[0].port, 5001);
assert_eq!(members[1].host, "10.0.0.2");
assert_eq!(members[1].port, 5002);
}
#[tokio::test]
async fn static_seed_provider_skips_invalid() {
let provider = StaticSeedProvider::new(vec!["valid:1234", "no-port", "also:bad"]);
let members = provider.get_members().await.unwrap();
assert_eq!(members.len(), 1);
assert_eq!(members[0].host, "valid");
assert_eq!(members[0].port, 1234);
}
#[tokio::test]
async fn static_seed_provider_empty() {
let provider = StaticSeedProvider::new(Vec::<&str>::new());
let members = provider.get_members().await.unwrap();
assert!(members.is_empty());
}
}