prav_core/decoder/state.rs
1//! Core decoder state for Union Find QEC decoding.
2//!
3//! This module contains the main decoder state structure that holds all data
4//! needed for the decoding process: Union Find parent pointers, block states,
5//! active masks, and correction tracking.
6
7#![allow(unsafe_op_in_unsafe_fn)]
8
9use crate::arena::Arena;
10use crate::decoder::graph::StaticGraph;
11use crate::topology::Topology;
12
13// Re-export types for backward compatibility with imports from decoder::state::*
14pub use crate::decoder::types::{BlockStateHot, BoundaryConfig, EdgeCorrection, FLAG_VALID_FULL};
15
16/// Main decoder state structure for Union Find QEC decoding.
17///
18/// This structure contains all the state needed for a single decoding instance.
19/// It is parameterized by:
20///
21/// - `'a` - Lifetime of the backing arena memory
22/// - `T: Topology` - The lattice topology (e.g., [`SquareGrid`](crate::SquareGrid))
23/// - `STRIDE_Y` - Compile-time Y stride for performance optimization
24///
25/// # Memory Layout
26///
27/// All slices are allocated from an [`Arena`] and are contiguous.
28/// The decoder uses Morton (Z-order) encoding to organize nodes into 64-node
29/// blocks for cache efficiency.
30///
31/// # Usage
32///
33/// ```ignore
34/// use prav_core::{Arena, DecodingState, SquareGrid, EdgeCorrection};
35///
36/// let mut buffer = [0u8; 1024 * 1024];
37/// let mut arena = Arena::new(&mut buffer);
38///
39/// // Create decoder with STRIDE_Y matching grid dimensions
40/// let mut state: DecodingState<SquareGrid, 32> = DecodingState::new(&mut arena, 32, 32, 1);
41///
42/// // Load syndromes and decode
43/// state.load_dense_syndromes(&syndromes);
44/// state.grow_clusters();
45///
46/// let mut corrections = [EdgeCorrection::default(); 1024];
47/// let count = state.peel_forest(&mut corrections);
48/// ```
49///
50/// # Thread Safety
51///
52/// `DecodingState` is not thread-safe. For parallel decoding, create one
53/// instance per thread with separate arenas.
54pub struct DecodingState<'a, T: Topology, const STRIDE_Y: usize> {
55 /// Reference to static graph metadata (dimensions, strides).
56 pub graph: &'a StaticGraph,
57 /// Grid width in nodes.
58 pub width: usize,
59 /// Grid height in nodes.
60 pub height: usize,
61 /// Y stride for coordinate calculations (power of 2 for fast division).
62 pub stride_y: usize,
63 /// Bitmask identifying nodes at the start of their row within a block.
64 pub row_start_mask: u64,
65 /// Bitmask identifying nodes at the end of their row within a block.
66 pub row_end_mask: u64,
67
68 // Morton Layout State
69 /// Block state for each 64-node block (boundary, occupied, masks, cached root).
70 pub blocks_state: &'a mut [BlockStateHot],
71
72 // Node State
73 /// Union Find parent pointers. `parents[i] == i` means node i is a root.
74 pub parents: &'a mut [u32],
75 /// Bitmask of defect (syndrome) nodes per block.
76 pub defect_mask: &'a mut [u64],
77 /// Path marking bitmask used during peeling.
78 pub path_mark: &'a mut [u64],
79
80 // Sparse Reset Tracking
81 /// Bitmask tracking which blocks have been modified (for sparse reset).
82 pub block_dirty_mask: &'a mut [u64],
83
84 // Active Set
85 /// Bitmask of currently active blocks for this growth iteration.
86 pub active_mask: &'a mut [u64],
87 /// Bitmask of blocks queued for the next growth iteration.
88 pub queued_mask: &'a mut [u64],
89 /// Fast-path active mask for small grids (<=64 blocks).
90 pub active_block_mask: u64,
91
92 // Ingestion Worklist
93 /// List of block indices with syndromes to process.
94 pub ingestion_list: &'a mut [u32],
95 /// Number of valid entries in `ingestion_list`.
96 pub ingestion_count: usize,
97
98 // O(1) Edge Compaction State
99 /// Bitmask of edges to include in corrections (XOR-accumulated).
100 pub edge_bitmap: &'a mut [u64],
101 /// List of edge bitmap word indices that have been modified.
102 pub edge_dirty_list: &'a mut [u32],
103 /// Number of valid entries in `edge_dirty_list`.
104 pub edge_dirty_count: usize,
105 /// Bitmask tracking which edge_bitmap words are dirty.
106 pub edge_dirty_mask: &'a mut [u64],
107
108 /// Bitmask of boundary corrections per block.
109 pub boundary_bitmap: &'a mut [u64],
110 /// List of block indices with boundary corrections.
111 pub boundary_dirty_list: &'a mut [u32],
112 /// Number of valid entries in `boundary_dirty_list`.
113 pub boundary_dirty_count: usize,
114 /// Bitmask tracking which boundary_bitmap entries are dirty.
115 pub boundary_dirty_mask: &'a mut [u64],
116
117 /// BFS predecessor array for path tracing.
118 pub bfs_pred: &'a mut [u16],
119 /// BFS queue for path tracing.
120 pub bfs_queue: &'a mut [u16],
121
122 // AVX/Scalar coordination
123 /// Flag indicating scalar fallback is needed for some blocks.
124 pub needs_scalar_fallback: bool,
125 /// Bitmask of blocks requiring scalar processing.
126 pub scalar_fallback_mask: u64,
127
128 /// Configuration for boundary matching behavior.
129 pub boundary_config: BoundaryConfig,
130 /// Offset for parent array (used in some optimizations).
131 pub parent_offset: usize,
132
133 /// Phantom marker for the topology type parameter.
134 pub _marker: core::marker::PhantomData<T>,
135}
136
137impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
138 /// Creates a new decoder state for the given grid dimensions.
139 ///
140 /// Allocates all necessary data structures from the provided arena.
141 /// The decoder is initialized and ready for use after construction.
142 ///
143 /// # Arguments
144 ///
145 /// * `arena` - Arena allocator to use for all internal allocations.
146 /// * `width` - Grid width in nodes.
147 /// * `height` - Grid height in nodes.
148 /// * `depth` - Grid depth (1 for 2D codes, >1 for 3D codes).
149 ///
150 /// # Panics
151 ///
152 /// Panics if `STRIDE_Y` doesn't match the calculated stride for the given
153 /// dimensions. The stride is `max(width, height, depth).next_power_of_two()`.
154 ///
155 /// # Memory Requirements
156 ///
157 /// The arena must have sufficient space for:
158 /// - `num_nodes * 4` bytes for parent array
159 /// - `num_blocks * 64` bytes for block states
160 /// - Additional space for masks, bitmaps, and queues
161 ///
162 /// A safe estimate is `num_nodes * 20` bytes. Use [`required_buffer_size`](crate::required_buffer_size)
163 /// for exact calculation.
164 #[must_use]
165 pub fn new(arena: &mut Arena<'a>, width: usize, height: usize, depth: usize) -> Self {
166 let is_3d = depth > 1;
167 let max_dim = width.max(height).max(if is_3d { depth } else { 1 });
168 let dim_pow2 = max_dim.next_power_of_two();
169
170 let stride_x = 1;
171 let stride_y = dim_pow2;
172
173 // Runtime check to ensure the const generic matches the physical dimensions
174 assert_eq!(
175 stride_y, STRIDE_Y,
176 "STRIDE_Y const generic ({}) must match calculated stride ({})",
177 STRIDE_Y, stride_y
178 );
179
180 let stride_z = dim_pow2 * dim_pow2;
181 let blk_stride_y = stride_y / 64;
182 let shift_y = stride_y.trailing_zeros();
183 let shift_z = stride_z.trailing_zeros();
184
185 let mut row_end_mask = 0u64;
186 let mut row_start_mask = 0u64;
187
188 if stride_y < 64 {
189 let mut i = 0;
190 while i < 64 {
191 row_start_mask |= 1 << i;
192 row_end_mask |= 1 << (i + stride_y - 1);
193 i += stride_y;
194 }
195 }
196
197 let alloc_size = if is_3d {
198 dim_pow2 * dim_pow2 * dim_pow2
199 } else {
200 dim_pow2 * dim_pow2
201 };
202 let alloc_nodes = alloc_size + 1;
203 let num_blocks = alloc_nodes.div_ceil(64);
204 let num_bitmask_words = num_blocks.div_ceil(64);
205
206 let graph = StaticGraph {
207 width,
208 height,
209 depth,
210 stride_x,
211 stride_y,
212 stride_z,
213 blk_stride_y,
214 shift_y,
215 shift_z,
216 row_end_mask,
217 row_start_mask,
218 };
219 let graph_ref = arena.alloc_value(graph).unwrap();
220
221 let blocks_state = arena
222 .alloc_slice_aligned::<BlockStateHot>(num_blocks, 64)
223 .unwrap();
224
225 let parents = arena.alloc_slice_aligned::<u32>(alloc_nodes, 64).unwrap();
226 let defect_mask = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
227 let path_mark = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
228
229 let block_dirty_mask = arena
230 .alloc_slice_aligned::<u64>(num_blocks.div_ceil(64), 64)
231 .unwrap();
232
233 let active_mask = arena
234 .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
235 .unwrap();
236 let queued_mask = arena
237 .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
238 .unwrap();
239
240 let ingestion_list = arena.alloc_slice::<u32>(num_blocks).unwrap();
241
242 let num_edges = alloc_nodes * 3;
243 let num_edge_words = num_edges.div_ceil(64);
244 let edge_bitmap = arena
245 .alloc_slice_aligned::<u64>(num_edge_words, 64)
246 .unwrap();
247 // Allocate extra space for dirty lists to handle XOR cancellations causing re-insertion
248 let edge_dirty_list = arena.alloc_slice::<u32>(num_edge_words * 8).unwrap();
249
250 let boundary_bitmap = arena.alloc_slice_aligned::<u64>(num_blocks, 64).unwrap();
251 let boundary_dirty_list = arena.alloc_slice::<u32>(num_blocks * 8).unwrap();
252
253 let edge_dirty_mask = arena
254 .alloc_slice_aligned::<u64>(num_edge_words.div_ceil(64), 64)
255 .unwrap();
256 let boundary_dirty_mask = arena
257 .alloc_slice_aligned::<u64>(num_blocks.div_ceil(64), 64)
258 .unwrap();
259
260 let bfs_pred = arena.alloc_slice::<u16>(alloc_nodes).unwrap();
261 let bfs_queue = arena.alloc_slice::<u16>(alloc_nodes).unwrap();
262
263 let mut decoder = Self {
264 graph: graph_ref,
265 width,
266 height,
267 stride_y,
268 row_start_mask,
269 row_end_mask,
270 blocks_state,
271 parents,
272 defect_mask,
273 path_mark,
274 block_dirty_mask,
275 active_mask,
276 queued_mask,
277 active_block_mask: 0,
278 ingestion_list,
279 ingestion_count: 0,
280 edge_bitmap,
281 edge_dirty_list,
282 edge_dirty_count: 0,
283 edge_dirty_mask,
284 boundary_bitmap,
285 boundary_dirty_list,
286 boundary_dirty_count: 0,
287 boundary_dirty_mask,
288 bfs_pred,
289 bfs_queue,
290 needs_scalar_fallback: false,
291 scalar_fallback_mask: 0,
292 boundary_config: BoundaryConfig::default(),
293 parent_offset: 0,
294 _marker: core::marker::PhantomData,
295 };
296
297 decoder.initialize_internal();
298
299 decoder.parents[alloc_size] = alloc_size as u32;
300
301 for block in decoder.blocks_state.iter_mut() {
302 *block = BlockStateHot::default();
303 }
304
305 if is_3d {
306 for z in 0..depth {
307 for y in 0..height {
308 for x in 0..width {
309 let idx = (z * stride_z) + (y * stride_y) + (x * stride_x);
310 let blk = idx / 64;
311 let bit = idx % 64;
312 if blk < num_blocks {
313 decoder.blocks_state[blk].valid_mask |= 1 << bit;
314 }
315 }
316 }
317 }
318 } else {
319 for y in 0..height {
320 for x in 0..width {
321 let idx = (y * stride_y) + (x * stride_x);
322 let blk = idx / 64;
323 let bit = idx % 64;
324 if blk < num_blocks {
325 decoder.blocks_state[blk].valid_mask |= 1 << bit;
326 }
327 }
328 }
329 }
330
331 for block in decoder.blocks_state.iter_mut() {
332 let valid = block.valid_mask;
333 block.effective_mask = valid;
334 if valid == !0 {
335 block.flags |= FLAG_VALID_FULL;
336 } else {
337 block.flags &= !FLAG_VALID_FULL;
338 }
339 }
340
341 decoder
342 }
343
344 /// Resets all internal state for a new decoding cycle.
345 ///
346 /// This is a full reset that clears all dynamic state while preserving
347 /// the grid topology (valid_mask). Call this before loading new syndromes
348 /// when you want to completely reset the decoder.
349 ///
350 /// # Performance Note
351 ///
352 /// For repeated decoding with sparse syndromes, prefer [`sparse_reset`](Self::sparse_reset)
353 /// which only resets modified blocks (O(modified) vs O(n)).
354 pub fn initialize_internal(&mut self) {
355 for block in self.blocks_state.iter_mut() {
356 block.boundary = 0;
357 block.occupied = 0;
358 block.root = u32::MAX;
359 block.root_rank = 0;
360 // effective_mask and flags are persistent across resets (topology doesn't change)
361 // But if we wanted to fully reset everything we would need to recompute effective_mask.
362 // Assuming topology is static, we keep effective_mask and flags.
363 }
364 self.defect_mask.fill(0);
365 self.path_mark.fill(0);
366 self.block_dirty_mask.fill(0);
367
368 self.active_mask.fill(0);
369 self.queued_mask.fill(0);
370 self.active_block_mask = 0;
371
372 self.edge_bitmap.fill(0);
373 self.edge_dirty_count = 0;
374 self.edge_dirty_mask.fill(0);
375 self.boundary_bitmap.fill(0);
376 self.boundary_dirty_count = 0;
377 self.boundary_dirty_mask.fill(0);
378
379 self.ingestion_count = 0;
380
381 for (i, p) in self.parents.iter_mut().enumerate() {
382 *p = i as u32;
383 }
384 }
385
386 /// Loads erasure information indicating which qubits were lost.
387 ///
388 /// Erasures represent qubits that could not be measured (e.g., photon loss).
389 /// The decoder excludes erased qubits from cluster growth.
390 ///
391 /// # Arguments
392 ///
393 /// * `erasures` - Dense bitarray where bit `i` in `erasures[blk]` indicates
394 /// node `(blk * 64 + i)` is erased.
395 ///
396 /// # Effect
397 ///
398 /// Updates `effective_mask = valid_mask & !erasure_mask` for each block,
399 /// which controls which nodes participate in cluster growth.
400 pub fn load_erasures(&mut self, erasures: &[u64]) {
401 let len = erasures.len().min(self.blocks_state.len());
402 for (i, &val) in erasures.iter().take(len).enumerate() {
403 self.blocks_state[i].erasure_mask = val;
404 let valid = self.blocks_state[i].valid_mask;
405 self.blocks_state[i].effective_mask = valid & !val;
406 }
407 for block in self.blocks_state[len..].iter_mut() {
408 block.erasure_mask = 0;
409 let valid = block.valid_mask;
410 block.effective_mask = valid;
411 }
412 }
413
414 /// Marks a block as modified for sparse reset tracking.
415 ///
416 /// Called internally when a block's state changes. Marked blocks will be
417 /// reset during the next [`sparse_reset`](Self::sparse_reset) call.
418 #[inline(always)]
419 pub fn mark_block_dirty(&mut self, blk_idx: usize) {
420 let mask_idx = blk_idx >> 6;
421 let mask_bit = blk_idx & 63;
422 unsafe {
423 *self.block_dirty_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
424 }
425 }
426
427 /// Checks if this grid qualifies for small-grid optimizations.
428 ///
429 /// Small grids (<=64 blocks, i.e., <=4096 nodes) use a single `u64` bitmask
430 /// for active block tracking, enabling faster iteration.
431 ///
432 /// # Returns
433 ///
434 /// `true` if the grid has at most 65 blocks (64 data + 1 boundary sentinel).
435 #[inline(always)]
436 pub fn is_small_grid(&self) -> bool {
437 // `active_block_mask` is a single `u64`, so we can only track up to 64 data blocks.
438 // `blocks_state` may include one extra sentinel block for the boundary node.
439 self.blocks_state.len() <= 65
440 }
441
442 /// Queues a block for processing in the next growth iteration.
443 ///
444 /// Sets the corresponding bit in `queued_mask`. The block will be processed
445 /// when `active_mask` and `queued_mask` are swapped at the end of the current
446 /// iteration.
447 #[inline(always)]
448 pub fn push_next(&mut self, blk_idx: usize) {
449 let mask_idx = blk_idx >> 6;
450 let mask_bit = blk_idx & 63;
451 unsafe {
452 *self.queued_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
453 }
454 }
455
456 /// Resets only the blocks that were modified, enabling efficient reuse.
457 ///
458 /// At typical error rates (p=0.001-0.06), only a small fraction of blocks
459 /// are modified during decoding. This method resets only those blocks,
460 /// achieving O(modified) complexity instead of O(total).
461 ///
462 /// # When to Use
463 ///
464 /// Call this between decoding cycles when:
465 /// - Error rate is low (most blocks unmodified)
466 /// - You want to minimize reset overhead
467 ///
468 /// For high error rates or when topology changes, use
469 /// [`initialize_internal`](Self::initialize_internal) instead.
470 ///
471 /// # What Gets Reset
472 ///
473 /// For each dirty block:
474 /// - `boundary`, `occupied`, `root`, `root_rank` in `BlockStateHot`
475 /// - `defect_mask` entry
476 /// - Parent pointers for all 64 nodes in the block
477 pub fn sparse_reset(&mut self) {
478 for (word_idx, word_ref) in self.block_dirty_mask.iter_mut().enumerate() {
479 let mut w = *word_ref;
480 *word_ref = 0;
481 while w != 0 {
482 let bit = w.trailing_zeros();
483 w &= w - 1;
484 let blk_idx = word_idx * 64 + bit as usize;
485
486 // SAFETY: blk_idx is derived from bits set in block_dirty_mask,
487 // which is only modified by mark_block_dirty() with valid block indices.
488 // Therefore blk_idx < blocks_state.len() and blk_idx < defect_mask.len().
489 unsafe {
490 let block = self.blocks_state.get_unchecked_mut(blk_idx);
491 block.boundary = 0;
492 block.occupied = 0;
493 block.root = u32::MAX;
494 block.root_rank = 0;
495 *self.defect_mask.get_unchecked_mut(blk_idx) = 0;
496 }
497
498 let start_node = blk_idx * 64;
499 let end_node = (start_node + 64).min(self.parents.len());
500 for node in start_node..end_node {
501 // SAFETY: node is bounded by min(blk_idx*64+64, parents.len()),
502 // so node < parents.len() is guaranteed.
503 unsafe {
504 *self.parents.get_unchecked_mut(node) = node as u32;
505 }
506 }
507 }
508 }
509
510 self.queued_mask.fill(0);
511 self.active_mask.fill(0);
512 self.active_block_mask = 0;
513
514 let boundary_idx = self.parents.len() - 1;
515 self.parents[boundary_idx] = boundary_idx as u32;
516 }
517
518 /// Resets state for the next decoding cycle (sparse reset).
519 ///
520 /// This efficiently resets only the blocks that were modified during
521 /// the previous decoding cycle. At typical error rates (p < 0.1),
522 /// this is significantly faster than [`full_reset`](Self::full_reset).
523 ///
524 /// Use this method between consecutive decoding cycles when the
525 /// grid topology remains unchanged.
526 ///
527 /// # Complexity
528 ///
529 /// O(modified_blocks) where modified_blocks << total_blocks for typical error rates.
530 ///
531 /// # Example
532 ///
533 /// ```ignore
534 /// for cycle in 0..1000 {
535 /// decoder.reset_for_next_cycle();
536 /// decoder.load_dense_syndromes(&syndromes[cycle]);
537 /// let count = decoder.decode(&mut corrections);
538 /// }
539 /// ```
540 #[inline]
541 pub fn reset_for_next_cycle(&mut self) {
542 self.sparse_reset();
543 }
544
545 /// Fully resets all decoder state.
546 ///
547 /// This performs a complete reset of all internal data structures,
548 /// suitable for when:
549 /// - Starting fresh with a completely new problem
550 /// - The grid topology has changed
551 /// - You want guaranteed clean state
552 ///
553 /// For repeated decoding with similar syndrome patterns, prefer
554 /// [`reset_for_next_cycle`](Self::reset_for_next_cycle) which is faster.
555 ///
556 /// # Complexity
557 ///
558 /// O(n) where n is the total number of nodes.
559 #[inline]
560 pub fn full_reset(&mut self) {
561 self.initialize_internal();
562 }
563}
564
565// Forwarding methods to traits - these provide convenient direct access
566// without needing to import the traits.
567impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
568 /// Finds the cluster root for node `i`. See [`UnionFind::find`](crate::UnionFind::find).
569 #[inline(always)]
570 pub fn find(&mut self, i: u32) -> u32 {
571 use super::union_find::UnionFind;
572 UnionFind::find(self, i)
573 }
574
575 /// Merges clusters containing nodes `u` and `v`.
576 ///
577 /// See [`UnionFind::union`](crate::UnionFind::union) for details.
578 ///
579 /// # Safety
580 ///
581 /// Caller must ensure `u` and `v` are valid node indices within the parents array.
582 #[inline(always)]
583 pub unsafe fn union(&mut self, u: u32, v: u32) -> bool {
584 use super::union_find::UnionFind;
585 UnionFind::union(self, u, v)
586 }
587
588 /// Loads syndrome measurements. See [`ClusterGrowth::load_dense_syndromes`](crate::ClusterGrowth::load_dense_syndromes).
589 #[inline(always)]
590 pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
591 use super::growth::ClusterGrowth;
592 ClusterGrowth::load_dense_syndromes(self, syndromes)
593 }
594
595 /// Expands clusters until convergence. See [`ClusterGrowth::grow_clusters`](crate::ClusterGrowth::grow_clusters).
596 #[inline(always)]
597 pub fn grow_clusters(&mut self) {
598 use super::growth::ClusterGrowth;
599 ClusterGrowth::grow_clusters(self)
600 }
601
602 /// Processes a single block during growth.
603 ///
604 /// See [`ClusterGrowth::process_block`](crate::ClusterGrowth::process_block) for details.
605 ///
606 /// # Safety
607 ///
608 /// Caller must ensure `blk_idx` is within bounds of the internal blocks state.
609 #[inline(always)]
610 pub unsafe fn process_block(&mut self, blk_idx: usize) -> bool {
611 use super::growth::ClusterGrowth;
612 ClusterGrowth::process_block(self, blk_idx)
613 }
614
615 /// Full decode: growth + peeling. See [`Peeling::decode`](crate::Peeling::decode).
616 #[inline(always)]
617 pub fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
618 use super::peeling::Peeling;
619 Peeling::decode(self, corrections)
620 }
621
622 /// Extracts corrections from grown clusters. See [`Peeling::peel_forest`](crate::Peeling::peel_forest).
623 #[inline(always)]
624 pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
625 use super::peeling::Peeling;
626 Peeling::peel_forest(self, corrections)
627 }
628}