Skip to main content

nodedb_cluster/distributed_graph/
wcc.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Distributed WCC — cross-shard component merging via label propagation.
4//!
5//! Each shard computes local WCC via union-find, then iteratively exchanges
6//! component labels across shard boundaries. For each cross-shard edge,
7//! the target shard adopts the lexicographically smaller label. Converges
8//! when no shard changes any labels in a round.
9
10use std::collections::HashMap;
11
12use serde::{Deserialize, Serialize};
13
14/// Cross-shard component merge request: shard → target shard.
15#[derive(
16    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
17)]
18pub struct ComponentMergeRequest {
19    pub round: u32,
20    pub source_shard: u32,
21    /// (target_vertex_name, source_component_label).
22    pub merges: Vec<(String, String)>,
23}
24
25/// WCC round acknowledgement: shard → coordinator.
26#[derive(
27    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct WccRoundAck {
30    pub shard_id: u32,
31    pub round: u32,
32    pub labels_changed: usize,
33    pub merges_sent: usize,
34}
35
36/// Per-shard WCC execution state.
37#[derive(Debug)]
38pub struct ShardWccState {
39    pub vertex_count: usize,
40    parent: Vec<usize>,
41    rank: Vec<u8>,
42    pub global_labels: Vec<String>,
43    pub shard_id: u32,
44    pub boundary_edges: HashMap<u32, Vec<(String, u32)>>,
45    node_names: Vec<String>,
46}
47
48impl ShardWccState {
49    /// Initialize WCC state for a local CSR partition.
50    pub fn init(
51        vertex_count: usize,
52        shard_id: u32,
53        node_names: Vec<String>,
54        local_edges: &dyn Fn(u32) -> Vec<u32>,
55        ghost_edges: &dyn Fn(u32) -> Vec<(String, u32)>,
56    ) -> Self {
57        let parent: Vec<usize> = (0..vertex_count).collect();
58        let rank = vec![0u8; vertex_count];
59
60        let mut state = Self {
61            vertex_count,
62            parent,
63            rank,
64            global_labels: Vec::new(),
65            shard_id,
66            boundary_edges: HashMap::new(),
67            node_names,
68        };
69
70        // Local union-find pass.
71        for u in 0..vertex_count {
72            for v in local_edges(u as u32) {
73                state.union(u, v as usize);
74            }
75        }
76
77        // Build boundary edge map.
78        for u in 0..vertex_count {
79            let ghosts = ghost_edges(u as u32);
80            if !ghosts.is_empty() {
81                state.boundary_edges.insert(u as u32, ghosts);
82            }
83        }
84
85        // Initialize global labels from local roots.
86        state.global_labels = (0..vertex_count)
87            .map(|i| {
88                let root = state.find(i);
89                format!("{}:{}", shard_id, state.node_names[root])
90            })
91            .collect();
92
93        state
94    }
95
96    fn find(&mut self, mut x: usize) -> usize {
97        while self.parent[x] != x {
98            self.parent[x] = self.parent[self.parent[x]];
99            x = self.parent[x];
100        }
101        x
102    }
103
104    fn union(&mut self, a: usize, b: usize) {
105        let ra = self.find(a);
106        let rb = self.find(b);
107        if ra == rb {
108            return;
109        }
110        match self.rank[ra].cmp(&self.rank[rb]) {
111            std::cmp::Ordering::Less => self.parent[ra] = rb,
112            std::cmp::Ordering::Greater => self.parent[rb] = ra,
113            std::cmp::Ordering::Equal => {
114                self.parent[rb] = ra;
115                self.rank[ra] += 1;
116            }
117        }
118    }
119
120    /// Produce outbound merge requests for boundary edges.
121    pub fn round(&self) -> (HashMap<u32, Vec<(String, String)>>, usize) {
122        let mut outbound: HashMap<u32, Vec<(String, String)>> = HashMap::new();
123
124        for (&local_id, ghost_list) in &self.boundary_edges {
125            let root = find_static(&self.parent, local_id as usize);
126            let label = self.global_labels[root].clone();
127            for (remote_name, target_shard) in ghost_list {
128                outbound
129                    .entry(*target_shard)
130                    .or_default()
131                    .push((remote_name.clone(), label.clone()));
132            }
133        }
134
135        let sent: usize = outbound.values().map(|v| v.len()).sum();
136        (outbound, sent)
137    }
138
139    /// Apply incoming merges. Returns number of labels changed.
140    pub fn apply_merges(&mut self, merges: &[(String, String)]) -> usize {
141        let mut changed = 0;
142
143        let name_to_id: HashMap<&str, usize> = self
144            .node_names
145            .iter()
146            .enumerate()
147            .map(|(i, n)| (n.as_str(), i))
148            .collect();
149
150        for (vertex_name, remote_label) in merges {
151            let Some(&local_id) = name_to_id.get(vertex_name.as_str()) else {
152                continue;
153            };
154
155            let root = find_static(&self.parent, local_id);
156            let local_label = &self.global_labels[root];
157
158            if local_label != remote_label && *remote_label < *local_label {
159                self.global_labels[root] = remote_label.clone();
160                changed += 1;
161            }
162        }
163
164        // Propagate updated labels to all nodes.
165        for i in 0..self.vertex_count {
166            let root = find_static(&self.parent, i);
167            if i != root {
168                self.global_labels[i] = self.global_labels[root].clone();
169            }
170        }
171
172        changed
173    }
174
175    /// Get current component assignment: (vertex_name, global_label).
176    pub fn component_labels(&self) -> Vec<(String, String)> {
177        (0..self.vertex_count)
178            .map(|i| {
179                let root = find_static(&self.parent, i);
180                (self.node_names[i].clone(), self.global_labels[root].clone())
181            })
182            .collect()
183    }
184}
185
186/// Non-mutating find (no path compression). Borrow-safe.
187fn find_static(parent: &[usize], mut x: usize) -> usize {
188    while parent[x] != x {
189        x = parent[x];
190    }
191    x
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn component_merge_request_serde() {
200        let req = ComponentMergeRequest {
201            round: 2,
202            source_shard: 1,
203            merges: vec![("alice".into(), "0:root_a".into())],
204        };
205        let bytes = zerompk::to_msgpack_vec(&req).unwrap();
206        let decoded: ComponentMergeRequest = zerompk::from_msgpack(&bytes).unwrap();
207        assert_eq!(decoded.round, 2);
208    }
209
210    #[test]
211    fn wcc_round_ack_serde() {
212        let ack = WccRoundAck {
213            shard_id: 3,
214            round: 1,
215            labels_changed: 5,
216            merges_sent: 12,
217        };
218        let bytes = zerompk::to_msgpack_vec(&ack).unwrap();
219        let decoded: WccRoundAck = zerompk::from_msgpack(&bytes).unwrap();
220        assert_eq!(decoded.labels_changed, 5);
221    }
222
223    #[test]
224    fn wcc_shard_local_only() {
225        let state = ShardWccState::init(
226            3,
227            0,
228            vec!["a".into(), "b".into(), "c".into()],
229            &|node| match node {
230                0 => vec![1],
231                1 => vec![2],
232                _ => Vec::new(),
233            },
234            &|_| Vec::new(),
235        );
236        let labels = state.component_labels();
237        assert_eq!(labels[0].1, labels[1].1);
238        assert_eq!(labels[1].1, labels[2].1);
239    }
240
241    #[test]
242    fn wcc_shard_with_boundary_edges() {
243        let state = ShardWccState::init(
244            2,
245            0,
246            vec!["a".into(), "b".into()],
247            &|node| match node {
248                0 => vec![1],
249                _ => Vec::new(),
250            },
251            &|node| {
252                if node == 1 {
253                    vec![("c".into(), 1)]
254                } else {
255                    Vec::new()
256                }
257            },
258        );
259        assert_eq!(state.boundary_edges.len(), 1);
260        let (outbound, sent) = state.round();
261        assert!(outbound.contains_key(&1));
262        assert_eq!(sent, 1);
263    }
264
265    #[test]
266    fn wcc_apply_merges_adopts_smaller_label() {
267        let mut state = ShardWccState::init(
268            2,
269            1,
270            vec!["c".into(), "d".into()],
271            &|node| match node {
272                0 => vec![1],
273                _ => Vec::new(),
274            },
275            &|_| Vec::new(),
276        );
277        let changed = state.apply_merges(&[("c".into(), "0:a".into())]);
278        assert!(changed > 0);
279        let labels = state.component_labels();
280        assert_eq!(labels[0].1, "0:a");
281        assert_eq!(labels[1].1, "0:a");
282    }
283
284    #[test]
285    fn wcc_apply_merges_keeps_smaller_label() {
286        let mut state =
287            ShardWccState::init(1, 0, vec!["a".into()], &|_| Vec::new(), &|_| Vec::new());
288        let changed = state.apply_merges(&[("a".into(), "1:z".into())]);
289        assert_eq!(changed, 0);
290        assert_eq!(state.component_labels()[0].1, "0:a");
291    }
292
293    #[test]
294    fn wcc_multi_round_convergence() {
295        let mut shard0 = ShardWccState::init(
296            2,
297            0,
298            vec!["a".into(), "b".into()],
299            &|node| match node {
300                0 => vec![1],
301                _ => Vec::new(),
302            },
303            &|node| {
304                if node == 1 {
305                    vec![("c".into(), 1)]
306                } else {
307                    Vec::new()
308                }
309            },
310        );
311
312        let mut shard1 = ShardWccState::init(
313            2,
314            1,
315            vec!["c".into(), "d".into()],
316            &|node| match node {
317                0 => vec![1],
318                _ => Vec::new(),
319            },
320            &|node| {
321                if node == 0 {
322                    vec![("b".into(), 0)]
323                } else {
324                    Vec::new()
325                }
326            },
327        );
328
329        // Round 1.
330        let (out0, _) = shard0.round();
331        let (out1, _) = shard1.round();
332        let c0 = out1.get(&0).map_or(0, |m| shard0.apply_merges(m));
333        let c1 = out0.get(&1).map_or(0, |m| shard1.apply_merges(m));
334        assert!(c0 + c1 > 0);
335
336        // Round 2.
337        let (out0_r2, _) = shard0.round();
338        let (out1_r2, _) = shard1.round();
339        let c0_r2 = out1_r2.get(&0).map_or(0, |m| shard0.apply_merges(m));
340        let c1_r2 = out0_r2.get(&1).map_or(0, |m| shard1.apply_merges(m));
341        assert_eq!(c0_r2 + c1_r2, 0, "should converge");
342
343        // All 4 nodes should share one global label.
344        let l0 = shard0.component_labels();
345        let l1 = shard1.component_labels();
346        assert_eq!(l0[0].1, l1[0].1, "cross-shard merge");
347    }
348}