1use std::collections::HashMap;
4use std::net::SocketAddr;
5
6#[derive(
8 Debug,
9 Clone,
10 Copy,
11 PartialEq,
12 Eq,
13 serde::Serialize,
14 serde::Deserialize,
15 zerompk::ToMessagePack,
16 zerompk::FromMessagePack,
17)]
18#[repr(u8)]
19#[msgpack(c_enum)]
20pub enum NodeState {
21 Joining = 0,
23 Active = 1,
25 Draining = 2,
27 Decommissioned = 3,
29 Learner = 4,
34}
35
36impl NodeState {
37 pub fn as_u8(self) -> u8 {
38 match self {
39 Self::Joining => 0,
40 Self::Active => 1,
41 Self::Draining => 2,
42 Self::Decommissioned => 3,
43 Self::Learner => 4,
44 }
45 }
46
47 pub fn from_u8(v: u8) -> Option<Self> {
48 match v {
49 0 => Some(Self::Joining),
50 1 => Some(Self::Active),
51 2 => Some(Self::Draining),
52 3 => Some(Self::Decommissioned),
53 4 => Some(Self::Learner),
54 _ => None,
55 }
56 }
57
58 pub fn is_voter(self) -> bool {
60 matches!(self, Self::Active)
61 }
62
63 pub fn receives_log(self) -> bool {
65 matches!(self, Self::Learner | Self::Active)
66 }
67}
68
69#[derive(
71 Debug,
72 Clone,
73 serde::Serialize,
74 serde::Deserialize,
75 zerompk::ToMessagePack,
76 zerompk::FromMessagePack,
77)]
78pub struct NodeInfo {
79 pub node_id: u64,
80 pub addr: String,
82 pub state: NodeState,
83 pub raft_groups: Vec<u64>,
85}
86
87impl NodeInfo {
88 pub fn new(node_id: u64, addr: SocketAddr, state: NodeState) -> Self {
89 Self {
90 node_id,
91 addr: addr.to_string(),
92 state,
93 raft_groups: Vec::new(),
94 }
95 }
96
97 pub fn socket_addr(&self) -> Option<SocketAddr> {
98 self.addr.parse().ok()
99 }
100}
101
102#[derive(
107 Debug,
108 Clone,
109 serde::Serialize,
110 serde::Deserialize,
111 zerompk::ToMessagePack,
112 zerompk::FromMessagePack,
113)]
114pub struct ClusterTopology {
115 nodes: HashMap<u64, NodeInfo>,
116 version: u64,
118}
119
120impl ClusterTopology {
121 pub fn new() -> Self {
122 Self {
123 nodes: HashMap::new(),
124 version: 0,
125 }
126 }
127
128 pub fn add_node(&mut self, info: NodeInfo) {
130 self.nodes.insert(info.node_id, info);
131 self.version += 1;
132 }
133
134 pub fn remove_node(&mut self, node_id: u64) -> Option<NodeInfo> {
136 let removed = self.nodes.remove(&node_id);
137 if removed.is_some() {
138 self.version += 1;
139 }
140 removed
141 }
142
143 pub fn get_node(&self, node_id: u64) -> Option<&NodeInfo> {
144 self.nodes.get(&node_id)
145 }
146
147 pub fn get_node_mut(&mut self, node_id: u64) -> Option<&mut NodeInfo> {
148 self.nodes.get_mut(&node_id)
149 }
150
151 pub fn set_state(&mut self, node_id: u64, state: NodeState) -> bool {
153 if let Some(info) = self.nodes.get_mut(&node_id) {
154 info.state = state;
155 self.version += 1;
156 true
157 } else {
158 false
159 }
160 }
161
162 pub fn active_nodes(&self) -> Vec<&NodeInfo> {
164 self.nodes
165 .values()
166 .filter(|n| n.state == NodeState::Active)
167 .collect()
168 }
169
170 pub fn all_nodes(&self) -> impl Iterator<Item = &NodeInfo> {
172 self.nodes.values()
173 }
174
175 pub fn node_count(&self) -> usize {
176 self.nodes.len()
177 }
178
179 pub fn version(&self) -> u64 {
180 self.version
181 }
182
183 pub fn contains(&self, node_id: u64) -> bool {
184 self.nodes.contains_key(&node_id)
185 }
186
187 pub fn join_as_learner(&mut self, info: NodeInfo) -> bool {
193 if self.nodes.contains_key(&info.node_id) {
194 return false; }
196 let mut learner = info;
197 learner.state = NodeState::Learner;
198 self.nodes.insert(learner.node_id, learner);
199 self.version += 1;
200 true
201 }
202
203 pub fn promote_to_voter(&mut self, node_id: u64) -> bool {
208 if let Some(info) = self.nodes.get_mut(&node_id)
209 && info.state == NodeState::Learner
210 {
211 info.state = NodeState::Active;
212 self.version += 1;
213 return true;
214 }
215 false
216 }
217
218 pub fn learner_nodes(&self) -> Vec<&NodeInfo> {
220 self.nodes
221 .values()
222 .filter(|n| n.state == NodeState::Learner)
223 .collect()
224 }
225}
226
227impl Default for ClusterTopology {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn add_and_lookup() {
239 let mut topo = ClusterTopology::new();
240 topo.add_node(NodeInfo::new(
241 1,
242 "127.0.0.1:9400".parse().unwrap(),
243 NodeState::Active,
244 ));
245 topo.add_node(NodeInfo::new(
246 2,
247 "127.0.0.1:9401".parse().unwrap(),
248 NodeState::Joining,
249 ));
250
251 assert_eq!(topo.node_count(), 2);
252 assert_eq!(topo.version(), 2);
253 assert_eq!(topo.active_nodes().len(), 1);
254 assert!(topo.contains(1));
255 assert!(topo.contains(2));
256 }
257
258 #[test]
259 fn remove_node() {
260 let mut topo = ClusterTopology::new();
261 topo.add_node(NodeInfo::new(
262 1,
263 "127.0.0.1:9400".parse().unwrap(),
264 NodeState::Active,
265 ));
266 let removed = topo.remove_node(1);
267 assert!(removed.is_some());
268 assert_eq!(topo.node_count(), 0);
269 assert_eq!(topo.version(), 2); }
271
272 #[test]
273 fn set_state() {
274 let mut topo = ClusterTopology::new();
275 topo.add_node(NodeInfo::new(
276 1,
277 "127.0.0.1:9400".parse().unwrap(),
278 NodeState::Joining,
279 ));
280 assert!(topo.set_state(1, NodeState::Active));
281 assert_eq!(topo.get_node(1).unwrap().state, NodeState::Active);
282 }
283
284 #[test]
285 fn node_state_roundtrip() {
286 for state in [
287 NodeState::Joining,
288 NodeState::Active,
289 NodeState::Draining,
290 NodeState::Decommissioned,
291 ] {
292 assert_eq!(NodeState::from_u8(state.as_u8()), Some(state));
293 }
294 assert_eq!(NodeState::from_u8(255), None);
295 }
296
297 #[test]
298 fn serde_roundtrip() {
299 let mut topo = ClusterTopology::new();
300 topo.add_node(NodeInfo::new(
301 1,
302 "127.0.0.1:9400".parse().unwrap(),
303 NodeState::Active,
304 ));
305 topo.add_node(NodeInfo::new(
306 2,
307 "127.0.0.1:9401".parse().unwrap(),
308 NodeState::Active,
309 ));
310
311 let bytes = zerompk::to_msgpack_vec(&topo).unwrap();
312 let decoded: ClusterTopology = zerompk::from_msgpack(&bytes).unwrap();
313 assert_eq!(decoded.node_count(), 2);
314 assert_eq!(decoded.version(), 2);
315 }
316}