use std::collections::HashMap;
use std::net::SocketAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum NodeState {
Joining,
Learner,
Active,
Draining,
Decommissioned,
}
impl NodeState {
pub fn as_u8(self) -> u8 {
match self {
Self::Joining => 0,
Self::Active => 1,
Self::Draining => 2,
Self::Decommissioned => 3,
Self::Learner => 4,
}
}
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Joining),
1 => Some(Self::Active),
2 => Some(Self::Draining),
3 => Some(Self::Decommissioned),
4 => Some(Self::Learner),
_ => None,
}
}
pub fn is_voter(self) -> bool {
matches!(self, Self::Active)
}
pub fn receives_log(self) -> bool {
matches!(self, Self::Learner | Self::Active)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct NodeInfo {
pub node_id: u64,
pub addr: String,
pub state: NodeState,
pub raft_groups: Vec<u64>,
}
impl NodeInfo {
pub fn new(node_id: u64, addr: SocketAddr, state: NodeState) -> Self {
Self {
node_id,
addr: addr.to_string(),
state,
raft_groups: Vec::new(),
}
}
pub fn socket_addr(&self) -> Option<SocketAddr> {
self.addr.parse().ok()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ClusterTopology {
nodes: HashMap<u64, NodeInfo>,
version: u64,
}
impl ClusterTopology {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
version: 0,
}
}
pub fn add_node(&mut self, info: NodeInfo) {
self.nodes.insert(info.node_id, info);
self.version += 1;
}
pub fn remove_node(&mut self, node_id: u64) -> Option<NodeInfo> {
let removed = self.nodes.remove(&node_id);
if removed.is_some() {
self.version += 1;
}
removed
}
pub fn get_node(&self, node_id: u64) -> Option<&NodeInfo> {
self.nodes.get(&node_id)
}
pub fn get_node_mut(&mut self, node_id: u64) -> Option<&mut NodeInfo> {
self.nodes.get_mut(&node_id)
}
pub fn set_state(&mut self, node_id: u64, state: NodeState) -> bool {
if let Some(info) = self.nodes.get_mut(&node_id) {
info.state = state;
self.version += 1;
true
} else {
false
}
}
pub fn active_nodes(&self) -> Vec<&NodeInfo> {
self.nodes
.values()
.filter(|n| n.state == NodeState::Active)
.collect()
}
pub fn all_nodes(&self) -> impl Iterator<Item = &NodeInfo> {
self.nodes.values()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn version(&self) -> u64 {
self.version
}
pub fn contains(&self, node_id: u64) -> bool {
self.nodes.contains_key(&node_id)
}
pub fn join_as_learner(&mut self, info: NodeInfo) -> bool {
if self.nodes.contains_key(&info.node_id) {
return false; }
let mut learner = info;
learner.state = NodeState::Learner;
self.nodes.insert(learner.node_id, learner);
self.version += 1;
true
}
pub fn promote_to_voter(&mut self, node_id: u64) -> bool {
if let Some(info) = self.nodes.get_mut(&node_id)
&& info.state == NodeState::Learner
{
info.state = NodeState::Active;
self.version += 1;
return true;
}
false
}
pub fn learner_nodes(&self) -> Vec<&NodeInfo> {
self.nodes
.values()
.filter(|n| n.state == NodeState::Learner)
.collect()
}
}
impl Default for ClusterTopology {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_and_lookup() {
let mut topo = ClusterTopology::new();
topo.add_node(NodeInfo::new(
1,
"127.0.0.1:9400".parse().unwrap(),
NodeState::Active,
));
topo.add_node(NodeInfo::new(
2,
"127.0.0.1:9401".parse().unwrap(),
NodeState::Joining,
));
assert_eq!(topo.node_count(), 2);
assert_eq!(topo.version(), 2);
assert_eq!(topo.active_nodes().len(), 1);
assert!(topo.contains(1));
assert!(topo.contains(2));
}
#[test]
fn remove_node() {
let mut topo = ClusterTopology::new();
topo.add_node(NodeInfo::new(
1,
"127.0.0.1:9400".parse().unwrap(),
NodeState::Active,
));
let removed = topo.remove_node(1);
assert!(removed.is_some());
assert_eq!(topo.node_count(), 0);
assert_eq!(topo.version(), 2); }
#[test]
fn set_state() {
let mut topo = ClusterTopology::new();
topo.add_node(NodeInfo::new(
1,
"127.0.0.1:9400".parse().unwrap(),
NodeState::Joining,
));
assert!(topo.set_state(1, NodeState::Active));
assert_eq!(topo.get_node(1).unwrap().state, NodeState::Active);
}
#[test]
fn node_state_roundtrip() {
for state in [
NodeState::Joining,
NodeState::Active,
NodeState::Draining,
NodeState::Decommissioned,
] {
assert_eq!(NodeState::from_u8(state.as_u8()), Some(state));
}
assert_eq!(NodeState::from_u8(255), None);
}
#[test]
fn serde_roundtrip() {
let mut topo = ClusterTopology::new();
topo.add_node(NodeInfo::new(
1,
"127.0.0.1:9400".parse().unwrap(),
NodeState::Active,
));
topo.add_node(NodeInfo::new(
2,
"127.0.0.1:9401".parse().unwrap(),
NodeState::Active,
));
let bytes = rmp_serde::to_vec(&topo).unwrap();
let decoded: ClusterTopology = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.node_count(), 2);
assert_eq!(decoded.version(), 2);
}
}