Skip to main content

nodedb_cluster/distributed_graph/
wcc.rs

1//! Distributed WCC — cross-shard component merging via label propagation.
2//!
3//! Each shard computes local WCC via union-find, then iteratively exchanges
4//! component labels across shard boundaries. For each cross-shard edge,
5//! the target shard adopts the lexicographically smaller label. Converges
6//! when no shard changes any labels in a round.
7
8use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11
12/// Cross-shard component merge request: shard → target shard.
13#[derive(
14    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
15)]
16pub struct ComponentMergeRequest {
17    pub round: u32,
18    pub source_shard: u16,
19    /// (target_vertex_name, source_component_label).
20    pub merges: Vec<(String, String)>,
21}
22
23/// WCC round acknowledgement: shard → coordinator.
24#[derive(
25    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
26)]
27pub struct WccRoundAck {
28    pub shard_id: u16,
29    pub round: u32,
30    pub labels_changed: usize,
31    pub merges_sent: usize,
32}
33
34/// Per-shard WCC execution state.
35#[derive(Debug)]
36pub struct ShardWccState {
37    pub vertex_count: usize,
38    parent: Vec<usize>,
39    rank: Vec<u8>,
40    pub global_labels: Vec<String>,
41    pub shard_id: u16,
42    pub boundary_edges: HashMap<u32, Vec<(String, u16)>>,
43    node_names: Vec<String>,
44}
45
46impl ShardWccState {
47    /// Initialize WCC state for a local CSR partition.
48    pub fn init(
49        vertex_count: usize,
50        shard_id: u16,
51        node_names: Vec<String>,
52        local_edges: &dyn Fn(u32) -> Vec<u32>,
53        ghost_edges: &dyn Fn(u32) -> Vec<(String, u16)>,
54    ) -> Self {
55        let parent: Vec<usize> = (0..vertex_count).collect();
56        let rank = vec![0u8; vertex_count];
57
58        let mut state = Self {
59            vertex_count,
60            parent,
61            rank,
62            global_labels: Vec::new(),
63            shard_id,
64            boundary_edges: HashMap::new(),
65            node_names,
66        };
67
68        // Local union-find pass.
69        for u in 0..vertex_count {
70            for v in local_edges(u as u32) {
71                state.union(u, v as usize);
72            }
73        }
74
75        // Build boundary edge map.
76        for u in 0..vertex_count {
77            let ghosts = ghost_edges(u as u32);
78            if !ghosts.is_empty() {
79                state.boundary_edges.insert(u as u32, ghosts);
80            }
81        }
82
83        // Initialize global labels from local roots.
84        state.global_labels = (0..vertex_count)
85            .map(|i| {
86                let root = state.find(i);
87                format!("{}:{}", shard_id, state.node_names[root])
88            })
89            .collect();
90
91        state
92    }
93
94    fn find(&mut self, mut x: usize) -> usize {
95        while self.parent[x] != x {
96            self.parent[x] = self.parent[self.parent[x]];
97            x = self.parent[x];
98        }
99        x
100    }
101
102    fn union(&mut self, a: usize, b: usize) {
103        let ra = self.find(a);
104        let rb = self.find(b);
105        if ra == rb {
106            return;
107        }
108        match self.rank[ra].cmp(&self.rank[rb]) {
109            std::cmp::Ordering::Less => self.parent[ra] = rb,
110            std::cmp::Ordering::Greater => self.parent[rb] = ra,
111            std::cmp::Ordering::Equal => {
112                self.parent[rb] = ra;
113                self.rank[ra] += 1;
114            }
115        }
116    }
117
118    /// Produce outbound merge requests for boundary edges.
119    pub fn round(&self) -> (HashMap<u16, Vec<(String, String)>>, usize) {
120        let mut outbound: HashMap<u16, Vec<(String, String)>> = HashMap::new();
121
122        for (&local_id, ghost_list) in &self.boundary_edges {
123            let root = find_static(&self.parent, local_id as usize);
124            let label = self.global_labels[root].clone();
125            for (remote_name, target_shard) in ghost_list {
126                outbound
127                    .entry(*target_shard)
128                    .or_default()
129                    .push((remote_name.clone(), label.clone()));
130            }
131        }
132
133        let sent: usize = outbound.values().map(|v| v.len()).sum();
134        (outbound, sent)
135    }
136
137    /// Apply incoming merges. Returns number of labels changed.
138    pub fn apply_merges(&mut self, merges: &[(String, String)]) -> usize {
139        let mut changed = 0;
140
141        let name_to_id: HashMap<&str, usize> = self
142            .node_names
143            .iter()
144            .enumerate()
145            .map(|(i, n)| (n.as_str(), i))
146            .collect();
147
148        for (vertex_name, remote_label) in merges {
149            let Some(&local_id) = name_to_id.get(vertex_name.as_str()) else {
150                continue;
151            };
152
153            let root = find_static(&self.parent, local_id);
154            let local_label = &self.global_labels[root];
155
156            if local_label != remote_label && *remote_label < *local_label {
157                self.global_labels[root] = remote_label.clone();
158                changed += 1;
159            }
160        }
161
162        // Propagate updated labels to all nodes.
163        for i in 0..self.vertex_count {
164            let root = find_static(&self.parent, i);
165            if i != root {
166                self.global_labels[i] = self.global_labels[root].clone();
167            }
168        }
169
170        changed
171    }
172
173    /// Get current component assignment: (vertex_name, global_label).
174    pub fn component_labels(&self) -> Vec<(String, String)> {
175        (0..self.vertex_count)
176            .map(|i| {
177                let root = find_static(&self.parent, i);
178                (self.node_names[i].clone(), self.global_labels[root].clone())
179            })
180            .collect()
181    }
182}
183
184/// Non-mutating find (no path compression). Borrow-safe.
185fn find_static(parent: &[usize], mut x: usize) -> usize {
186    while parent[x] != x {
187        x = parent[x];
188    }
189    x
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn component_merge_request_serde() {
198        let req = ComponentMergeRequest {
199            round: 2,
200            source_shard: 1,
201            merges: vec![("alice".into(), "0:root_a".into())],
202        };
203        let bytes = zerompk::to_msgpack_vec(&req).unwrap();
204        let decoded: ComponentMergeRequest = zerompk::from_msgpack(&bytes).unwrap();
205        assert_eq!(decoded.round, 2);
206    }
207
208    #[test]
209    fn wcc_round_ack_serde() {
210        let ack = WccRoundAck {
211            shard_id: 3,
212            round: 1,
213            labels_changed: 5,
214            merges_sent: 12,
215        };
216        let bytes = zerompk::to_msgpack_vec(&ack).unwrap();
217        let decoded: WccRoundAck = zerompk::from_msgpack(&bytes).unwrap();
218        assert_eq!(decoded.labels_changed, 5);
219    }
220
221    #[test]
222    fn wcc_shard_local_only() {
223        let state = ShardWccState::init(
224            3,
225            0,
226            vec!["a".into(), "b".into(), "c".into()],
227            &|node| match node {
228                0 => vec![1],
229                1 => vec![2],
230                _ => Vec::new(),
231            },
232            &|_| Vec::new(),
233        );
234        let labels = state.component_labels();
235        assert_eq!(labels[0].1, labels[1].1);
236        assert_eq!(labels[1].1, labels[2].1);
237    }
238
239    #[test]
240    fn wcc_shard_with_boundary_edges() {
241        let state = ShardWccState::init(
242            2,
243            0,
244            vec!["a".into(), "b".into()],
245            &|node| match node {
246                0 => vec![1],
247                _ => Vec::new(),
248            },
249            &|node| {
250                if node == 1 {
251                    vec![("c".into(), 1)]
252                } else {
253                    Vec::new()
254                }
255            },
256        );
257        assert_eq!(state.boundary_edges.len(), 1);
258        let (outbound, sent) = state.round();
259        assert!(outbound.contains_key(&1));
260        assert_eq!(sent, 1);
261    }
262
263    #[test]
264    fn wcc_apply_merges_adopts_smaller_label() {
265        let mut state = ShardWccState::init(
266            2,
267            1,
268            vec!["c".into(), "d".into()],
269            &|node| match node {
270                0 => vec![1],
271                _ => Vec::new(),
272            },
273            &|_| Vec::new(),
274        );
275        let changed = state.apply_merges(&[("c".into(), "0:a".into())]);
276        assert!(changed > 0);
277        let labels = state.component_labels();
278        assert_eq!(labels[0].1, "0:a");
279        assert_eq!(labels[1].1, "0:a");
280    }
281
282    #[test]
283    fn wcc_apply_merges_keeps_smaller_label() {
284        let mut state =
285            ShardWccState::init(1, 0, vec!["a".into()], &|_| Vec::new(), &|_| Vec::new());
286        let changed = state.apply_merges(&[("a".into(), "1:z".into())]);
287        assert_eq!(changed, 0);
288        assert_eq!(state.component_labels()[0].1, "0:a");
289    }
290
291    #[test]
292    fn wcc_multi_round_convergence() {
293        let mut shard0 = ShardWccState::init(
294            2,
295            0,
296            vec!["a".into(), "b".into()],
297            &|node| match node {
298                0 => vec![1],
299                _ => Vec::new(),
300            },
301            &|node| {
302                if node == 1 {
303                    vec![("c".into(), 1)]
304                } else {
305                    Vec::new()
306                }
307            },
308        );
309
310        let mut shard1 = ShardWccState::init(
311            2,
312            1,
313            vec!["c".into(), "d".into()],
314            &|node| match node {
315                0 => vec![1],
316                _ => Vec::new(),
317            },
318            &|node| {
319                if node == 0 {
320                    vec![("b".into(), 0)]
321                } else {
322                    Vec::new()
323                }
324            },
325        );
326
327        // Round 1.
328        let (out0, _) = shard0.round();
329        let (out1, _) = shard1.round();
330        let c0 = out1.get(&0).map_or(0, |m| shard0.apply_merges(m));
331        let c1 = out0.get(&1).map_or(0, |m| shard1.apply_merges(m));
332        assert!(c0 + c1 > 0);
333
334        // Round 2.
335        let (out0_r2, _) = shard0.round();
336        let (out1_r2, _) = shard1.round();
337        let c0_r2 = out1_r2.get(&0).map_or(0, |m| shard0.apply_merges(m));
338        let c1_r2 = out0_r2.get(&1).map_or(0, |m| shard1.apply_merges(m));
339        assert_eq!(c0_r2 + c1_r2, 0, "should converge");
340
341        // All 4 nodes should share one global label.
342        let l0 = shard0.component_labels();
343        let l1 = shard1.component_labels();
344        assert_eq!(l0[0].1, l1[0].1, "cross-shard merge");
345    }
346}