1use std::collections::HashMap;
4use std::net::SocketAddr;
5
6pub const CLUSTER_WIRE_FORMAT_VERSION: u16 = 4;
17
18fn default_wire_version() -> u16 {
19 1
24}
25
26#[derive(
28 Debug,
29 Clone,
30 Copy,
31 PartialEq,
32 Eq,
33 serde::Serialize,
34 serde::Deserialize,
35 zerompk::ToMessagePack,
36 zerompk::FromMessagePack,
37)]
38#[repr(u8)]
39#[msgpack(c_enum)]
40pub enum NodeState {
41 Joining = 0,
43 Active = 1,
45 Draining = 2,
47 Decommissioned = 3,
49 Learner = 4,
54}
55
56impl NodeState {
57 pub fn as_u8(self) -> u8 {
58 match self {
59 Self::Joining => 0,
60 Self::Active => 1,
61 Self::Draining => 2,
62 Self::Decommissioned => 3,
63 Self::Learner => 4,
64 }
65 }
66
67 pub fn from_u8(v: u8) -> Option<Self> {
68 match v {
69 0 => Some(Self::Joining),
70 1 => Some(Self::Active),
71 2 => Some(Self::Draining),
72 3 => Some(Self::Decommissioned),
73 4 => Some(Self::Learner),
74 _ => None,
75 }
76 }
77
78 pub fn is_voter(self) -> bool {
80 matches!(self, Self::Active)
81 }
82
83 pub fn receives_log(self) -> bool {
85 matches!(self, Self::Learner | Self::Active)
86 }
87}
88
89#[derive(
91 Debug,
92 Clone,
93 serde::Serialize,
94 serde::Deserialize,
95 zerompk::ToMessagePack,
96 zerompk::FromMessagePack,
97)]
98pub struct NodeInfo {
99 pub node_id: u64,
100 pub addr: String,
102 pub state: NodeState,
103 pub raft_groups: Vec<u64>,
105 #[serde(default = "default_wire_version")]
112 pub wire_version: u16,
113}
114
115impl NodeInfo {
116 pub fn new(node_id: u64, addr: SocketAddr, state: NodeState) -> Self {
120 Self {
121 node_id,
122 addr: addr.to_string(),
123 state,
124 raft_groups: Vec::new(),
125 wire_version: CLUSTER_WIRE_FORMAT_VERSION,
126 }
127 }
128
129 pub fn with_wire_version(mut self, wire_version: u16) -> Self {
133 self.wire_version = wire_version;
134 self
135 }
136
137 pub fn socket_addr(&self) -> Option<SocketAddr> {
138 self.addr.parse().ok()
139 }
140}
141
142#[derive(
147 Debug,
148 Clone,
149 serde::Serialize,
150 serde::Deserialize,
151 zerompk::ToMessagePack,
152 zerompk::FromMessagePack,
153)]
154pub struct ClusterTopology {
155 nodes: HashMap<u64, NodeInfo>,
156 version: u64,
158}
159
160impl ClusterTopology {
161 pub fn new() -> Self {
162 Self {
163 nodes: HashMap::new(),
164 version: 0,
165 }
166 }
167
168 pub fn add_node(&mut self, info: NodeInfo) {
170 self.nodes.insert(info.node_id, info);
171 self.version += 1;
172 }
173
174 pub fn remove_node(&mut self, node_id: u64) -> Option<NodeInfo> {
176 let removed = self.nodes.remove(&node_id);
177 if removed.is_some() {
178 self.version += 1;
179 }
180 removed
181 }
182
183 pub fn get_node(&self, node_id: u64) -> Option<&NodeInfo> {
184 self.nodes.get(&node_id)
185 }
186
187 pub fn get_node_mut(&mut self, node_id: u64) -> Option<&mut NodeInfo> {
188 self.nodes.get_mut(&node_id)
189 }
190
191 pub fn set_state(&mut self, node_id: u64, state: NodeState) -> bool {
193 if let Some(info) = self.nodes.get_mut(&node_id) {
194 info.state = state;
195 self.version += 1;
196 true
197 } else {
198 false
199 }
200 }
201
202 pub fn active_nodes(&self) -> Vec<&NodeInfo> {
204 self.nodes
205 .values()
206 .filter(|n| n.state == NodeState::Active)
207 .collect()
208 }
209
210 pub fn all_nodes(&self) -> impl Iterator<Item = &NodeInfo> {
212 self.nodes.values()
213 }
214
215 pub fn node_count(&self) -> usize {
216 self.nodes.len()
217 }
218
219 pub fn version(&self) -> u64 {
220 self.version
221 }
222
223 pub fn contains(&self, node_id: u64) -> bool {
224 self.nodes.contains_key(&node_id)
225 }
226
227 pub fn join_as_learner(&mut self, info: NodeInfo) -> bool {
233 if self.nodes.contains_key(&info.node_id) {
234 return false; }
236 let mut learner = info;
237 learner.state = NodeState::Learner;
238 self.nodes.insert(learner.node_id, learner);
239 self.version += 1;
240 true
241 }
242
243 pub fn promote_to_voter(&mut self, node_id: u64) -> bool {
248 if let Some(info) = self.nodes.get_mut(&node_id)
249 && info.state == NodeState::Learner
250 {
251 info.state = NodeState::Active;
252 self.version += 1;
253 return true;
254 }
255 false
256 }
257
258 pub fn learner_nodes(&self) -> Vec<&NodeInfo> {
260 self.nodes
261 .values()
262 .filter(|n| n.state == NodeState::Learner)
263 .collect()
264 }
265}
266
267impl Default for ClusterTopology {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn add_and_lookup() {
279 let mut topo = ClusterTopology::new();
280 topo.add_node(NodeInfo::new(
281 1,
282 "127.0.0.1:9400".parse().unwrap(),
283 NodeState::Active,
284 ));
285 topo.add_node(NodeInfo::new(
286 2,
287 "127.0.0.1:9401".parse().unwrap(),
288 NodeState::Joining,
289 ));
290
291 assert_eq!(topo.node_count(), 2);
292 assert_eq!(topo.version(), 2);
293 assert_eq!(topo.active_nodes().len(), 1);
294 assert!(topo.contains(1));
295 assert!(topo.contains(2));
296 }
297
298 #[test]
299 fn remove_node() {
300 let mut topo = ClusterTopology::new();
301 topo.add_node(NodeInfo::new(
302 1,
303 "127.0.0.1:9400".parse().unwrap(),
304 NodeState::Active,
305 ));
306 let removed = topo.remove_node(1);
307 assert!(removed.is_some());
308 assert_eq!(topo.node_count(), 0);
309 assert_eq!(topo.version(), 2); }
311
312 #[test]
313 fn set_state() {
314 let mut topo = ClusterTopology::new();
315 topo.add_node(NodeInfo::new(
316 1,
317 "127.0.0.1:9400".parse().unwrap(),
318 NodeState::Joining,
319 ));
320 assert!(topo.set_state(1, NodeState::Active));
321 assert_eq!(topo.get_node(1).unwrap().state, NodeState::Active);
322 }
323
324 #[test]
325 fn node_state_roundtrip() {
326 for state in [
327 NodeState::Joining,
328 NodeState::Active,
329 NodeState::Draining,
330 NodeState::Decommissioned,
331 ] {
332 assert_eq!(NodeState::from_u8(state.as_u8()), Some(state));
333 }
334 assert_eq!(NodeState::from_u8(255), None);
335 }
336
337 #[test]
338 fn serde_roundtrip() {
339 let mut topo = ClusterTopology::new();
340 topo.add_node(NodeInfo::new(
341 1,
342 "127.0.0.1:9400".parse().unwrap(),
343 NodeState::Active,
344 ));
345 topo.add_node(NodeInfo::new(
346 2,
347 "127.0.0.1:9401".parse().unwrap(),
348 NodeState::Active,
349 ));
350
351 let bytes = zerompk::to_msgpack_vec(&topo).unwrap();
352 let decoded: ClusterTopology = zerompk::from_msgpack(&bytes).unwrap();
353 assert_eq!(decoded.node_count(), 2);
354 assert_eq!(decoded.version(), 2);
355 }
356}