prav_core/decoder/growth/mod.rs
1//! Cluster growth algorithm for Union Find QEC decoding.
2//!
3//! This module implements the iterative boundary expansion that groups syndrome
4//! nodes into connected clusters. The algorithm:
5//!
6//! 1. Loads syndrome measurements as seed points
7//! 2. Iteratively expands cluster boundaries using SWAR bit operations
8//! 3. Merges clusters when they meet via Union Find
9//! 4. Terminates when no more expansion is possible
10//!
11//! # Algorithm Overview
12//!
13//! ```text
14//! Initial: After growth:
15//! . . . . .
16//! . X . => . X . (X = defect, . = occupied by cluster)
17//! . . . . .
18//! ```
19//!
20//! # Performance Optimizations
21//!
22//! - **SWAR spreading**: Syndrome spreading uses SIMD-Within-A-Register operations,
23//! achieving 19-427x speedup over lookup tables.
24//! - **Monochromatic fast-path**: When all 64 nodes in a block share the same root,
25//! skip Union Find operations entirely (covers ~95% of blocks).
26//! - **Block-level parallelism**: Active blocks tracked via bitmasks for efficient
27//! iteration.
28
29#![allow(unsafe_op_in_unsafe_fn)]
30
31use crate::decoder::state::DecodingState;
32use crate::topology::Topology;
33
34// =============================================================================
35// Submodules
36// =============================================================================
37
38/// Inter-block neighbor processing and merging utilities.
39pub mod inter_block;
40
41/// Small grid fast-path (single u64 active mask, <=64 blocks).
42pub mod small_grid;
43
44/// Stride-32 specific implementation.
45pub mod stride32;
46
47/// Unrolled optimizations for stride-32/64.
48pub mod unrolled;
49
50/// Kani formal verification proofs.
51#[cfg(kani)]
52mod kani_proofs;
53
54/// Cluster boundary expansion operations for QEC decoding.
55///
56/// This trait defines the interface for the cluster growth phase of Union Find
57/// decoding. Starting from syndrome measurements (defects), clusters expand
58/// outward until they either meet other clusters or reach the boundary.
59///
60/// # Decoding Flow
61///
62/// ```text
63/// load_dense_syndromes() -> grow_clusters() -> peel_forest()
64/// | | |
65/// v v v
66/// Initialize seeds Expand boundaries Extract corrections
67/// ```
68///
69/// # Implementors
70///
71/// This trait is implemented by [`DecodingState`] for all topologies.
72pub trait ClusterGrowth {
73 /// Loads syndrome measurements from a dense bitarray.
74 ///
75 /// Syndromes indicate which stabilizer measurements detected errors.
76 /// Each u64 represents 64 consecutive nodes, where bit `i` being set
77 /// means node `(blk_idx * 64 + i)` has a syndrome (defect).
78 ///
79 /// # Arguments
80 ///
81 /// * `syndromes` - Dense bitarray where `syndromes[blk_idx]` contains
82 /// syndrome bits for block `blk_idx`. Length should match the number
83 /// of blocks in the decoder.
84 ///
85 /// # Implementation Details
86 ///
87 /// Uses a two-stage approach for large grids:
88 /// 1. **Scanner stage**: Burst-mode collection of non-zero blocks
89 /// 2. **Processor stage**: Initialize block state for active blocks
90 ///
91 /// For small grids (<=64 blocks), uses direct bitmask manipulation.
92 fn load_dense_syndromes(&mut self, syndromes: &[u64]);
93
94 /// Expands cluster boundaries until convergence.
95 ///
96 /// Iteratively calls [`grow_iteration`](Self::grow_iteration) until no
97 /// more expansion is possible. The algorithm terminates when:
98 ///
99 /// - All defects have been connected to the boundary, OR
100 /// - All defects have been paired with other defects in the same cluster
101 ///
102 /// # Termination Guarantee
103 ///
104 /// The algorithm is guaranteed to terminate within O(max_dimension) iterations,
105 /// where `max_dimension` is the largest grid dimension. A safety limit of
106 /// `max_dim * 16 + 128` iterations is enforced.
107 fn grow_clusters(&mut self);
108
109 /// Performs a single iteration of cluster growth.
110 ///
111 /// Processes all currently active blocks, expanding their boundaries
112 /// and merging clusters as needed.
113 ///
114 /// # Returns
115 ///
116 /// * `true` if any expansion occurred (more iterations may be needed).
117 /// * `false` if no expansion occurred (algorithm has converged).
118 ///
119 /// # Active Block Tracking
120 ///
121 /// Blocks are tracked in an active set. After processing:
122 /// - Blocks that expanded are added to the next iteration's set
123 /// - Blocks that can't expand further are removed
124 fn grow_iteration(&mut self) -> bool;
125
126 /// Processes a single block during cluster growth.
127 ///
128 /// Expands the boundary within the block and handles connections to
129 /// neighboring blocks. This is the core operation called for each
130 /// active block during growth.
131 ///
132 /// # Arguments
133 ///
134 /// * `blk_idx` - Index of the block to process.
135 ///
136 /// # Returns
137 ///
138 /// * `true` if the block's boundary expanded (neighbor blocks may activate).
139 /// * `false` if no expansion occurred.
140 ///
141 /// # Safety
142 ///
143 /// Caller must ensure `blk_idx` is within bounds of the internal blocks state.
144 unsafe fn process_block(&mut self, blk_idx: usize) -> bool;
145}
146
147impl<'a, T: Topology, const STRIDE_Y: usize> ClusterGrowth for DecodingState<'a, T, STRIDE_Y> {
148 fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
149 self.ingestion_count = 0;
150 self.active_block_mask = 0; // Reset for small grids
151 let limit = syndromes.len().min(self.blocks_state.len());
152
153 if self.is_small_grid() {
154 for blk_idx in 0..limit {
155 // Eagerly sync all blocks
156 let word = unsafe { *syndromes.get_unchecked(blk_idx) };
157 if word != 0 {
158 unsafe {
159 self.mark_block_dirty(blk_idx);
160 let block = self.blocks_state.get_unchecked_mut(blk_idx);
161 block.boundary |= word;
162 block.occupied |= word;
163 *self.defect_mask.get_unchecked_mut(blk_idx) |= word;
164
165 // Directly set bit in active_block_mask
166 self.active_block_mask |= 1 << blk_idx;
167 }
168 }
169 }
170 // Ensure queued_mask is clear before growth starts
171 if !self.queued_mask.is_empty() {
172 self.queued_mask[0] = 0;
173 }
174 return;
175 }
176
177 let mut blk_idx = 0;
178
179 // Stage 1: Scanner (Burst-Mode Ingestion)
180
181
182
183 while blk_idx < limit {
184 let word = unsafe { *syndromes.get_unchecked(blk_idx) };
185 if word != 0 {
186 unsafe {
187 *self.ingestion_list.get_unchecked_mut(self.ingestion_count) = blk_idx as u32;
188 self.ingestion_count += 1;
189 }
190 }
191 blk_idx += 1;
192 }
193
194 // Stage 2: Processor
195 for i in 0..self.ingestion_count {
196 let blk_idx = unsafe { *self.ingestion_list.get_unchecked(i) } as usize;
197 let word = unsafe { *syndromes.get_unchecked(blk_idx) };
198
199 // Lazy Reset: Ensure block is ready for this epoch - REMOVED, assumed clean via sparse_reset
200 unsafe {
201 self.mark_block_dirty(blk_idx);
202 let block = self.blocks_state.get_unchecked_mut(blk_idx);
203 block.boundary |= word;
204 block.occupied |= word;
205 *self.defect_mask.get_unchecked_mut(blk_idx) |= word;
206
207 self.push_next(blk_idx);
208 }
209 }
210
211 // Prepare for first growth iteration
212 if !self.is_small_grid() {
213 // Swap active and queued
214 core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
215 self.queued_mask.fill(0);
216 }
217 }
218
219 #[inline(never)]
220 fn grow_clusters(&mut self) {
221 if self.is_small_grid() {
222 let max_dim = self.width.max(self.height).max(self.graph.depth);
223 let limit = max_dim * 16 + 128;
224
225 for _ in 0..limit {
226 self.grow_bitmask_iteration();
227
228 if self.active_block_mask == 0 {
229 break;
230 }
231 }
232 return;
233 }
234
235 let max_dim = self.width.max(self.height).max(self.graph.depth);
236 let limit = max_dim * 16 + 128;
237
238 for _ in 0..limit {
239 // Check if active_mask is empty
240 // Use iterator to check all words (SIMD optimized typically)
241 if self.active_mask.iter().all(|&w| w == 0) {
242 break;
243 }
244
245 self.grow_iteration();
246 }
247 }
248
249 #[inline(always)]
250 fn grow_iteration(&mut self) -> bool {
251 if self.is_small_grid() {
252 let mut any_expansion = false;
253 let num_blocks = self.blocks_state.len();
254
255 // Clear the mask as we are switching to flat scan
256 if num_blocks > 0 {
257 unsafe { *self.queued_mask.get_unchecked_mut(0) = 0 };
258 }
259
260 for blk_idx in 0..num_blocks {
261 unsafe {
262 if self.process_block_silent(blk_idx) {
263 any_expansion = true;
264 }
265 }
266 }
267 return any_expansion;
268 }
269
270 let mut any_expansion = false;
271
272 // Queued mask is already cleared at the end of previous iteration (or init)
273
274 let active_mask_ptr = self.active_mask.as_ptr();
275 let active_mask_len = self.active_mask.len();
276
277 for chunk_idx in 0..active_mask_len {
278 let mut w = unsafe { *active_mask_ptr.add(chunk_idx) };
279 if w == 0 {
280 continue;
281 }
282
283 let base_idx = chunk_idx * 64;
284 while w != 0 {
285 let bit = w.trailing_zeros();
286 w &= w - 1;
287 let blk_idx = base_idx + bit as usize;
288
289 unsafe {
290 if self.process_block(blk_idx) {
291 any_expansion = true;
292 }
293 }
294 }
295 }
296
297 core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
298 self.queued_mask.fill(0);
299
300 any_expansion
301 }
302
303 #[inline(always)]
304 unsafe fn process_block(&mut self, blk_idx: usize) -> bool {
305 self.process_block_small_stride::<false>(blk_idx)
306 }
307}
308
309impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
310 fn grow_bitmask_iteration(&mut self) {
311 // optimized_32 unrolled path for 16 blocks (1024 nodes)
312 if STRIDE_Y == 32 && self.blocks_state.len() == 16 {
313 unsafe {
314 let (_expanded, next_mask) = self.process_all_blocks_stride_32_unrolled_16();
315 self.active_block_mask = next_mask;
316 }
317 return;
318 }
319
320 // Reset queued mask for next iteration
321 self.queued_mask[0] = 0;
322
323 let mut current_mask = self.active_block_mask;
324 while current_mask != 0 {
325 let blk_idx = crate::intrinsics::tzcnt(current_mask) as usize;
326 current_mask &= current_mask - 1;
327
328 unsafe {
329 self.process_block(blk_idx);
330 }
331 }
332
333 // Update active mask from the queued mask populated by process_block
334 self.active_block_mask = self.queued_mask[0];
335 }
336
337 #[inline(always)]
338 unsafe fn process_block_silent(&mut self, blk_idx: usize) -> bool {
339 self.process_block_small_stride::<true>(blk_idx)
340 }
341}