1use std::collections::HashMap;
6use std::net::SocketAddr;
7
8pub use nodedb_types::wire_version::WIRE_FORMAT_VERSION as CLUSTER_WIRE_FORMAT_VERSION;
15
16fn default_spiffe_id() -> Option<String> {
17 None
18}
19
20fn default_spki_pin() -> Option<[u8; 32]> {
21 None
22}
23
24fn default_wire_version() -> u16 {
25 1
30}
31
32#[derive(
34 Debug,
35 Clone,
36 Copy,
37 PartialEq,
38 Eq,
39 serde::Serialize,
40 serde::Deserialize,
41 zerompk::ToMessagePack,
42 zerompk::FromMessagePack,
43)]
44#[repr(u8)]
45#[msgpack(c_enum)]
46pub enum NodeState {
47 Joining = 0,
49 Active = 1,
51 Draining = 2,
53 Decommissioned = 3,
55 Learner = 4,
60}
61
62impl NodeState {
63 pub fn as_u8(self) -> u8 {
64 match self {
65 Self::Joining => 0,
66 Self::Active => 1,
67 Self::Draining => 2,
68 Self::Decommissioned => 3,
69 Self::Learner => 4,
70 }
71 }
72
73 pub fn from_u8(v: u8) -> Option<Self> {
74 match v {
75 0 => Some(Self::Joining),
76 1 => Some(Self::Active),
77 2 => Some(Self::Draining),
78 3 => Some(Self::Decommissioned),
79 4 => Some(Self::Learner),
80 _ => None,
81 }
82 }
83
84 pub fn is_voter(self) -> bool {
86 matches!(self, Self::Active)
87 }
88
89 pub fn receives_log(self) -> bool {
91 matches!(self, Self::Learner | Self::Active)
92 }
93}
94
95#[derive(
97 Debug,
98 Clone,
99 serde::Serialize,
100 serde::Deserialize,
101 zerompk::ToMessagePack,
102 zerompk::FromMessagePack,
103)]
104pub struct NodeInfo {
105 pub node_id: u64,
106 pub addr: String,
108 pub state: NodeState,
109 pub raft_groups: Vec<u64>,
111 #[serde(default = "default_wire_version")]
118 pub wire_version: u16,
119 #[serde(default = "default_spiffe_id")]
124 pub spiffe_id: Option<String>,
125 #[serde(default = "default_spki_pin")]
130 pub spki_pin: Option<[u8; 32]>,
131}
132
133impl NodeInfo {
134 pub fn new(node_id: u64, addr: SocketAddr, state: NodeState) -> Self {
138 Self {
139 node_id,
140 addr: addr.to_string(),
141 state,
142 raft_groups: Vec::new(),
143 wire_version: CLUSTER_WIRE_FORMAT_VERSION,
144 spiffe_id: None,
145 spki_pin: None,
146 }
147 }
148
149 pub fn with_wire_version(mut self, wire_version: u16) -> Self {
153 self.wire_version = wire_version;
154 self
155 }
156
157 pub fn with_spiffe_id(mut self, spiffe_id: Option<String>) -> Self {
159 self.spiffe_id = spiffe_id;
160 self
161 }
162
163 pub fn with_spki_pin(mut self, spki_pin: Option<[u8; 32]>) -> Self {
165 self.spki_pin = spki_pin;
166 self
167 }
168
169 pub fn socket_addr(&self) -> Option<SocketAddr> {
170 self.addr.parse().ok()
171 }
172}
173
174#[derive(
179 Debug,
180 Clone,
181 serde::Serialize,
182 serde::Deserialize,
183 zerompk::ToMessagePack,
184 zerompk::FromMessagePack,
185)]
186pub struct ClusterTopology {
187 nodes: HashMap<u64, NodeInfo>,
188 version: u64,
190}
191
192impl ClusterTopology {
193 pub fn new() -> Self {
194 Self {
195 nodes: HashMap::new(),
196 version: 0,
197 }
198 }
199
200 pub fn add_node(&mut self, info: NodeInfo) {
202 self.nodes.insert(info.node_id, info);
203 self.version += 1;
204 }
205
206 pub fn remove_node(&mut self, node_id: u64) -> Option<NodeInfo> {
208 let removed = self.nodes.remove(&node_id);
209 if removed.is_some() {
210 self.version += 1;
211 }
212 removed
213 }
214
215 pub fn get_node(&self, node_id: u64) -> Option<&NodeInfo> {
216 self.nodes.get(&node_id)
217 }
218
219 pub fn get_node_mut(&mut self, node_id: u64) -> Option<&mut NodeInfo> {
220 self.nodes.get_mut(&node_id)
221 }
222
223 pub fn set_state(&mut self, node_id: u64, state: NodeState) -> bool {
225 if let Some(info) = self.nodes.get_mut(&node_id) {
226 info.state = state;
227 self.version += 1;
228 true
229 } else {
230 false
231 }
232 }
233
234 pub fn active_nodes(&self) -> Vec<&NodeInfo> {
236 self.nodes
237 .values()
238 .filter(|n| n.state == NodeState::Active)
239 .collect()
240 }
241
242 pub fn all_nodes(&self) -> impl Iterator<Item = &NodeInfo> {
244 self.nodes.values()
245 }
246
247 pub fn node_count(&self) -> usize {
248 self.nodes.len()
249 }
250
251 pub fn version(&self) -> u64 {
252 self.version
253 }
254
255 pub fn contains(&self, node_id: u64) -> bool {
256 self.nodes.contains_key(&node_id)
257 }
258
259 pub fn join_as_learner(&mut self, info: NodeInfo) -> bool {
265 if self.nodes.contains_key(&info.node_id) {
266 return false; }
268 let mut learner = info;
269 learner.state = NodeState::Learner;
270 self.nodes.insert(learner.node_id, learner);
271 self.version += 1;
272 true
273 }
274
275 pub fn promote_to_voter(&mut self, node_id: u64) -> bool {
280 if let Some(info) = self.nodes.get_mut(&node_id)
281 && info.state == NodeState::Learner
282 {
283 info.state = NodeState::Active;
284 self.version += 1;
285 return true;
286 }
287 false
288 }
289
290 pub fn learner_nodes(&self) -> Vec<&NodeInfo> {
292 self.nodes
293 .values()
294 .filter(|n| n.state == NodeState::Learner)
295 .collect()
296 }
297}
298
299impl Default for ClusterTopology {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn add_and_lookup() {
311 let mut topo = ClusterTopology::new();
312 topo.add_node(NodeInfo::new(
313 1,
314 "127.0.0.1:9400".parse().unwrap(),
315 NodeState::Active,
316 ));
317 topo.add_node(NodeInfo::new(
318 2,
319 "127.0.0.1:9401".parse().unwrap(),
320 NodeState::Joining,
321 ));
322
323 assert_eq!(topo.node_count(), 2);
324 assert_eq!(topo.version(), 2);
325 assert_eq!(topo.active_nodes().len(), 1);
326 assert!(topo.contains(1));
327 assert!(topo.contains(2));
328 }
329
330 #[test]
331 fn remove_node() {
332 let mut topo = ClusterTopology::new();
333 topo.add_node(NodeInfo::new(
334 1,
335 "127.0.0.1:9400".parse().unwrap(),
336 NodeState::Active,
337 ));
338 let removed = topo.remove_node(1);
339 assert!(removed.is_some());
340 assert_eq!(topo.node_count(), 0);
341 assert_eq!(topo.version(), 2); }
343
344 #[test]
345 fn set_state() {
346 let mut topo = ClusterTopology::new();
347 topo.add_node(NodeInfo::new(
348 1,
349 "127.0.0.1:9400".parse().unwrap(),
350 NodeState::Joining,
351 ));
352 assert!(topo.set_state(1, NodeState::Active));
353 assert_eq!(topo.get_node(1).unwrap().state, NodeState::Active);
354 }
355
356 #[test]
357 fn node_state_roundtrip() {
358 for state in [
359 NodeState::Joining,
360 NodeState::Active,
361 NodeState::Draining,
362 NodeState::Decommissioned,
363 ] {
364 assert_eq!(NodeState::from_u8(state.as_u8()), Some(state));
365 }
366 assert_eq!(NodeState::from_u8(255), None);
367 }
368
369 #[test]
370 fn serde_roundtrip() {
371 let mut topo = ClusterTopology::new();
372 topo.add_node(NodeInfo::new(
373 1,
374 "127.0.0.1:9400".parse().unwrap(),
375 NodeState::Active,
376 ));
377 topo.add_node(NodeInfo::new(
378 2,
379 "127.0.0.1:9401".parse().unwrap(),
380 NodeState::Active,
381 ));
382
383 let bytes = zerompk::to_msgpack_vec(&topo).unwrap();
384 let decoded: ClusterTopology = zerompk::from_msgpack(&bytes).unwrap();
385 assert_eq!(decoded.node_count(), 2);
386 assert_eq!(decoded.version(), 2);
387 }
388}