use std::net::SocketAddr;
use std::sync::Arc;
use nodedb_types::NodeId;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use super::config::SwimConfig;
use super::detector::{FailureDetector, ProbeScheduler, Transport};
use super::dissemination::DisseminationQueue;
use super::error::SwimError;
use super::incarnation::Incarnation;
use super::member::MemberState;
use super::member::record::MemberUpdate;
use super::membership::MembershipList;
use super::subscriber::MembershipSubscriber;
pub struct SwimHandle {
detector: Arc<FailureDetector>,
membership: Arc<MembershipList>,
shutdown_tx: watch::Sender<bool>,
join: JoinHandle<()>,
}
impl SwimHandle {
pub fn detector(&self) -> &Arc<FailureDetector> {
&self.detector
}
pub fn membership(&self) -> &Arc<MembershipList> {
&self.membership
}
pub fn dissemination(&self) -> &Arc<DisseminationQueue> {
self.detector.dissemination()
}
pub async fn shutdown(self) {
let _ = self.shutdown_tx.send(true);
let _ = self.join.await;
}
}
pub async fn spawn(
cfg: SwimConfig,
local_id: NodeId,
local_addr: SocketAddr,
seeds: Vec<SocketAddr>,
transport: Arc<dyn Transport>,
) -> Result<SwimHandle, SwimError> {
spawn_with_subscribers(cfg, local_id, local_addr, seeds, transport, Vec::new()).await
}
pub async fn spawn_with_subscribers(
cfg: SwimConfig,
local_id: NodeId,
local_addr: SocketAddr,
seeds: Vec<SocketAddr>,
transport: Arc<dyn Transport>,
subscribers: Vec<Arc<dyn MembershipSubscriber>>,
) -> Result<SwimHandle, SwimError> {
cfg.validate()?;
let membership = Arc::new(MembershipList::new_local(
local_id.clone(),
local_addr,
cfg.initial_incarnation,
));
for seed_addr in &seeds {
if *seed_addr == local_addr {
continue;
}
membership.apply(&MemberUpdate {
node_id: NodeId::from_validated(format!("seed:{seed_addr}")),
addr: seed_addr.to_string(),
state: MemberState::Alive,
incarnation: Incarnation::ZERO,
});
}
let initial_inc = cfg.initial_incarnation;
let detector = Arc::new(FailureDetector::with_subscribers(
cfg,
Arc::clone(&membership),
transport,
ProbeScheduler::new(),
subscribers,
));
detector.dissemination().enqueue(MemberUpdate {
node_id: local_id.clone(),
addr: local_addr.to_string(),
state: MemberState::Alive,
incarnation: initial_inc,
});
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let join = tokio::spawn({
let detector = Arc::clone(&detector);
async move { detector.run(shutdown_rx).await }
});
Ok(SwimHandle {
detector,
membership,
shutdown_tx,
join,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::swim::detector::TransportFabric;
use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
fn addr(p: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), p)
}
fn cfg() -> SwimConfig {
SwimConfig {
probe_interval: Duration::from_millis(100),
probe_timeout: Duration::from_millis(40),
indirect_probes: 2,
suspicion_mult: 4,
min_suspicion: Duration::from_millis(500),
initial_incarnation: Incarnation::ZERO,
max_piggyback: 6,
fanout_lambda: 3,
}
}
#[tokio::test]
async fn spawn_solo_cluster_has_only_local() {
let fab = TransportFabric::new();
let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7100)).await);
let handle = spawn(
cfg(),
NodeId::try_new("a").expect("test fixture"),
addr(7100),
vec![],
transport,
)
.await
.expect("spawn");
assert_eq!(handle.membership().len(), 1);
assert!(handle.membership().is_solo());
handle.shutdown().await;
}
#[tokio::test]
async fn spawn_seeds_populates_membership() {
let fab = TransportFabric::new();
let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7110)).await);
let handle = spawn(
cfg(),
NodeId::try_new("a").expect("test fixture"),
addr(7110),
vec![addr(7111), addr(7112)],
transport,
)
.await
.expect("spawn");
assert_eq!(handle.membership().len(), 3);
handle.shutdown().await;
}
#[tokio::test]
async fn spawn_skips_local_addr_in_seeds() {
let fab = TransportFabric::new();
let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7120)).await);
let handle = spawn(
cfg(),
NodeId::try_new("a").expect("test fixture"),
addr(7120),
vec![addr(7120), addr(7121)],
transport,
)
.await
.expect("spawn");
assert_eq!(handle.membership().len(), 2);
handle.shutdown().await;
}
#[tokio::test]
async fn invalid_config_rejected_before_task_spawned() {
let fab = TransportFabric::new();
let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7130)).await);
let mut bad = cfg();
bad.probe_timeout = bad.probe_interval; let res = spawn(
bad,
NodeId::try_new("a").expect("test fixture"),
addr(7130),
vec![],
transport,
)
.await;
match res {
Err(SwimError::InvalidConfig { .. }) => {}
Err(other) => panic!("expected InvalidConfig, got {other:?}"),
Ok(_) => panic!("expected InvalidConfig error"),
}
}
#[tokio::test]
async fn shutdown_joins_promptly() {
let fab = TransportFabric::new();
let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7140)).await);
let handle = spawn(
cfg(),
NodeId::try_new("a").expect("test fixture"),
addr(7140),
vec![],
transport,
)
.await
.expect("spawn");
let start = std::time::Instant::now();
tokio::time::timeout(Duration::from_millis(500), handle.shutdown())
.await
.expect("shutdown did not join within budget");
assert!(start.elapsed() < Duration::from_millis(500));
}
}