nodedb_cluster/distributed_graph/
wcc.rs1use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11
12#[derive(
14 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
15)]
16pub struct ComponentMergeRequest {
17 pub round: u32,
18 pub source_shard: u16,
19 pub merges: Vec<(String, String)>,
21}
22
23#[derive(
25 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
26)]
27pub struct WccRoundAck {
28 pub shard_id: u16,
29 pub round: u32,
30 pub labels_changed: usize,
31 pub merges_sent: usize,
32}
33
34#[derive(Debug)]
36pub struct ShardWccState {
37 pub vertex_count: usize,
38 parent: Vec<usize>,
39 rank: Vec<u8>,
40 pub global_labels: Vec<String>,
41 pub shard_id: u16,
42 pub boundary_edges: HashMap<u32, Vec<(String, u16)>>,
43 node_names: Vec<String>,
44}
45
46impl ShardWccState {
47 pub fn init(
49 vertex_count: usize,
50 shard_id: u16,
51 node_names: Vec<String>,
52 local_edges: &dyn Fn(u32) -> Vec<u32>,
53 ghost_edges: &dyn Fn(u32) -> Vec<(String, u16)>,
54 ) -> Self {
55 let parent: Vec<usize> = (0..vertex_count).collect();
56 let rank = vec![0u8; vertex_count];
57
58 let mut state = Self {
59 vertex_count,
60 parent,
61 rank,
62 global_labels: Vec::new(),
63 shard_id,
64 boundary_edges: HashMap::new(),
65 node_names,
66 };
67
68 for u in 0..vertex_count {
70 for v in local_edges(u as u32) {
71 state.union(u, v as usize);
72 }
73 }
74
75 for u in 0..vertex_count {
77 let ghosts = ghost_edges(u as u32);
78 if !ghosts.is_empty() {
79 state.boundary_edges.insert(u as u32, ghosts);
80 }
81 }
82
83 state.global_labels = (0..vertex_count)
85 .map(|i| {
86 let root = state.find(i);
87 format!("{}:{}", shard_id, state.node_names[root])
88 })
89 .collect();
90
91 state
92 }
93
94 fn find(&mut self, mut x: usize) -> usize {
95 while self.parent[x] != x {
96 self.parent[x] = self.parent[self.parent[x]];
97 x = self.parent[x];
98 }
99 x
100 }
101
102 fn union(&mut self, a: usize, b: usize) {
103 let ra = self.find(a);
104 let rb = self.find(b);
105 if ra == rb {
106 return;
107 }
108 match self.rank[ra].cmp(&self.rank[rb]) {
109 std::cmp::Ordering::Less => self.parent[ra] = rb,
110 std::cmp::Ordering::Greater => self.parent[rb] = ra,
111 std::cmp::Ordering::Equal => {
112 self.parent[rb] = ra;
113 self.rank[ra] += 1;
114 }
115 }
116 }
117
118 pub fn round(&self) -> (HashMap<u16, Vec<(String, String)>>, usize) {
120 let mut outbound: HashMap<u16, Vec<(String, String)>> = HashMap::new();
121
122 for (&local_id, ghost_list) in &self.boundary_edges {
123 let root = find_static(&self.parent, local_id as usize);
124 let label = self.global_labels[root].clone();
125 for (remote_name, target_shard) in ghost_list {
126 outbound
127 .entry(*target_shard)
128 .or_default()
129 .push((remote_name.clone(), label.clone()));
130 }
131 }
132
133 let sent: usize = outbound.values().map(|v| v.len()).sum();
134 (outbound, sent)
135 }
136
137 pub fn apply_merges(&mut self, merges: &[(String, String)]) -> usize {
139 let mut changed = 0;
140
141 let name_to_id: HashMap<&str, usize> = self
142 .node_names
143 .iter()
144 .enumerate()
145 .map(|(i, n)| (n.as_str(), i))
146 .collect();
147
148 for (vertex_name, remote_label) in merges {
149 let Some(&local_id) = name_to_id.get(vertex_name.as_str()) else {
150 continue;
151 };
152
153 let root = find_static(&self.parent, local_id);
154 let local_label = &self.global_labels[root];
155
156 if local_label != remote_label && *remote_label < *local_label {
157 self.global_labels[root] = remote_label.clone();
158 changed += 1;
159 }
160 }
161
162 for i in 0..self.vertex_count {
164 let root = find_static(&self.parent, i);
165 if i != root {
166 self.global_labels[i] = self.global_labels[root].clone();
167 }
168 }
169
170 changed
171 }
172
173 pub fn component_labels(&self) -> Vec<(String, String)> {
175 (0..self.vertex_count)
176 .map(|i| {
177 let root = find_static(&self.parent, i);
178 (self.node_names[i].clone(), self.global_labels[root].clone())
179 })
180 .collect()
181 }
182}
183
184fn find_static(parent: &[usize], mut x: usize) -> usize {
186 while parent[x] != x {
187 x = parent[x];
188 }
189 x
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn component_merge_request_serde() {
198 let req = ComponentMergeRequest {
199 round: 2,
200 source_shard: 1,
201 merges: vec![("alice".into(), "0:root_a".into())],
202 };
203 let bytes = zerompk::to_msgpack_vec(&req).unwrap();
204 let decoded: ComponentMergeRequest = zerompk::from_msgpack(&bytes).unwrap();
205 assert_eq!(decoded.round, 2);
206 }
207
208 #[test]
209 fn wcc_round_ack_serde() {
210 let ack = WccRoundAck {
211 shard_id: 3,
212 round: 1,
213 labels_changed: 5,
214 merges_sent: 12,
215 };
216 let bytes = zerompk::to_msgpack_vec(&ack).unwrap();
217 let decoded: WccRoundAck = zerompk::from_msgpack(&bytes).unwrap();
218 assert_eq!(decoded.labels_changed, 5);
219 }
220
221 #[test]
222 fn wcc_shard_local_only() {
223 let state = ShardWccState::init(
224 3,
225 0,
226 vec!["a".into(), "b".into(), "c".into()],
227 &|node| match node {
228 0 => vec![1],
229 1 => vec![2],
230 _ => Vec::new(),
231 },
232 &|_| Vec::new(),
233 );
234 let labels = state.component_labels();
235 assert_eq!(labels[0].1, labels[1].1);
236 assert_eq!(labels[1].1, labels[2].1);
237 }
238
239 #[test]
240 fn wcc_shard_with_boundary_edges() {
241 let state = ShardWccState::init(
242 2,
243 0,
244 vec!["a".into(), "b".into()],
245 &|node| match node {
246 0 => vec![1],
247 _ => Vec::new(),
248 },
249 &|node| {
250 if node == 1 {
251 vec![("c".into(), 1)]
252 } else {
253 Vec::new()
254 }
255 },
256 );
257 assert_eq!(state.boundary_edges.len(), 1);
258 let (outbound, sent) = state.round();
259 assert!(outbound.contains_key(&1));
260 assert_eq!(sent, 1);
261 }
262
263 #[test]
264 fn wcc_apply_merges_adopts_smaller_label() {
265 let mut state = ShardWccState::init(
266 2,
267 1,
268 vec!["c".into(), "d".into()],
269 &|node| match node {
270 0 => vec![1],
271 _ => Vec::new(),
272 },
273 &|_| Vec::new(),
274 );
275 let changed = state.apply_merges(&[("c".into(), "0:a".into())]);
276 assert!(changed > 0);
277 let labels = state.component_labels();
278 assert_eq!(labels[0].1, "0:a");
279 assert_eq!(labels[1].1, "0:a");
280 }
281
282 #[test]
283 fn wcc_apply_merges_keeps_smaller_label() {
284 let mut state =
285 ShardWccState::init(1, 0, vec!["a".into()], &|_| Vec::new(), &|_| Vec::new());
286 let changed = state.apply_merges(&[("a".into(), "1:z".into())]);
287 assert_eq!(changed, 0);
288 assert_eq!(state.component_labels()[0].1, "0:a");
289 }
290
291 #[test]
292 fn wcc_multi_round_convergence() {
293 let mut shard0 = ShardWccState::init(
294 2,
295 0,
296 vec!["a".into(), "b".into()],
297 &|node| match node {
298 0 => vec![1],
299 _ => Vec::new(),
300 },
301 &|node| {
302 if node == 1 {
303 vec![("c".into(), 1)]
304 } else {
305 Vec::new()
306 }
307 },
308 );
309
310 let mut shard1 = ShardWccState::init(
311 2,
312 1,
313 vec!["c".into(), "d".into()],
314 &|node| match node {
315 0 => vec![1],
316 _ => Vec::new(),
317 },
318 &|node| {
319 if node == 0 {
320 vec![("b".into(), 0)]
321 } else {
322 Vec::new()
323 }
324 },
325 );
326
327 let (out0, _) = shard0.round();
329 let (out1, _) = shard1.round();
330 let c0 = out1.get(&0).map_or(0, |m| shard0.apply_merges(m));
331 let c1 = out0.get(&1).map_or(0, |m| shard1.apply_merges(m));
332 assert!(c0 + c1 > 0);
333
334 let (out0_r2, _) = shard0.round();
336 let (out1_r2, _) = shard1.round();
337 let c0_r2 = out1_r2.get(&0).map_or(0, |m| shard0.apply_merges(m));
338 let c1_r2 = out0_r2.get(&1).map_or(0, |m| shard1.apply_merges(m));
339 assert_eq!(c0_r2 + c1_r2, 0, "should converge");
340
341 let l0 = shard0.component_labels();
343 let l1 = shard1.component_labels();
344 assert_eq!(l0[0].1, l1[0].1, "cross-shard merge");
345 }
346}