prav_core/decoder/
union_find.rs

1//! Union Find (Disjoint Set Forest) implementation for cluster tracking.
2//!
3//! This module provides the data structure that tracks which nodes belong to
4//! the same cluster during QEC decoding. It uses several optimizations:
5//!
6//! - **Fast path for self-rooted nodes**: At typical error rates (p=0.001),
7//!   ~95% of nodes are self-rooted (their own cluster). Direct check avoids traversal.
8//! - **Path halving compression**: When traversing, each node points to its grandparent,
9//!   halving path length per query. Achieves O(α(n)) amortized complexity.
10//! - **Deterministic union**: Smaller index becomes child of larger, providing
11//!   reproducible results without rank tracking overhead.
12
13#![allow(unsafe_op_in_unsafe_fn)]
14use super::state::DecodingState;
15use crate::topology::Topology;
16
17/// Disjoint set forest operations for tracking connected clusters.
18///
19/// In Union Find-based QEC decoding, each syndrome node starts as its own cluster.
20/// As cluster growth proceeds, neighboring nodes are merged using `union`. The
21/// `find` operation determines which cluster a node belongs to.
22///
23/// # Cluster Representation
24///
25/// Each cluster is identified by its root node - the representative element.
26/// The root is the node where `parents[root] == root`. All other nodes in the
27/// cluster have a parent pointer forming a tree structure leading to the root.
28///
29/// ```text
30/// Before union:       After union(A, B):
31///   A    B               B (root)
32///  /|    |              /|\
33/// 1 2    3             A 1 2
34///                        |
35///                        3
36/// ```
37///
38/// # Performance Characteristics
39///
40/// | Operation | Time Complexity | Notes |
41/// |-----------|-----------------|-------|
42/// | `find` | O(α(n)) amortized | α is inverse Ackermann (effectively constant) |
43/// | `union` | O(α(n)) amortized | Two finds + O(1) merge |
44/// | `union_roots` | O(1) | Direct merge of known roots |
45pub trait UnionFind {
46    /// Finds the root (cluster representative) of the node `i`.
47    ///
48    /// This is the fundamental query operation. Two nodes are in the same cluster
49    /// if and only if they have the same root.
50    ///
51    /// # Arguments
52    ///
53    /// * `i` - Node index to find the root of.
54    ///
55    /// # Returns
56    ///
57    /// The root node index of the cluster containing `i`.
58    ///
59    /// # Fast Path Optimization
60    ///
61    /// At typical QEC error rates, most nodes are isolated (self-rooted).
62    /// The implementation checks `parents[i] == i` first, returning immediately
63    /// in ~95% of cases without any traversal.
64    ///
65    /// # Path Compression
66    ///
67    /// During traversal, path halving is applied: each visited node is redirected
68    /// to its grandparent. This flattens the tree over time, keeping paths short.
69    fn find(&mut self, i: u32) -> u32;
70
71    /// Merges two clusters given their root nodes.
72    ///
73    /// This is the low-level merge operation used when roots are already known.
74    /// Use [`union`](Self::union) for the general case where roots must be found.
75    ///
76    /// # Arguments
77    ///
78    /// * `root_u` - Root of the first cluster.
79    /// * `root_v` - Root of the second cluster.
80    ///
81    /// # Returns
82    ///
83    /// * `true` if the clusters were merged (they were different).
84    /// * `false` if they were already the same cluster.
85    ///
86    /// # Union Strategy
87    ///
88    /// Uses index-based union: the smaller index becomes a child of the larger.
89    /// This provides deterministic behavior without maintaining rank information.
90    ///
91    /// # Safety
92    ///
93    /// Caller must ensure `root_u` and `root_v` are valid root node indices
94    /// (i.e., `parents[root_u] == root_u` and `parents[root_v] == root_v`).
95    unsafe fn union_roots(&mut self, root_u: u32, root_v: u32) -> bool;
96
97    /// Merges the clusters containing nodes `u` and `v`.
98    ///
99    /// Finds the roots of both nodes and merges them if different.
100    ///
101    /// # Arguments
102    ///
103    /// * `u` - First node index.
104    /// * `v` - Second node index.
105    ///
106    /// # Returns
107    ///
108    /// * `true` if the clusters were merged.
109    /// * `false` if `u` and `v` were already in the same cluster.
110    ///
111    /// # Safety
112    ///
113    /// Caller must ensure `u` and `v` are valid node indices within bounds.
114    unsafe fn union(&mut self, u: u32, v: u32) -> bool;
115}
116
117impl<'a, T: Topology, const STRIDE_Y: usize> UnionFind for DecodingState<'a, T, STRIDE_Y> {
118    // Optimized find with O(1) fast path for self-rooted nodes
119    // At p=0.001, ~95% of nodes are self-rooted, so this check pays off
120    #[inline(always)]
121    fn find(&mut self, i: u32) -> u32 {
122        // SAFETY: Callers must ensure `i < parents.len()`. This is an internal
123        // method called only from growth and peeling code that iterates over
124        // valid node indices. The unchecked access eliminates bounds checking
125        // in the hot path.
126        unsafe {
127            let p = *self.parents.get_unchecked(i as usize);
128            if p == i {
129                return i; // Fast path: self-rooted (most common case)
130            }
131            self.find_slow(i, p)
132        }
133    }
134
135    #[inline(always)]
136    unsafe fn union_roots(&mut self, root_u: u32, root_v: u32) -> bool {
137        if root_u == root_v {
138            return false;
139        }
140
141        // Simple index-based union: smaller index becomes child of larger
142        // This provides deterministic behavior without rank tracking overhead
143        let (child, parent) = if root_u < root_v {
144            (root_u, root_v)
145        } else {
146            (root_v, root_u)
147        };
148
149        // SAFETY: Caller guarantees root_u and root_v are valid root indices
150        // (i.e., `parents[root] == root`). Since they're roots, they're valid
151        // node indices by construction.
152        *self.parents.get_unchecked_mut(child as usize) = parent;
153
154        // Invalidate cached root for the child's block
155        let blk_child = (child as usize) >> 6;
156        if blk_child < self.blocks_state.len() {
157            // SAFETY: Bounds check performed above.
158            self.blocks_state.get_unchecked_mut(blk_child).root = u32::MAX;
159        }
160        self.mark_block_dirty(blk_child);
161
162        true
163    }
164
165    #[inline(always)]
166    unsafe fn union(&mut self, u: u32, v: u32) -> bool {
167        // SAFETY: Caller guarantees u and v are valid node indices.
168        // find() performs unchecked access but is safe given valid indices.
169        let root_u = self.find(u);
170        let root_v = self.find(v);
171
172        if root_u != root_v {
173            // SAFETY: root_u and root_v are valid roots from find().
174            self.union_roots(root_u, root_v)
175        } else {
176            false
177        }
178    }
179}
180
181impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
182    // Cold path: path halving compression
183    // Each node on the path points to its grandparent, halving path length per traversal.
184    #[inline(never)]
185    #[cold]
186    fn find_slow(&mut self, mut i: u32, mut p: u32) -> u32 {
187        // SAFETY: This function is only called from find() with valid indices.
188        // The parent pointers form a well-formed tree structure where every
189        // node either points to itself (root) or to a valid parent index.
190        // The loop terminates when we reach a root (p == grandparent).
191        unsafe {
192            loop {
193                let grandparent = *self.parents.get_unchecked(p as usize);
194                if p == grandparent {
195                    return p; // Found root
196                }
197                // Path halving: point i to grandparent
198                *self.parents.get_unchecked_mut(i as usize) = grandparent;
199                self.mark_block_dirty(i as usize >> 6);
200                i = grandparent;
201                p = *self.parents.get_unchecked(i as usize);
202            }
203        }
204    }
205}