nodedb_cluster/distributed_graph/
wcc.rs1use std::collections::HashMap;
11
12use serde::{Deserialize, Serialize};
13
14#[derive(
16 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
17)]
18pub struct ComponentMergeRequest {
19 pub round: u32,
20 pub source_shard: u32,
21 pub merges: Vec<(String, String)>,
23}
24
25#[derive(
27 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct WccRoundAck {
30 pub shard_id: u32,
31 pub round: u32,
32 pub labels_changed: usize,
33 pub merges_sent: usize,
34}
35
36#[derive(Debug)]
38pub struct ShardWccState {
39 pub vertex_count: usize,
40 parent: Vec<usize>,
41 rank: Vec<u8>,
42 pub global_labels: Vec<String>,
43 pub shard_id: u32,
44 pub boundary_edges: HashMap<u32, Vec<(String, u32)>>,
45 node_names: Vec<String>,
46}
47
48impl ShardWccState {
49 pub fn init(
51 vertex_count: usize,
52 shard_id: u32,
53 node_names: Vec<String>,
54 local_edges: &dyn Fn(u32) -> Vec<u32>,
55 ghost_edges: &dyn Fn(u32) -> Vec<(String, u32)>,
56 ) -> Self {
57 let parent: Vec<usize> = (0..vertex_count).collect();
58 let rank = vec![0u8; vertex_count];
59
60 let mut state = Self {
61 vertex_count,
62 parent,
63 rank,
64 global_labels: Vec::new(),
65 shard_id,
66 boundary_edges: HashMap::new(),
67 node_names,
68 };
69
70 for u in 0..vertex_count {
72 for v in local_edges(u as u32) {
73 state.union(u, v as usize);
74 }
75 }
76
77 for u in 0..vertex_count {
79 let ghosts = ghost_edges(u as u32);
80 if !ghosts.is_empty() {
81 state.boundary_edges.insert(u as u32, ghosts);
82 }
83 }
84
85 state.global_labels = (0..vertex_count)
87 .map(|i| {
88 let root = state.find(i);
89 format!("{}:{}", shard_id, state.node_names[root])
90 })
91 .collect();
92
93 state
94 }
95
96 fn find(&mut self, mut x: usize) -> usize {
97 while self.parent[x] != x {
98 self.parent[x] = self.parent[self.parent[x]];
99 x = self.parent[x];
100 }
101 x
102 }
103
104 fn union(&mut self, a: usize, b: usize) {
105 let ra = self.find(a);
106 let rb = self.find(b);
107 if ra == rb {
108 return;
109 }
110 match self.rank[ra].cmp(&self.rank[rb]) {
111 std::cmp::Ordering::Less => self.parent[ra] = rb,
112 std::cmp::Ordering::Greater => self.parent[rb] = ra,
113 std::cmp::Ordering::Equal => {
114 self.parent[rb] = ra;
115 self.rank[ra] += 1;
116 }
117 }
118 }
119
120 pub fn round(&self) -> (HashMap<u32, Vec<(String, String)>>, usize) {
122 let mut outbound: HashMap<u32, Vec<(String, String)>> = HashMap::new();
123
124 for (&local_id, ghost_list) in &self.boundary_edges {
125 let root = find_static(&self.parent, local_id as usize);
126 let label = self.global_labels[root].clone();
127 for (remote_name, target_shard) in ghost_list {
128 outbound
129 .entry(*target_shard)
130 .or_default()
131 .push((remote_name.clone(), label.clone()));
132 }
133 }
134
135 let sent: usize = outbound.values().map(|v| v.len()).sum();
136 (outbound, sent)
137 }
138
139 pub fn apply_merges(&mut self, merges: &[(String, String)]) -> usize {
141 let mut changed = 0;
142
143 let name_to_id: HashMap<&str, usize> = self
144 .node_names
145 .iter()
146 .enumerate()
147 .map(|(i, n)| (n.as_str(), i))
148 .collect();
149
150 for (vertex_name, remote_label) in merges {
151 let Some(&local_id) = name_to_id.get(vertex_name.as_str()) else {
152 continue;
153 };
154
155 let root = find_static(&self.parent, local_id);
156 let local_label = &self.global_labels[root];
157
158 if local_label != remote_label && *remote_label < *local_label {
159 self.global_labels[root] = remote_label.clone();
160 changed += 1;
161 }
162 }
163
164 for i in 0..self.vertex_count {
166 let root = find_static(&self.parent, i);
167 if i != root {
168 self.global_labels[i] = self.global_labels[root].clone();
169 }
170 }
171
172 changed
173 }
174
175 pub fn component_labels(&self) -> Vec<(String, String)> {
177 (0..self.vertex_count)
178 .map(|i| {
179 let root = find_static(&self.parent, i);
180 (self.node_names[i].clone(), self.global_labels[root].clone())
181 })
182 .collect()
183 }
184}
185
186fn find_static(parent: &[usize], mut x: usize) -> usize {
188 while parent[x] != x {
189 x = parent[x];
190 }
191 x
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn component_merge_request_serde() {
200 let req = ComponentMergeRequest {
201 round: 2,
202 source_shard: 1,
203 merges: vec![("alice".into(), "0:root_a".into())],
204 };
205 let bytes = zerompk::to_msgpack_vec(&req).unwrap();
206 let decoded: ComponentMergeRequest = zerompk::from_msgpack(&bytes).unwrap();
207 assert_eq!(decoded.round, 2);
208 }
209
210 #[test]
211 fn wcc_round_ack_serde() {
212 let ack = WccRoundAck {
213 shard_id: 3,
214 round: 1,
215 labels_changed: 5,
216 merges_sent: 12,
217 };
218 let bytes = zerompk::to_msgpack_vec(&ack).unwrap();
219 let decoded: WccRoundAck = zerompk::from_msgpack(&bytes).unwrap();
220 assert_eq!(decoded.labels_changed, 5);
221 }
222
223 #[test]
224 fn wcc_shard_local_only() {
225 let state = ShardWccState::init(
226 3,
227 0,
228 vec!["a".into(), "b".into(), "c".into()],
229 &|node| match node {
230 0 => vec![1],
231 1 => vec![2],
232 _ => Vec::new(),
233 },
234 &|_| Vec::new(),
235 );
236 let labels = state.component_labels();
237 assert_eq!(labels[0].1, labels[1].1);
238 assert_eq!(labels[1].1, labels[2].1);
239 }
240
241 #[test]
242 fn wcc_shard_with_boundary_edges() {
243 let state = ShardWccState::init(
244 2,
245 0,
246 vec!["a".into(), "b".into()],
247 &|node| match node {
248 0 => vec![1],
249 _ => Vec::new(),
250 },
251 &|node| {
252 if node == 1 {
253 vec![("c".into(), 1)]
254 } else {
255 Vec::new()
256 }
257 },
258 );
259 assert_eq!(state.boundary_edges.len(), 1);
260 let (outbound, sent) = state.round();
261 assert!(outbound.contains_key(&1));
262 assert_eq!(sent, 1);
263 }
264
265 #[test]
266 fn wcc_apply_merges_adopts_smaller_label() {
267 let mut state = ShardWccState::init(
268 2,
269 1,
270 vec!["c".into(), "d".into()],
271 &|node| match node {
272 0 => vec![1],
273 _ => Vec::new(),
274 },
275 &|_| Vec::new(),
276 );
277 let changed = state.apply_merges(&[("c".into(), "0:a".into())]);
278 assert!(changed > 0);
279 let labels = state.component_labels();
280 assert_eq!(labels[0].1, "0:a");
281 assert_eq!(labels[1].1, "0:a");
282 }
283
284 #[test]
285 fn wcc_apply_merges_keeps_smaller_label() {
286 let mut state =
287 ShardWccState::init(1, 0, vec!["a".into()], &|_| Vec::new(), &|_| Vec::new());
288 let changed = state.apply_merges(&[("a".into(), "1:z".into())]);
289 assert_eq!(changed, 0);
290 assert_eq!(state.component_labels()[0].1, "0:a");
291 }
292
293 #[test]
294 fn wcc_multi_round_convergence() {
295 let mut shard0 = ShardWccState::init(
296 2,
297 0,
298 vec!["a".into(), "b".into()],
299 &|node| match node {
300 0 => vec![1],
301 _ => Vec::new(),
302 },
303 &|node| {
304 if node == 1 {
305 vec![("c".into(), 1)]
306 } else {
307 Vec::new()
308 }
309 },
310 );
311
312 let mut shard1 = ShardWccState::init(
313 2,
314 1,
315 vec!["c".into(), "d".into()],
316 &|node| match node {
317 0 => vec![1],
318 _ => Vec::new(),
319 },
320 &|node| {
321 if node == 0 {
322 vec![("b".into(), 0)]
323 } else {
324 Vec::new()
325 }
326 },
327 );
328
329 let (out0, _) = shard0.round();
331 let (out1, _) = shard1.round();
332 let c0 = out1.get(&0).map_or(0, |m| shard0.apply_merges(m));
333 let c1 = out0.get(&1).map_or(0, |m| shard1.apply_merges(m));
334 assert!(c0 + c1 > 0);
335
336 let (out0_r2, _) = shard0.round();
338 let (out1_r2, _) = shard1.round();
339 let c0_r2 = out1_r2.get(&0).map_or(0, |m| shard0.apply_merges(m));
340 let c1_r2 = out0_r2.get(&1).map_or(0, |m| shard1.apply_merges(m));
341 assert_eq!(c0_r2 + c1_r2, 0, "should converge");
342
343 let l0 = shard0.component_labels();
345 let l1 = shard1.component_labels();
346 assert_eq!(l0[0].1, l1[0].1, "cross-shard merge");
347 }
348}