prav_core/decoder/
state.rs

1//! Core decoder state for Union Find QEC decoding.
2//!
3//! This module contains the main decoder state structure that holds all data
4//! needed for the decoding process: Union Find parent pointers, block states,
5//! active masks, and correction tracking.
6
7#![allow(unsafe_op_in_unsafe_fn)]
8
9use crate::arena::Arena;
10use crate::decoder::graph::StaticGraph;
11use crate::topology::Topology;
12
13// Re-export types for backward compatibility with imports from decoder::state::*
14pub use crate::decoder::types::{BlockStateHot, BoundaryConfig, EdgeCorrection, FLAG_VALID_FULL};
15
16/// Main decoder state structure for Union Find QEC decoding.
17///
18/// This structure contains all the state needed for a single decoding instance.
19/// It is parameterized by:
20///
21/// - `'a` - Lifetime of the backing arena memory
22/// - `T: Topology` - The lattice topology (e.g., [`SquareGrid`](crate::SquareGrid))
23/// - `STRIDE_Y` - Compile-time Y stride for performance optimization
24///
25/// # Memory Layout
26///
27/// All slices are allocated from an [`Arena`] and are contiguous.
28/// The decoder uses Morton (Z-order) encoding to organize nodes into 64-node
29/// blocks for cache efficiency.
30///
31/// # Usage
32///
33/// ```ignore
34/// use prav_core::{Arena, DecodingState, SquareGrid, EdgeCorrection};
35///
36/// let mut buffer = [0u8; 1024 * 1024];
37/// let mut arena = Arena::new(&mut buffer);
38///
39/// // Create decoder with STRIDE_Y matching grid dimensions
40/// let mut state: DecodingState<SquareGrid, 32> = DecodingState::new(&mut arena, 32, 32, 1);
41///
42/// // Load syndromes and decode
43/// state.load_dense_syndromes(&syndromes);
44/// state.grow_clusters();
45///
46/// let mut corrections = [EdgeCorrection::default(); 1024];
47/// let count = state.peel_forest(&mut corrections);
48/// ```
49///
50/// # Thread Safety
51///
52/// `DecodingState` is not thread-safe. For parallel decoding, create one
53/// instance per thread with separate arenas.
54pub struct DecodingState<'a, T: Topology, const STRIDE_Y: usize> {
55    /// Reference to static graph metadata (dimensions, strides).
56    pub graph: &'a StaticGraph,
57    /// Grid width in nodes.
58    pub width: usize,
59    /// Grid height in nodes.
60    pub height: usize,
61    /// Y stride for coordinate calculations (power of 2 for fast division).
62    pub stride_y: usize,
63    /// Bitmask identifying nodes at the start of their row within a block.
64    pub row_start_mask: u64,
65    /// Bitmask identifying nodes at the end of their row within a block.
66    pub row_end_mask: u64,
67
68    // Morton Layout State
69    /// Block state for each 64-node block (boundary, occupied, masks, cached root).
70    pub blocks_state: &'a mut [BlockStateHot],
71
72    // Node State
73    /// Union Find parent pointers. `parents[i] == i` means node i is a root.
74    pub parents: &'a mut [u32],
75    /// Bitmask of defect (syndrome) nodes per block.
76    pub defect_mask: &'a mut [u64],
77    /// Path marking bitmask used during peeling.
78    pub path_mark: &'a mut [u64],
79
80    // Sparse Reset Tracking
81    /// Bitmask tracking which blocks have been modified (for sparse reset).
82    pub block_dirty_mask: &'a mut [u64],
83
84    // Active Set
85    /// Bitmask of currently active blocks for this growth iteration.
86    pub active_mask: &'a mut [u64],
87    /// Bitmask of blocks queued for the next growth iteration.
88    pub queued_mask: &'a mut [u64],
89    /// Fast-path active mask for small grids (<=64 blocks).
90    pub active_block_mask: u64,
91
92    // Ingestion Worklist
93    /// List of block indices with syndromes to process.
94    pub ingestion_list: &'a mut [u32],
95    /// Number of valid entries in `ingestion_list`.
96    pub ingestion_count: usize,
97
98    // O(1) Edge Compaction State
99    /// Bitmask of edges to include in corrections (XOR-accumulated).
100    pub edge_bitmap: &'a mut [u64],
101    /// List of edge bitmap word indices that have been modified.
102    pub edge_dirty_list: &'a mut [u32],
103    /// Number of valid entries in `edge_dirty_list`.
104    pub edge_dirty_count: usize,
105    /// Bitmask tracking which edge_bitmap words are dirty.
106    pub edge_dirty_mask: &'a mut [u64],
107
108    /// Bitmask of boundary corrections per block.
109    pub boundary_bitmap: &'a mut [u64],
110    /// List of block indices with boundary corrections.
111    pub boundary_dirty_list: &'a mut [u32],
112    /// Number of valid entries in `boundary_dirty_list`.
113    pub boundary_dirty_count: usize,
114    /// Bitmask tracking which boundary_bitmap entries are dirty.
115    pub boundary_dirty_mask: &'a mut [u64],
116
117    /// BFS predecessor array for path tracing.
118    pub bfs_pred: &'a mut [u16],
119    /// BFS queue for path tracing.
120    pub bfs_queue: &'a mut [u16],
121
122    // AVX/Scalar coordination
123    /// Flag indicating scalar fallback is needed for some blocks.
124    pub needs_scalar_fallback: bool,
125    /// Bitmask of blocks requiring scalar processing.
126    pub scalar_fallback_mask: u64,
127
128    /// Configuration for boundary matching behavior.
129    pub boundary_config: BoundaryConfig,
130    /// Offset for parent array (used in some optimizations).
131    pub parent_offset: usize,
132
133    /// Phantom marker for the topology type parameter.
134    pub _marker: core::marker::PhantomData<T>,
135}
136
137impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
138    /// Creates a new decoder state for the given grid dimensions.
139    ///
140    /// Allocates all necessary data structures from the provided arena.
141    /// The decoder is initialized and ready for use after construction.
142    ///
143    /// # Arguments
144    ///
145    /// * `arena` - Arena allocator to use for all internal allocations.
146    /// * `width` - Grid width in nodes.
147    /// * `height` - Grid height in nodes.
148    /// * `depth` - Grid depth (1 for 2D codes, >1 for 3D codes).
149    ///
150    /// # Panics
151    ///
152    /// Panics if `STRIDE_Y` doesn't match the calculated stride for the given
153    /// dimensions. The stride is `max(width, height, depth).next_power_of_two()`.
154    ///
155    /// # Memory Requirements
156    ///
157    /// The arena must have sufficient space for:
158    /// - `num_nodes * 4` bytes for parent array
159    /// - `num_blocks * 64` bytes for block states
160    /// - Additional space for masks, bitmaps, and queues
161    ///
162    /// A safe estimate is `num_nodes * 20` bytes. Use [`required_buffer_size`](crate::required_buffer_size)
163    /// for exact calculation.
164    #[must_use]
165    pub fn new(arena: &mut Arena<'a>, width: usize, height: usize, depth: usize) -> Self {
166        let is_3d = depth > 1;
167        let max_dim = width.max(height).max(if is_3d { depth } else { 1 });
168        let dim_pow2 = max_dim.next_power_of_two();
169
170        let stride_x = 1;
171        let stride_y = dim_pow2;
172
173        // Runtime check to ensure the const generic matches the physical dimensions
174        assert_eq!(
175            stride_y, STRIDE_Y,
176            "STRIDE_Y const generic ({}) must match calculated stride ({})",
177            STRIDE_Y, stride_y
178        );
179
180        let stride_z = dim_pow2 * dim_pow2;
181        let blk_stride_y = stride_y / 64;
182        let shift_y = stride_y.trailing_zeros();
183        let shift_z = stride_z.trailing_zeros();
184
185        let mut row_end_mask = 0u64;
186        let mut row_start_mask = 0u64;
187
188        if stride_y < 64 {
189            let mut i = 0;
190            while i < 64 {
191                row_start_mask |= 1 << i;
192                row_end_mask |= 1 << (i + stride_y - 1);
193                i += stride_y;
194            }
195        }
196
197        let alloc_size = if is_3d {
198            dim_pow2 * dim_pow2 * dim_pow2
199        } else {
200            dim_pow2 * dim_pow2
201        };
202        let alloc_nodes = alloc_size + 1;
203        let num_blocks = alloc_nodes.div_ceil(64);
204        let num_bitmask_words = num_blocks.div_ceil(64);
205
206        let graph = StaticGraph {
207            width,
208            height,
209            depth,
210            stride_x,
211            stride_y,
212            stride_z,
213            blk_stride_y,
214            shift_y,
215            shift_z,
216            row_end_mask,
217            row_start_mask,
218        };
219        let graph_ref = arena.alloc_value(graph).unwrap();
220
221        let blocks_state = arena
222            .alloc_slice_aligned::<BlockStateHot>(num_blocks, 64)
223            .unwrap();
224
225        let parents = arena.alloc_slice_aligned::<u32>(alloc_nodes, 64).unwrap();
226        let defect_mask = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
227        let path_mark = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
228
229        let block_dirty_mask = arena
230            .alloc_slice_aligned::<u64>(num_blocks.div_ceil(64), 64)
231            .unwrap();
232
233        let active_mask = arena
234            .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
235            .unwrap();
236        let queued_mask = arena
237            .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
238            .unwrap();
239
240        let ingestion_list = arena.alloc_slice::<u32>(num_blocks).unwrap();
241
242        let num_edges = alloc_nodes * 3;
243        let num_edge_words = num_edges.div_ceil(64);
244        let edge_bitmap = arena
245            .alloc_slice_aligned::<u64>(num_edge_words, 64)
246            .unwrap();
247        // Allocate extra space for dirty lists to handle XOR cancellations causing re-insertion
248        let edge_dirty_list = arena.alloc_slice::<u32>(num_edge_words * 8).unwrap();
249
250        let boundary_bitmap = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
251        let boundary_dirty_list = arena.alloc_slice::<u32>(num_blocks * 8).unwrap();
252
253        let edge_dirty_mask = arena
254            .alloc_slice_aligned::<u64>(num_edge_words.div_ceil(64), 64)
255            .unwrap();
256        let boundary_dirty_mask = arena
257            .alloc_slice_aligned::<u64>(num_blocks.div_ceil(64), 64)
258            .unwrap();
259
260        let bfs_pred = arena.alloc_slice::<u16>(alloc_nodes).unwrap();
261        let bfs_queue = arena.alloc_slice::<u16>(alloc_nodes).unwrap();
262
263        let mut decoder = Self {
264            graph: graph_ref,
265            width,
266            height,
267            stride_y,
268            row_start_mask,
269            row_end_mask,
270            blocks_state,
271            parents,
272            defect_mask,
273            path_mark,
274            block_dirty_mask,
275            active_mask,
276            queued_mask,
277            active_block_mask: 0,
278            ingestion_list,
279            ingestion_count: 0,
280            edge_bitmap,
281            edge_dirty_list,
282            edge_dirty_count: 0,
283            edge_dirty_mask,
284            boundary_bitmap,
285            boundary_dirty_list,
286            boundary_dirty_count: 0,
287            boundary_dirty_mask,
288            bfs_pred,
289            bfs_queue,
290            needs_scalar_fallback: false,
291            scalar_fallback_mask: 0,
292            boundary_config: BoundaryConfig::default(),
293            parent_offset: 0,
294            _marker: core::marker::PhantomData,
295        };
296
297        decoder.initialize_internal();
298
299        decoder.parents[alloc_size] = alloc_size as u32;
300
301        for block in decoder.blocks_state.iter_mut() {
302            *block = BlockStateHot::default();
303        }
304
305        if is_3d {
306            for z in 0..depth {
307                for y in 0..height {
308                    for x in 0..width {
309                        let idx = (z * stride_z) + (y * stride_y) + (x * stride_x);
310                        let blk = idx / 64;
311                        let bit = idx % 64;
312                        if blk < num_blocks {
313                            decoder.blocks_state[blk].valid_mask |= 1 << bit;
314                        }
315                    }
316                }
317            }
318        } else {
319            for y in 0..height {
320                for x in 0..width {
321                    let idx = (y * stride_y) + (x * stride_x);
322                    let blk = idx / 64;
323                    let bit = idx % 64;
324                    if blk < num_blocks {
325                        decoder.blocks_state[blk].valid_mask |= 1 << bit;
326                    }
327                }
328            }
329        }
330
331        for block in decoder.blocks_state.iter_mut() {
332            let valid = block.valid_mask;
333            block.effective_mask = valid;
334            if valid == !0 {
335                block.flags |= FLAG_VALID_FULL;
336            } else {
337                block.flags &= !FLAG_VALID_FULL;
338            }
339        }
340
341        decoder
342    }
343
344    /// Resets all internal state for a new decoding cycle.
345    ///
346    /// This is a full reset that clears all dynamic state while preserving
347    /// the grid topology (valid_mask). Call this before loading new syndromes
348    /// when you want to completely reset the decoder.
349    ///
350    /// # Performance Note
351    ///
352    /// For repeated decoding with sparse syndromes, prefer [`sparse_reset`](Self::sparse_reset)
353    /// which only resets modified blocks (O(modified) vs O(n)).
354    pub fn initialize_internal(&mut self) {
355        for block in self.blocks_state.iter_mut() {
356            block.boundary = 0;
357            block.occupied = 0;
358            block.root = u32::MAX;
359            block.root_rank = 0;
360            // effective_mask and flags are persistent across resets (topology doesn't change)
361            // But if we wanted to fully reset everything we would need to recompute effective_mask.
362            // Assuming topology is static, we keep effective_mask and flags.
363        }
364        self.defect_mask.fill(0);
365        self.path_mark.fill(0);
366        self.block_dirty_mask.fill(0);
367
368        self.active_mask.fill(0);
369        self.queued_mask.fill(0);
370        self.active_block_mask = 0;
371
372        self.edge_bitmap.fill(0);
373        self.edge_dirty_count = 0;
374        self.edge_dirty_mask.fill(0);
375        self.boundary_bitmap.fill(0);
376        self.boundary_dirty_count = 0;
377        self.boundary_dirty_mask.fill(0);
378
379        self.ingestion_count = 0;
380
381        for (i, p) in self.parents.iter_mut().enumerate() {
382            *p = i as u32;
383        }
384    }
385
386    /// Loads erasure information indicating which qubits were lost.
387    ///
388    /// Erasures represent qubits that could not be measured (e.g., photon loss).
389    /// The decoder excludes erased qubits from cluster growth.
390    ///
391    /// # Arguments
392    ///
393    /// * `erasures` - Dense bitarray where bit `i` in `erasures[blk]` indicates
394    ///   node `(blk * 64 + i)` is erased.
395    ///
396    /// # Effect
397    ///
398    /// Updates `effective_mask = valid_mask & !erasure_mask` for each block,
399    /// which controls which nodes participate in cluster growth.
400    pub fn load_erasures(&mut self, erasures: &[u64]) {
401        let len = erasures.len().min(self.blocks_state.len());
402        for (i, &val) in erasures.iter().take(len).enumerate() {
403            self.blocks_state[i].erasure_mask = val;
404            let valid = self.blocks_state[i].valid_mask;
405            self.blocks_state[i].effective_mask = valid & !val;
406        }
407        for block in self.blocks_state[len..].iter_mut() {
408            block.erasure_mask = 0;
409            let valid = block.valid_mask;
410            block.effective_mask = valid;
411        }
412    }
413
414    /// Marks a block as modified for sparse reset tracking.
415    ///
416    /// Called internally when a block's state changes. Marked blocks will be
417    /// reset during the next [`sparse_reset`](Self::sparse_reset) call.
418    #[inline(always)]
419    pub fn mark_block_dirty(&mut self, blk_idx: usize) {
420        let mask_idx = blk_idx >> 6;
421        let mask_bit = blk_idx & 63;
422        unsafe {
423            *self.block_dirty_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
424        }
425    }
426
427    /// Checks if this grid qualifies for small-grid optimizations.
428    ///
429    /// Small grids (<=64 blocks, i.e., <=4096 nodes) use a single `u64` bitmask
430    /// for active block tracking, enabling faster iteration.
431    ///
432    /// # Returns
433    ///
434    /// `true` if the grid has at most 65 blocks (64 data + 1 boundary sentinel).
435    #[inline(always)]
436    pub fn is_small_grid(&self) -> bool {
437        // `active_block_mask` is a single `u64`, so we can only track up to 64 data blocks.
438        // `blocks_state` may include one extra sentinel block for the boundary node.
439        self.blocks_state.len() <= 65
440    }
441
442    /// Queues a block for processing in the next growth iteration.
443    ///
444    /// Sets the corresponding bit in `queued_mask`. The block will be processed
445    /// when `active_mask` and `queued_mask` are swapped at the end of the current
446    /// iteration.
447    #[inline(always)]
448    pub fn push_next(&mut self, blk_idx: usize) {
449        let mask_idx = blk_idx >> 6;
450        let mask_bit = blk_idx & 63;
451        unsafe {
452            *self.queued_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
453        }
454    }
455
456    /// Resets only the blocks that were modified, enabling efficient reuse.
457    ///
458    /// At typical error rates (p=0.001-0.06), only a small fraction of blocks
459    /// are modified during decoding. This method resets only those blocks,
460    /// achieving O(modified) complexity instead of O(total).
461    ///
462    /// # When to Use
463    ///
464    /// Call this between decoding cycles when:
465    /// - Error rate is low (most blocks unmodified)
466    /// - You want to minimize reset overhead
467    ///
468    /// For high error rates or when topology changes, use
469    /// [`initialize_internal`](Self::initialize_internal) instead.
470    ///
471    /// # What Gets Reset
472    ///
473    /// For each dirty block:
474    /// - `boundary`, `occupied`, `root`, `root_rank` in `BlockStateHot`
475    /// - `defect_mask` entry
476    /// - Parent pointers for all 64 nodes in the block
477    pub fn sparse_reset(&mut self) {
478        for (word_idx, word_ref) in self.block_dirty_mask.iter_mut().enumerate() {
479            let mut w = *word_ref;
480            *word_ref = 0;
481            while w != 0 {
482                let bit = w.trailing_zeros();
483                w &= w - 1;
484                let blk_idx = word_idx * 64 + bit as usize;
485
486                // SAFETY: blk_idx is derived from bits set in block_dirty_mask,
487                // which is only modified by mark_block_dirty() with valid block indices.
488                // Therefore blk_idx < blocks_state.len() and blk_idx < defect_mask.len().
489                unsafe {
490                    let block = self.blocks_state.get_unchecked_mut(blk_idx);
491                    block.boundary = 0;
492                    block.occupied = 0;
493                    block.root = u32::MAX;
494                    block.root_rank = 0;
495                    *self.defect_mask.get_unchecked_mut(blk_idx) = 0;
496                }
497
498                let start_node = blk_idx * 64;
499                let end_node = (start_node + 64).min(self.parents.len());
500                for node in start_node..end_node {
501                    // SAFETY: node is bounded by min(blk_idx*64+64, parents.len()),
502                    // so node < parents.len() is guaranteed.
503                    unsafe {
504                        *self.parents.get_unchecked_mut(node) = node as u32;
505                    }
506                }
507            }
508        }
509
510        self.queued_mask.fill(0);
511        self.active_mask.fill(0);
512        self.active_block_mask = 0;
513
514        let boundary_idx = self.parents.len() - 1;
515        self.parents[boundary_idx] = boundary_idx as u32;
516    }
517
518    /// Resets state for the next decoding cycle (sparse reset).
519    ///
520    /// This efficiently resets only the blocks that were modified during
521    /// the previous decoding cycle. At typical error rates (p < 0.1),
522    /// this is significantly faster than [`full_reset`](Self::full_reset).
523    ///
524    /// Use this method between consecutive decoding cycles when the
525    /// grid topology remains unchanged.
526    ///
527    /// # Complexity
528    ///
529    /// O(modified_blocks) where modified_blocks << total_blocks for typical error rates.
530    ///
531    /// # Example
532    ///
533    /// ```ignore
534    /// for cycle in 0..1000 {
535    ///     decoder.reset_for_next_cycle();
536    ///     decoder.load_dense_syndromes(&syndromes[cycle]);
537    ///     let count = decoder.decode(&mut corrections);
538    /// }
539    /// ```
540    #[inline]
541    pub fn reset_for_next_cycle(&mut self) {
542        self.sparse_reset();
543    }
544
545    /// Fully resets all decoder state.
546    ///
547    /// This performs a complete reset of all internal data structures,
548    /// suitable for when:
549    /// - Starting fresh with a completely new problem
550    /// - The grid topology has changed
551    /// - You want guaranteed clean state
552    ///
553    /// For repeated decoding with similar syndrome patterns, prefer
554    /// [`reset_for_next_cycle`](Self::reset_for_next_cycle) which is faster.
555    ///
556    /// # Complexity
557    ///
558    /// O(n) where n is the total number of nodes.
559    #[inline]
560    pub fn full_reset(&mut self) {
561        self.initialize_internal();
562    }
563}
564
565// Forwarding methods to traits - these provide convenient direct access
566// without needing to import the traits.
567impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
568    /// Finds the cluster root for node `i`. See [`UnionFind::find`](crate::UnionFind::find).
569    #[inline(always)]
570    pub fn find(&mut self, i: u32) -> u32 {
571        use super::union_find::UnionFind;
572        UnionFind::find(self, i)
573    }
574
575    /// Merges clusters containing nodes `u` and `v`.
576    ///
577    /// See [`UnionFind::union`](crate::UnionFind::union) for details.
578    ///
579    /// # Safety
580    ///
581    /// Caller must ensure `u` and `v` are valid node indices within the parents array.
582    #[inline(always)]
583    pub unsafe fn union(&mut self, u: u32, v: u32) -> bool {
584        use super::union_find::UnionFind;
585        UnionFind::union(self, u, v)
586    }
587
588    /// Loads syndrome measurements. See [`ClusterGrowth::load_dense_syndromes`](crate::ClusterGrowth::load_dense_syndromes).
589    #[inline(always)]
590    pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
591        use super::growth::ClusterGrowth;
592        ClusterGrowth::load_dense_syndromes(self, syndromes)
593    }
594
595    /// Expands clusters until convergence. See [`ClusterGrowth::grow_clusters`](crate::ClusterGrowth::grow_clusters).
596    #[inline(always)]
597    pub fn grow_clusters(&mut self) {
598        use super::growth::ClusterGrowth;
599        ClusterGrowth::grow_clusters(self)
600    }
601
602    /// Processes a single block during growth.
603    ///
604    /// See [`ClusterGrowth::process_block`](crate::ClusterGrowth::process_block) for details.
605    ///
606    /// # Safety
607    ///
608    /// Caller must ensure `blk_idx` is within bounds of the internal blocks state.
609    #[inline(always)]
610    pub unsafe fn process_block(&mut self, blk_idx: usize) -> bool {
611        use super::growth::ClusterGrowth;
612        ClusterGrowth::process_block(self, blk_idx)
613    }
614
615    /// Full decode: growth + peeling. See [`Peeling::decode`](crate::Peeling::decode).
616    #[inline(always)]
617    pub fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
618        use super::peeling::Peeling;
619        Peeling::decode(self, corrections)
620    }
621
622    /// Extracts corrections from grown clusters. See [`Peeling::peel_forest`](crate::Peeling::peel_forest).
623    #[inline(always)]
624    pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
625        use super::peeling::Peeling;
626        Peeling::peel_forest(self, corrections)
627    }
628}