1use rabia_core::NodeId;
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use tracing::{debug, info};
14
15#[derive(Debug, Clone)]
17pub struct LeaderSelector {
18 cluster_view: Vec<NodeId>,
20 current_leader: Option<NodeId>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct LeadershipInfo {
27 pub leader: Option<NodeId>,
29 pub cluster_nodes: Vec<NodeId>,
31 pub cluster_size: usize,
33}
34
35impl LeaderSelector {
36 pub fn new() -> Self {
38 Self {
39 cluster_view: Vec::new(),
40 current_leader: None,
41 }
42 }
43
44 pub fn with_cluster(nodes: HashSet<NodeId>) -> Self {
46 let mut selector = Self::new();
47 selector.update_cluster_view(nodes);
48 selector
49 }
50
51 pub fn determine_leader(&self) -> Option<NodeId> {
55 self.cluster_view.first().copied()
56 }
57
58 pub fn update_cluster_view(&mut self, nodes: HashSet<NodeId>) -> Option<NodeId> {
62 let mut new_cluster_view: Vec<NodeId> = nodes.into_iter().collect();
64 new_cluster_view.sort();
65
66 let old_leader = self.current_leader;
67
68 self.cluster_view = new_cluster_view;
70 self.current_leader = self.determine_leader();
71
72 debug!(
73 "Cluster view updated: {:?}, leader: {:?}",
74 self.cluster_view, self.current_leader
75 );
76
77 if self.current_leader != old_leader {
79 info!(
80 "Leader changed from {:?} to {:?}",
81 old_leader, self.current_leader
82 );
83 self.current_leader
84 } else {
85 None
86 }
87 }
88
89 pub fn add_node(&mut self, node_id: NodeId) -> Option<NodeId> {
93 let mut nodes: HashSet<NodeId> = self.cluster_view.iter().copied().collect();
94 nodes.insert(node_id);
95 self.update_cluster_view(nodes)
96 }
97
98 pub fn remove_node(&mut self, node_id: NodeId) -> Option<NodeId> {
102 let mut nodes: HashSet<NodeId> = self.cluster_view.iter().copied().collect();
103 nodes.remove(&node_id);
104 self.update_cluster_view(nodes)
105 }
106
107 pub fn get_leader(&self) -> Option<NodeId> {
109 self.current_leader
110 }
111
112 pub fn is_leader(&self, node_id: NodeId) -> bool {
114 self.current_leader == Some(node_id)
115 }
116
117 pub fn get_cluster_view(&self) -> &[NodeId] {
119 &self.cluster_view
120 }
121
122 pub fn get_leadership_info(&self) -> LeadershipInfo {
124 LeadershipInfo {
125 leader: self.current_leader,
126 cluster_nodes: self.cluster_view.clone(),
127 cluster_size: self.cluster_view.len(),
128 }
129 }
130
131 pub fn has_nodes(&self) -> bool {
133 !self.cluster_view.is_empty()
134 }
135
136 pub fn cluster_size(&self) -> usize {
138 self.cluster_view.len()
139 }
140}
141
142impl Default for LeaderSelector {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::collections::HashSet;
152
153 #[test]
154 fn test_empty_cluster_has_no_leader() {
155 let selector = LeaderSelector::new();
156 assert_eq!(selector.get_leader(), None);
157 assert_eq!(selector.cluster_size(), 0);
158 assert!(!selector.has_nodes());
159 }
160
161 #[test]
162 fn test_single_node_becomes_leader() {
163 let node1 = NodeId::from(1);
164 let mut nodes = HashSet::new();
165 nodes.insert(node1);
166
167 let selector = LeaderSelector::with_cluster(nodes);
168 assert_eq!(selector.get_leader(), Some(node1));
169 assert!(selector.is_leader(node1));
170 assert_eq!(selector.cluster_size(), 1);
171 }
172
173 #[test]
174 fn test_multiple_nodes_first_becomes_leader() {
175 let node1 = NodeId::from(1);
176 let node2 = NodeId::from(2);
177 let node3 = NodeId::from(3);
178
179 let mut nodes = HashSet::new();
180 nodes.insert(node3); nodes.insert(node1);
182 nodes.insert(node2);
183
184 let selector = LeaderSelector::with_cluster(nodes);
185
186 assert_eq!(selector.get_leader(), Some(node1));
188 assert!(selector.is_leader(node1));
189 assert!(!selector.is_leader(node2));
190 assert!(!selector.is_leader(node3));
191 assert_eq!(selector.cluster_size(), 3);
192 }
193
194 #[test]
195 fn test_leader_changes_when_cluster_changes() {
196 let node1 = NodeId::from(1);
197 let node2 = NodeId::from(2);
198 let node3 = NodeId::from(3);
199
200 let mut selector = LeaderSelector::new();
201
202 let mut nodes = HashSet::new();
204 nodes.insert(node2);
205 nodes.insert(node3);
206
207 let new_leader = selector.update_cluster_view(nodes);
208 assert_eq!(new_leader, Some(node2)); assert_eq!(selector.get_leader(), Some(node2));
210
211 let new_leader = selector.add_node(node1);
213 assert_eq!(new_leader, Some(node1));
214 assert_eq!(selector.get_leader(), Some(node1));
215
216 let new_leader = selector.remove_node(node1);
218 assert_eq!(new_leader, Some(node2));
219 assert_eq!(selector.get_leader(), Some(node2));
220 }
221
222 #[test]
223 fn test_no_leader_change_when_non_leader_leaves() {
224 let node1 = NodeId::from(1);
225 let node2 = NodeId::from(2);
226 let node3 = NodeId::from(3);
227
228 let mut nodes = HashSet::new();
229 nodes.insert(node1);
230 nodes.insert(node2);
231 nodes.insert(node3);
232
233 let mut selector = LeaderSelector::with_cluster(nodes);
234 assert_eq!(selector.get_leader(), Some(node1)); let new_leader = selector.remove_node(node3);
238 assert_eq!(new_leader, None); assert_eq!(selector.get_leader(), Some(node1)); }
241
242 #[test]
243 fn test_leadership_info() {
244 let node1 = NodeId::from(1);
245 let node2 = NodeId::from(2);
246
247 let mut nodes = HashSet::new();
248 nodes.insert(node1);
249 nodes.insert(node2);
250
251 let selector = LeaderSelector::with_cluster(nodes);
252 let info = selector.get_leadership_info();
253
254 assert_eq!(info.leader, Some(node1));
255 assert_eq!(info.cluster_size, 2);
256 assert!(info.cluster_nodes.contains(&node1));
257 assert!(info.cluster_nodes.contains(&node2));
258 }
259
260 #[test]
261 fn test_deterministic_ordering() {
262 let node1 = NodeId::from(5);
264 let node2 = NodeId::from(1);
265 let node3 = NodeId::from(3);
266
267 let mut nodes1 = HashSet::new();
269 nodes1.insert(node1);
270 nodes1.insert(node2);
271 nodes1.insert(node3);
272 let selector1 = LeaderSelector::with_cluster(nodes1);
273
274 let mut nodes2 = HashSet::new();
276 nodes2.insert(node3);
277 nodes2.insert(node1);
278 nodes2.insert(node2);
279 let selector2 = LeaderSelector::with_cluster(nodes2);
280
281 assert_eq!(selector1.get_leader(), selector2.get_leader());
283 assert_eq!(selector1.get_cluster_view(), selector2.get_cluster_view());
284 }
285}