prav_core/decoder/growth/
mod.rs

1//! Cluster growth algorithm for Union Find QEC decoding.
2//!
3//! This module implements the iterative boundary expansion that groups syndrome
4//! nodes into connected clusters. The algorithm:
5//!
6//! 1. Loads syndrome measurements as seed points
7//! 2. Iteratively expands cluster boundaries using SWAR bit operations
8//! 3. Merges clusters when they meet via Union Find
9//! 4. Terminates when no more expansion is possible
10//!
11//! # Algorithm Overview
12//!
13//! ```text
14//! Initial:     After growth:
15//!   .   .        . . .
16//!   . X .   =>   . X .   (X = defect, . = occupied by cluster)
17//!   .   .        . . .
18//! ```
19//!
20//! # Performance Optimizations
21//!
22//! - **SWAR spreading**: Syndrome spreading uses SIMD-Within-A-Register operations,
23//!   achieving 19-427x speedup over lookup tables.
24//! - **Monochromatic fast-path**: When all 64 nodes in a block share the same root,
25//!   skip Union Find operations entirely (covers ~95% of blocks).
26//! - **Block-level parallelism**: Active blocks tracked via bitmasks for efficient
27//!   iteration.
28
29#![allow(unsafe_op_in_unsafe_fn)]
30
31use crate::decoder::state::DecodingState;
32use crate::topology::Topology;
33
34// =============================================================================
35// Submodules
36// =============================================================================
37
38/// Inter-block neighbor processing and merging utilities.
39pub mod inter_block;
40
41/// Small grid fast-path (single u64 active mask, <=64 blocks).
42pub mod small_grid;
43
44/// Stride-32 specific implementation.
45pub mod stride32;
46
47/// Unrolled optimizations for stride-32/64.
48pub mod unrolled;
49
50/// Kani formal verification proofs.
51#[cfg(kani)]
52mod kani_proofs;
53
54/// Cluster boundary expansion operations for QEC decoding.
55///
56/// This trait defines the interface for the cluster growth phase of Union Find
57/// decoding. Starting from syndrome measurements (defects), clusters expand
58/// outward until they either meet other clusters or reach the boundary.
59///
60/// # Decoding Flow
61///
62/// ```text
63/// load_dense_syndromes() -> grow_clusters() -> peel_forest()
64///        |                       |                 |
65///        v                       v                 v
66///   Initialize seeds      Expand boundaries   Extract corrections
67/// ```
68///
69/// # Implementors
70///
71/// This trait is implemented by [`DecodingState`] for all topologies.
72pub trait ClusterGrowth {
73    /// Loads syndrome measurements from a dense bitarray.
74    ///
75    /// Syndromes indicate which stabilizer measurements detected errors.
76    /// Each u64 represents 64 consecutive nodes, where bit `i` being set
77    /// means node `(blk_idx * 64 + i)` has a syndrome (defect).
78    ///
79    /// # Arguments
80    ///
81    /// * `syndromes` - Dense bitarray where `syndromes[blk_idx]` contains
82    ///   syndrome bits for block `blk_idx`. Length should match the number
83    ///   of blocks in the decoder.
84    ///
85    /// # Implementation Details
86    ///
87    /// Uses a two-stage approach for large grids:
88    /// 1. **Scanner stage**: Burst-mode collection of non-zero blocks
89    /// 2. **Processor stage**: Initialize block state for active blocks
90    ///
91    /// For small grids (<=64 blocks), uses direct bitmask manipulation.
92    fn load_dense_syndromes(&mut self, syndromes: &[u64]);
93
94    /// Expands cluster boundaries until convergence.
95    ///
96    /// Iteratively calls [`grow_iteration`](Self::grow_iteration) until no
97    /// more expansion is possible. The algorithm terminates when:
98    ///
99    /// - All defects have been connected to the boundary, OR
100    /// - All defects have been paired with other defects in the same cluster
101    ///
102    /// # Termination Guarantee
103    ///
104    /// The algorithm is guaranteed to terminate within O(max_dimension) iterations,
105    /// where `max_dimension` is the largest grid dimension. A safety limit of
106    /// `max_dim * 16 + 128` iterations is enforced.
107    fn grow_clusters(&mut self);
108
109    /// Performs a single iteration of cluster growth.
110    ///
111    /// Processes all currently active blocks, expanding their boundaries
112    /// and merging clusters as needed.
113    ///
114    /// # Returns
115    ///
116    /// * `true` if any expansion occurred (more iterations may be needed).
117    /// * `false` if no expansion occurred (algorithm has converged).
118    ///
119    /// # Active Block Tracking
120    ///
121    /// Blocks are tracked in an active set. After processing:
122    /// - Blocks that expanded are added to the next iteration's set
123    /// - Blocks that can't expand further are removed
124    fn grow_iteration(&mut self) -> bool;
125
126    /// Processes a single block during cluster growth.
127    ///
128    /// Expands the boundary within the block and handles connections to
129    /// neighboring blocks. This is the core operation called for each
130    /// active block during growth.
131    ///
132    /// # Arguments
133    ///
134    /// * `blk_idx` - Index of the block to process.
135    ///
136    /// # Returns
137    ///
138    /// * `true` if the block's boundary expanded (neighbor blocks may activate).
139    /// * `false` if no expansion occurred.
140    ///
141    /// # Safety
142    ///
143    /// Caller must ensure `blk_idx` is within bounds of the internal blocks state.
144    unsafe fn process_block(&mut self, blk_idx: usize) -> bool;
145}
146
147impl<'a, T: Topology, const STRIDE_Y: usize> ClusterGrowth for DecodingState<'a, T, STRIDE_Y> {
148    fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
149        self.ingestion_count = 0;
150        self.active_block_mask = 0; // Reset for small grids
151        let limit = syndromes.len().min(self.blocks_state.len());
152
153        if self.is_small_grid() {
154            for blk_idx in 0..limit {
155                // Eagerly sync all blocks
156                let word = unsafe { *syndromes.get_unchecked(blk_idx) };
157                if word != 0 {
158                    unsafe {
159                        self.mark_block_dirty(blk_idx);
160                        let block = self.blocks_state.get_unchecked_mut(blk_idx);
161                        block.boundary |= word;
162                        block.occupied |= word;
163                        *self.defect_mask.get_unchecked_mut(blk_idx) |= word;
164
165                        // Directly set bit in active_block_mask
166                        self.active_block_mask |= 1 << blk_idx;
167                    }
168                }
169            }
170            // Ensure queued_mask is clear before growth starts
171            if !self.queued_mask.is_empty() {
172                self.queued_mask[0] = 0;
173            }
174            return;
175        }
176
177        let mut blk_idx = 0;
178
179        // Stage 1: Scanner (Burst-Mode Ingestion)
180
181
182
183        while blk_idx < limit {
184            let word = unsafe { *syndromes.get_unchecked(blk_idx) };
185            if word != 0 {
186                unsafe {
187                    *self.ingestion_list.get_unchecked_mut(self.ingestion_count) = blk_idx as u32;
188                    self.ingestion_count += 1;
189                }
190            }
191            blk_idx += 1;
192        }
193
194        // Stage 2: Processor
195        for i in 0..self.ingestion_count {
196            let blk_idx = unsafe { *self.ingestion_list.get_unchecked(i) } as usize;
197            let word = unsafe { *syndromes.get_unchecked(blk_idx) };
198
199            // Lazy Reset: Ensure block is ready for this epoch - REMOVED, assumed clean via sparse_reset
200            unsafe {
201                self.mark_block_dirty(blk_idx);
202                let block = self.blocks_state.get_unchecked_mut(blk_idx);
203                block.boundary |= word;
204                block.occupied |= word;
205                *self.defect_mask.get_unchecked_mut(blk_idx) |= word;
206
207                self.push_next(blk_idx);
208            }
209        }
210
211        // Prepare for first growth iteration
212        if !self.is_small_grid() {
213            // Swap active and queued
214            core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
215            self.queued_mask.fill(0);
216        }
217    }
218
219    #[inline(never)]
220    fn grow_clusters(&mut self) {
221        if self.is_small_grid() {
222            let max_dim = self.width.max(self.height).max(self.graph.depth);
223            let limit = max_dim * 16 + 128;
224
225            for _ in 0..limit {
226                self.grow_bitmask_iteration();
227
228                if self.active_block_mask == 0 {
229                    break;
230                }
231            }
232            return;
233        }
234
235        let max_dim = self.width.max(self.height).max(self.graph.depth);
236        let limit = max_dim * 16 + 128;
237
238        for _ in 0..limit {
239            // Check if active_mask is empty
240            // Use iterator to check all words (SIMD optimized typically)
241            if self.active_mask.iter().all(|&w| w == 0) {
242                break;
243            }
244
245            self.grow_iteration();
246        }
247    }
248
249    #[inline(always)]
250    fn grow_iteration(&mut self) -> bool {
251        if self.is_small_grid() {
252            let mut any_expansion = false;
253            let num_blocks = self.blocks_state.len();
254
255            // Clear the mask as we are switching to flat scan
256            if num_blocks > 0 {
257                unsafe { *self.queued_mask.get_unchecked_mut(0) = 0 };
258            }
259
260            for blk_idx in 0..num_blocks {
261                unsafe {
262                    if self.process_block_silent(blk_idx) {
263                        any_expansion = true;
264                    }
265                }
266            }
267            return any_expansion;
268        }
269
270        let mut any_expansion = false;
271
272        // Queued mask is already cleared at the end of previous iteration (or init)
273
274        let active_mask_ptr = self.active_mask.as_ptr();
275        let active_mask_len = self.active_mask.len();
276
277        for chunk_idx in 0..active_mask_len {
278            let mut w = unsafe { *active_mask_ptr.add(chunk_idx) };
279            if w == 0 {
280                continue;
281            }
282
283            let base_idx = chunk_idx * 64;
284            while w != 0 {
285                let bit = w.trailing_zeros();
286                w &= w - 1;
287                let blk_idx = base_idx + bit as usize;
288
289                unsafe {
290                    if self.process_block(blk_idx) {
291                        any_expansion = true;
292                    }
293                }
294            }
295        }
296
297        core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
298        self.queued_mask.fill(0);
299
300        any_expansion
301    }
302
303    #[inline(always)]
304    unsafe fn process_block(&mut self, blk_idx: usize) -> bool {
305        self.process_block_small_stride::<false>(blk_idx)
306    }
307}
308
309impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
310    fn grow_bitmask_iteration(&mut self) {
311        // optimized_32 unrolled path for 16 blocks (1024 nodes)
312        if STRIDE_Y == 32 && self.blocks_state.len() == 16 {
313            unsafe {
314                let (_expanded, next_mask) = self.process_all_blocks_stride_32_unrolled_16();
315                self.active_block_mask = next_mask;
316            }
317            return;
318        }
319
320        // Reset queued mask for next iteration
321        self.queued_mask[0] = 0;
322
323        let mut current_mask = self.active_block_mask;
324        while current_mask != 0 {
325            let blk_idx = crate::intrinsics::tzcnt(current_mask) as usize;
326            current_mask &= current_mask - 1;
327
328            unsafe {
329                self.process_block(blk_idx);
330            }
331        }
332
333        // Update active mask from the queued mask populated by process_block
334        self.active_block_mask = self.queued_mask[0];
335    }
336
337    #[inline(always)]
338    unsafe fn process_block_silent(&mut self, blk_idx: usize) -> bool {
339        self.process_block_small_stride::<true>(blk_idx)
340    }
341}