use std::sync::{Arc, RwLock};
use nodedb_types::NodeId;
use tracing::debug;
use crate::routing::RoutingTable;
use crate::swim::MemberState;
use crate::swim::subscriber::MembershipSubscriber;
pub type NodeIdResolver = Arc<dyn Fn(&NodeId) -> Option<u64> + Send + Sync>;
pub struct RoutingLivenessHook {
routing: Arc<RwLock<RoutingTable>>,
resolver: NodeIdResolver,
}
impl RoutingLivenessHook {
pub fn new(routing: Arc<RwLock<RoutingTable>>, resolver: NodeIdResolver) -> Self {
Self { routing, resolver }
}
}
impl MembershipSubscriber for RoutingLivenessHook {
fn on_state_change(&self, node_id: &NodeId, _old: Option<MemberState>, new: MemberState) {
if !matches!(
new,
MemberState::Suspect | MemberState::Dead | MemberState::Left
) {
return;
}
let Some(numeric_id) = (self.resolver)(node_id) else {
return;
};
let mut rt = self.routing.write().unwrap_or_else(|p| p.into_inner());
let affected: Vec<u64> = rt
.group_members()
.iter()
.filter(|(_, info)| info.leader == numeric_id)
.map(|(gid, _)| *gid)
.collect();
for gid in &affected {
rt.set_leader(*gid, 0);
}
if !affected.is_empty() {
debug!(
?node_id,
?new,
numeric_id,
groups_invalidated = affected.len(),
"routing liveness hook cleared leader hints"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rt_with_leaders(pairs: &[(u64, u64)], rf: usize) -> Arc<RwLock<RoutingTable>> {
let nodes: Vec<u64> = pairs.iter().map(|(_, l)| *l).collect();
let mut rt = RoutingTable::uniform(pairs.len() as u64, &nodes, rf);
for (gid, leader) in pairs {
rt.set_leader(*gid, *leader);
}
Arc::new(RwLock::new(rt))
}
fn resolver_for(map: &'static [(&'static str, u64)]) -> NodeIdResolver {
Arc::new(move |nid: &NodeId| {
map.iter()
.find(|(s, _)| *s == nid.as_str())
.map(|(_, n)| *n)
})
}
#[test]
fn dead_transition_clears_leader_for_owned_groups() {
let rt = rt_with_leaders(&[(0, 1), (1, 2), (2, 1), (3, 3)], 1);
let hook =
RoutingLivenessHook::new(rt.clone(), resolver_for(&[("a", 1), ("b", 2), ("c", 3)]));
hook.on_state_change(
&NodeId::new("a"),
Some(MemberState::Alive),
MemberState::Dead,
);
let guard = rt.read().unwrap();
assert_eq!(guard.group_info(0).unwrap().leader, 0);
assert_eq!(guard.group_info(1).unwrap().leader, 2);
assert_eq!(guard.group_info(2).unwrap().leader, 0);
assert_eq!(guard.group_info(3).unwrap().leader, 3);
}
#[test]
fn suspect_transition_also_invalidates() {
let rt = rt_with_leaders(&[(0, 7)], 1);
let hook = RoutingLivenessHook::new(rt.clone(), resolver_for(&[("x", 7)]));
hook.on_state_change(
&NodeId::new("x"),
Some(MemberState::Alive),
MemberState::Suspect,
);
assert_eq!(rt.read().unwrap().group_info(0).unwrap().leader, 0);
}
#[test]
fn alive_transition_is_noop() {
let rt = rt_with_leaders(&[(0, 5)], 1);
let hook = RoutingLivenessHook::new(rt.clone(), resolver_for(&[("q", 5)]));
hook.on_state_change(&NodeId::new("q"), None, MemberState::Alive);
assert_eq!(rt.read().unwrap().group_info(0).unwrap().leader, 5);
}
#[test]
fn unresolved_node_id_is_ignored() {
let rt = rt_with_leaders(&[(0, 1)], 1);
let hook = RoutingLivenessHook::new(rt.clone(), resolver_for(&[("a", 1)]));
hook.on_state_change(
&NodeId::new("seed:127.0.0.1:9000"),
Some(MemberState::Alive),
MemberState::Dead,
);
assert_eq!(rt.read().unwrap().group_info(0).unwrap().leader, 1);
}
#[test]
fn left_is_also_invalidating() {
let rt = rt_with_leaders(&[(0, 2)], 1);
let hook = RoutingLivenessHook::new(rt.clone(), resolver_for(&[("b", 2)]));
hook.on_state_change(
&NodeId::new("b"),
Some(MemberState::Alive),
MemberState::Left,
);
assert_eq!(rt.read().unwrap().group_info(0).unwrap().leader, 0);
}
}