prav_core/decoder/
tiled.rs

1//! Tiled decoder for large grids.
2//!
3//! This module provides [`TiledDecodingState`], which breaks large grids into
4//! 32x32 tiles for improved cache locality and SWAR efficiency. Each tile is
5//! processed using the optimized Stride-32 path, then tiles are stitched together.
6
7#![allow(unsafe_op_in_unsafe_fn)]
8
9use crate::arena::Arena;
10use crate::decoder::state::DecodingState;
11use crate::decoder::types::{BlockStateHot, BoundaryConfig, EdgeCorrection, FLAG_VALID_FULL};
12use crate::intrinsics::{morton_decode_2d, morton_encode_2d, tzcnt};
13use crate::topology::Topology;
14
15/// A tiled decoder that manages large grids by breaking them into 32x32 tiles.
16///
17/// For grids larger than ~4096 nodes, the standard decoder's cache utilization
18/// degrades. `TiledDecodingState` addresses this by:
19///
20/// 1. Dividing the grid into 32x32 tiles (1024 nodes each)
21/// 2. Running the optimized Stride-32 decoder on each tile
22/// 3. Stitching tiles together at boundaries using Union Find
23///
24/// # When to Use
25///
26/// Use `TiledDecodingState` for grids larger than 64x64 nodes. For smaller grids,
27/// the standard [`DecodingState`] is more efficient.
28///
29/// # Tile Layout
30///
31/// ```text
32/// Global grid (96x64):
33/// +--------+--------+--------+
34/// | Tile 0 | Tile 1 | Tile 2 |  (each 32x32)
35/// +--------+--------+--------+
36/// | Tile 3 | Tile 4 | Tile 5 |
37/// +--------+--------+--------+
38/// ```
39///
40/// Each tile contains 16 blocks of 64 nodes, arranged in Stride-32 format.
41pub struct TiledDecodingState<'a, T: Topology> {
42    /// Total grid width in nodes.
43    pub width: usize,
44    /// Total grid height in nodes.
45    pub height: usize,
46
47    // Tiling configuration
48    /// Number of tiles in X direction.
49    pub tiles_x: usize,
50    /// Number of tiles in Y direction.
51    pub tiles_y: usize,
52
53    // We hold the raw memory, but organized in tiles.
54    // Each tile has 16 blocks (1024 nodes).
55    // Total blocks = tiles_x * tiles_y * 16.
56    /// Block state for all tiles (16 blocks per tile).
57    pub blocks_state: &'a mut [BlockStateHot],
58    /// Union Find parent pointers for all nodes.
59    pub parents: &'a mut [u32],
60
61    // Auxiliary state needed for DecodingState
62    /// Defect mask per block.
63    pub defect_mask: &'a mut [u64],
64    /// Path marking for peeling.
65    pub path_mark: &'a mut [u64],
66    /// Dirty block tracking for sparse reset.
67    pub block_dirty_mask: &'a mut [u64],
68    /// Currently active blocks.
69    pub active_mask: &'a mut [u64],
70    /// Blocks queued for next iteration.
71    pub queued_mask: &'a mut [u64],
72    /// Syndrome ingestion worklist.
73    pub ingestion_list: &'a mut [u32],
74    /// Edge correction bitmap.
75    pub edge_bitmap: &'a mut [u64],
76    /// Dirty edge word list.
77    pub edge_dirty_list: &'a mut [u32],
78    /// Dirty edge mask.
79    pub edge_dirty_mask: &'a mut [u64],
80    /// Boundary correction bitmap.
81    pub boundary_bitmap: &'a mut [u64],
82    /// Dirty boundary block list.
83    pub boundary_dirty_list: &'a mut [u32],
84    /// Dirty boundary mask.
85    pub boundary_dirty_mask: &'a mut [u64],
86    /// BFS predecessor array for path tracing.
87    pub bfs_pred: &'a mut [u16],
88    /// BFS queue for path tracing.
89    pub bfs_queue: &'a mut [u16],
90
91    // Static graph for tiles (shared because all tiles are 32x32 Stride 32)
92    /// Static graph metadata for 32x32 tiles (shared by all tiles).
93    pub tile_graph: &'a crate::decoder::graph::StaticGraph,
94
95    /// Syndrome ingestion count.
96    pub ingestion_count: usize,
97    /// Active block mask (for small-grid compatibility).
98    pub active_block_mask: u64,
99    /// Edge dirty count.
100    pub edge_dirty_count: usize,
101    /// Boundary dirty count.
102    pub boundary_dirty_count: usize,
103
104    /// Phantom marker for topology type.
105    pub _marker: core::marker::PhantomData<T>,
106}
107
108impl<'a, T: Topology> TiledDecodingState<'a, T> {
109    /// Creates a new tiled decoder for the given grid dimensions.
110    ///
111    /// # Arguments
112    ///
113    /// * `arena` - Arena allocator for all internal allocations.
114    /// * `width` - Total grid width in nodes.
115    /// * `height` - Total grid height in nodes.
116    ///
117    /// # Tile Calculation
118    ///
119    /// The grid is divided into ceil(width/32) x ceil(height/32) tiles.
120    /// Each tile contains 1024 nodes (32x32) in 16 blocks.
121    pub fn new(arena: &mut Arena<'a>, width: usize, height: usize) -> Self {
122        let tiles_x = width.div_ceil(32);
123        let tiles_y = height.div_ceil(32);
124        let num_tiles = tiles_x * tiles_y;
125
126        let blocks_per_tile = 16;
127        let nodes_per_tile = 1024;
128
129        let total_blocks = num_tiles * blocks_per_tile;
130        let total_nodes = num_tiles * nodes_per_tile + 1; // +1 for boundary
131
132        // Allocate Memory
133        let blocks_state = arena
134            .alloc_slice_aligned::<BlockStateHot>(total_blocks, 64)
135            .unwrap();
136        let parents = arena.alloc_slice_aligned::<u32>(total_nodes, 64).unwrap();
137
138        let defect_mask = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
139        let path_mark = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
140
141        let block_dirty_mask = arena
142            .alloc_slice_aligned::<u64>(total_blocks.div_ceil(64), 64)
143            .unwrap();
144
145        let num_bitmask_words = total_blocks.div_ceil(64);
146        let active_mask = arena
147            .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
148            .unwrap();
149        let queued_mask = arena
150            .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
151            .unwrap();
152
153        let ingestion_list = arena.alloc_slice::<u32>(total_blocks).unwrap();
154
155        // Edge/Boundary tracking (sized for total capacity)
156        let num_edges = total_nodes * 3;
157        let num_edge_words = num_edges.div_ceil(64);
158        let edge_bitmap = arena
159            .alloc_slice_aligned::<u64>(num_edge_words, 64)
160            .unwrap();
161        let edge_dirty_list = arena.alloc_slice::<u32>(num_edge_words * 8).unwrap();
162        let edge_dirty_mask = arena
163            .alloc_slice_aligned::<u64>(num_edge_words.div_ceil(64), 64)
164            .unwrap();
165
166        let boundary_bitmap = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
167        let boundary_dirty_list = arena.alloc_slice::<u32>(total_blocks * 8).unwrap();
168        let boundary_dirty_mask = arena
169            .alloc_slice_aligned::<u64>(total_blocks.div_ceil(64), 64)
170            .unwrap();
171
172        let bfs_pred = arena.alloc_slice::<u16>(total_nodes).unwrap();
173        let bfs_queue = arena.alloc_slice::<u16>(total_nodes).unwrap();
174
175        // Create a StaticGraph for a 32x32 tile (Stride 32)
176        // Neighbor traversal uses SWAR bit operations (spread_syndrome_*) rather than lookup tables
177        let tile_graph = crate::decoder::graph::StaticGraph {
178            width: 32,
179            height: 32,
180            depth: 1,
181            stride_x: 1,
182            stride_y: 32,
183            stride_z: 1024,
184            blk_stride_y: 0, // Not used for Stride 32
185            shift_y: 5,
186            shift_z: 10,
187            row_end_mask: 0x8000000080000000,
188            row_start_mask: 0x0000000100000001,
189        };
190        let tile_graph_ref = arena.alloc_value(tile_graph).unwrap();
191
192        let mut state = Self {
193            width,
194            height,
195            tiles_x,
196            tiles_y,
197            blocks_state,
198            parents,
199            defect_mask,
200            path_mark,
201            block_dirty_mask,
202            active_mask,
203            queued_mask,
204            ingestion_list,
205            edge_bitmap,
206            edge_dirty_list,
207            edge_dirty_mask,
208            boundary_bitmap,
209            boundary_dirty_list,
210            boundary_dirty_mask,
211            bfs_pred,
212            bfs_queue,
213            tile_graph: tile_graph_ref,
214            ingestion_count: 0,
215            active_block_mask: 0,
216            edge_dirty_count: 0,
217            boundary_dirty_count: 0,
218            _marker: core::marker::PhantomData,
219        };
220
221        state.initialize();
222        state
223    }
224
225    /// Initializes or reinitializes the decoder state.
226    ///
227    /// Sets up valid masks for all tiles based on actual grid dimensions
228    /// and resets all dynamic state.
229    pub fn initialize(&mut self) {
230        for block in self.blocks_state.iter_mut() {
231            *block = BlockStateHot::default();
232        }
233
234        // Initialize parents
235        for (i, p) in self.parents.iter_mut().enumerate() {
236            *p = i as u32;
237        }
238
239        // Initialize valid masks based on actual width/height
240        let _total_blocks = self.tiles_x * self.tiles_y * 16;
241
242        for ty in 0..self.tiles_y {
243            for tx in 0..self.tiles_x {
244                let tile_idx = ty * self.tiles_x + tx;
245                let block_offset = tile_idx * 16;
246
247                // Determine valid region for this tile
248                let tile_base_x = tx * 32;
249                let tile_base_y = ty * 32;
250
251                for i in 0..1024 {
252                    let lx = i % 32;
253                    let ly = i / 32;
254
255                    let gx = tile_base_x + lx;
256                    let gy = tile_base_y + ly;
257
258                    if gx < self.width && gy < self.height {
259                        let blk = block_offset + (i / 64);
260                        let bit = i % 64;
261                        self.blocks_state[blk].valid_mask |= 1 << bit;
262                    }
263                }
264            }
265        }
266
267        for block in self.blocks_state.iter_mut() {
268            let valid = block.valid_mask;
269            block.effective_mask = valid;
270            if valid == !0 {
271                block.flags |= FLAG_VALID_FULL;
272            }
273        }
274    }
275
276    /// Resets only modified blocks for efficient reuse.
277    ///
278    /// Similar to [`DecodingState::sparse_reset`], this only resets blocks
279    /// that were touched during the previous decoding cycle.
280    pub fn sparse_reset(&mut self) {
281        for (word_idx, word_ref) in self.block_dirty_mask.iter_mut().enumerate() {
282            let mut w = *word_ref;
283            *word_ref = 0;
284            while w != 0 {
285                let bit = w.trailing_zeros();
286                w &= w - 1;
287                let blk_idx = word_idx * 64 + bit as usize;
288
289                unsafe {
290                    let block = self.blocks_state.get_unchecked_mut(blk_idx);
291                    block.boundary = 0;
292                    block.occupied = 0;
293                    block.root = u32::MAX;
294                    *self.defect_mask.get_unchecked_mut(blk_idx) = 0;
295                }
296
297                // Reset parents for this block
298                // But parents are indexed globally?
299                // In TiledDecodingState, parents are laid out in Tile Major order.
300                // So Block 0 corresponds to parents 0..63.
301                // It matches simply.
302                let start_node = blk_idx * 64;
303                let end_node = (start_node + 64).min(self.parents.len());
304                for node in start_node..end_node {
305                    unsafe {
306                        *self.parents.get_unchecked_mut(node) = node as u32;
307                    }
308                }
309            }
310        }
311
312        self.queued_mask.fill(0);
313        self.active_mask.fill(0);
314
315        let boundary_idx = self.parents.len() - 1;
316        self.parents[boundary_idx] = boundary_idx as u32;
317    }
318
319    /// Loads syndrome measurements from a dense row-major bitarray.
320    ///
321    /// Converts from row-major input format to the internal tiled format.
322    ///
323    /// # Arguments
324    ///
325    /// * `syndromes` - Dense bitarray in row-major order with power-of-2 stride.
326    pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
327        // Input `syndromes` is Row-Major (Stride = Width.next_power_of_two()).
328        // We need to map to Tiled Layout.
329
330        let max_dim = self.width.max(self.height);
331        let stride_y = max_dim.next_power_of_two();
332        let stride_shift = stride_y.trailing_zeros();
333        let stride_mask = stride_y - 1;
334        
335        let _blk_stride = stride_y / 64; // Blocks per row in input
336
337        // This is slow (scalar), but it's just loading.
338        // Iterate over set bits in syndromes and map to tiled.
339
340        for (blk_idx, &word) in syndromes.iter().enumerate() {
341            if word == 0 {
342                continue;
343            }
344
345            // Input Block covers 64 bits. Stride is `stride_y`.
346            // Wait, input format depends on `generate_defects`.
347            // In growth_bench, stride is power of 2.
348            // Let's assume input matches that.
349
350            // Map each bit
351            let mut w = word;
352            while w != 0 {
353                let bit = w.trailing_zeros();
354                w &= w - 1;
355
356                let global_idx = blk_idx * 64 + bit as usize;
357                let gy = global_idx >> stride_shift;
358                let gx = global_idx & stride_mask;
359
360                if gx >= self.width || gy >= self.height {
361                    continue;
362                }
363
364                let tx = gx / 32;
365                let ty = gy / 32;
366                let lx = gx % 32;
367                let ly = gy % 32;
368
369                let tile_idx = ty * self.tiles_x + tx;
370                let local_idx = ly * 32 + lx;
371
372                let target_blk = tile_idx * 16 + (local_idx / 64);
373                let target_bit = local_idx % 64;
374
375                // Mark dirty
376                let mask_idx = target_blk >> 6;
377                let mask_bit = target_blk & 63;
378                unsafe {
379                    *self.block_dirty_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
380                    let block = self.blocks_state.get_unchecked_mut(target_blk);
381                    block.boundary |= 1 << target_bit;
382                    block.occupied |= 1 << target_bit;
383                    *self.defect_mask.get_unchecked_mut(target_blk) |= 1 << target_bit;
384
385                    // Mark active
386                    let active_word = target_blk >> 6;
387                    let active_bit = target_blk & 63;
388                    *self.active_mask.get_unchecked_mut(active_word) |= 1 << active_bit;
389                }
390            }
391        }
392    }
393
394    /// Expands cluster boundaries until convergence.
395    ///
396    /// Processes tiles in two phases:
397    /// 1. **Intra-tile growth**: Run Stride-32 decoder within each tile
398    /// 2. **Inter-tile stitching**: Connect clusters across tile boundaries
399    pub fn grow_clusters(&mut self) {
400        let max_cycles = self.width.max(self.height) * 16;
401
402        // Extract raw pointers to circumvent borrow checker in the loop.
403        // Safety: We manage disjoint access patterns manually.
404        let active_mask_ptr = self.active_mask.as_mut_ptr();
405
406        for _ in 0..max_cycles {
407            if self.active_mask.iter().all(|&w| w == 0) {
408                break;
409            }
410
411            // 1. Intra-Tile Growth (Phase 1)
412            // Iterate over all tiles
413            for ty in 0..self.tiles_y {
414                for tx in 0..self.tiles_x {
415                    let tile_idx = ty * self.tiles_x + tx;
416
417                    // Check if tile has active blocks
418                    let start_blk = tile_idx * 16;
419                    let end_blk = start_blk + 16;
420
421                    // Quick check if tile is active using raw pointer to avoid full borrow
422                    let mut tile_active = false;
423                    for blk in start_blk..end_blk {
424                        let word_idx = blk >> 6;
425                        let bit = blk & 63;
426                        unsafe {
427                            if (*active_mask_ptr.add(word_idx) & (1 << bit)) != 0 {
428                                tile_active = true;
429                                break;
430                            }
431                        }
432                    }
433
434                    if !tile_active {
435                        continue;
436                    }
437
438                    unsafe {
439                        self.process_tile_unsafe(tile_idx, tx, ty);
440                    }
441                }
442            }
443
444            // 2. Inter-Tile Stitching (Phase 2)
445            unsafe {
446                self.stitch_tiles();
447            }
448
449            // Swap queues
450            core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
451            self.queued_mask.fill(0);
452        }
453    }
454
455    unsafe fn process_tile_unsafe(&mut self, tile_idx: usize, tx: usize, ty: usize) {
456        let block_offset = tile_idx * 16;
457        let parent_offset = tile_idx * 1024; // 32x32 nodes
458
459        // Prepare Pointers for DecodingState reconstruction
460        // We cast to lifetime 'a effectively.
461        let blocks_ptr = self.blocks_state.as_mut_ptr().add(block_offset);
462        let blocks_slice = core::slice::from_raw_parts_mut(blocks_ptr, 16);
463
464        // Pass full parents slice.
465        // We calculate parents_len to be safe, but decoder uses len() - 1 for boundary.
466        // We want decoder.parents[decoder.parents.len()-1] to be the Global Boundary.
467        // So we pass the full slice from index 0.
468        // Wait, self.parents is contiguous.
469        let parents_ptr = self.parents.as_mut_ptr();
470        let parents_len = self.parents.len();
471        let parents_slice = core::slice::from_raw_parts_mut(parents_ptr, parents_len);
472
473        let queued_mask_slice = core::slice::from_raw_parts_mut(self.queued_mask.as_mut_ptr(), self.queued_mask.len());
474        let dirty_mask_slice = core::slice::from_raw_parts_mut(self.block_dirty_mask.as_mut_ptr(), self.block_dirty_mask.len());
475
476        // Disjoint slices for auxiliary
477        // These are not actually used by process_block logic except dirty_masks.
478        // We pass empty slices for unused ones.
479
480        // Config
481        let config = BoundaryConfig {
482            check_left: tx == 0,
483            check_right: tx == self.tiles_x - 1,
484            check_top: ty == 0,
485            check_bottom: ty == self.tiles_y - 1,
486        };
487
488        // Create manual DecodingState
489        // Note: active_mask/queued_mask are passed but we will override them inside the loop
490        // with stack buffers to avoid global indexing confusion.
491        let mut decoder = DecodingState::<T, 32> {
492            graph: self.tile_graph,
493            width: 32,
494            height: 32,
495            stride_y: 32,
496            row_start_mask: 0x0000000100000001,
497            row_end_mask: 0x8000000080000000,
498
499            blocks_state: blocks_slice,
500            parents: parents_slice,
501
502            defect_mask: &mut [],
503            path_mark: &mut [],
504            block_dirty_mask: dirty_mask_slice, // Global
505
506            active_mask: &mut [],
507            queued_mask: queued_mask_slice, // Global
508            active_block_mask: 0,
509
510            ingestion_list: &mut [],
511            ingestion_count: 0,
512
513            edge_bitmap: &mut [],
514            edge_dirty_list: &mut [],
515            edge_dirty_count: 0,
516            edge_dirty_mask: &mut [],
517
518            boundary_bitmap: &mut [],
519            boundary_dirty_list: &mut [],
520            boundary_dirty_count: 0,
521            boundary_dirty_mask: &mut [],
522
523            bfs_pred: &mut [],
524            bfs_queue: &mut [],
525
526            needs_scalar_fallback: false,
527            scalar_fallback_mask: 0,
528            boundary_config: config,
529            parent_offset,
530            _marker: core::marker::PhantomData,
531        };
532
533        // Pointers for global masks to merge back results
534        // let global_queued_ptr = self.queued_mask.as_mut_ptr();
535        // let global_dirty_ptr = self.block_dirty_mask.as_mut_ptr();
536        let global_active_ptr = self.active_mask.as_ptr(); // Read-only access to verify activation
537
538        // Loop over tile blocks
539        let start_blk = tile_idx * 16;
540        for i in 0..16 {
541            let global_blk = start_blk + i;
542            let word_idx = global_blk >> 6;
543            let bit = global_blk & 63;
544
545            let is_active = (*global_active_ptr.add(word_idx) & (1 << bit)) != 0;
546            if is_active {
547                // decoder.process_block writes directly to global masks using offset logic in portable_32.rs
548                decoder.process_block(i);
549            }
550        }
551    }
552
553    unsafe fn stitch_tiles(&mut self) {
554        // Generic Stitching using T::for_each_neighbor
555        let blocks_ptr = self.blocks_state.as_mut_ptr();
556
557        for ty in 0..self.tiles_y {
558            for tx in 0..self.tiles_x {
559                let tile_idx = ty * self.tiles_x + tx;
560                let base_gx = tx * 32;
561                let base_gy = ty * 32;
562
563                // Iterate boundary pixels of this tile
564                // Boundaries: Row 0, Row 31, Col 0, Col 31.
565                // Duplicate checks are fine (UnionFind handles it).
566                
567                // Helper closure to process a pixel
568                let mut process_pixel = |lx: usize, ly: usize| {
569                    let gx = base_gx + lx;
570                    let gy = base_gy + ly;
571                    
572                    if gx >= self.width || gy >= self.height {
573                        return;
574                    }
575
576                    // My Global Morton Index
577                    let m_idx = morton_encode_2d(gx as u32, gy as u32);
578                    
579                    // My Tiled Node Index
580                    let my_blk_offset = tile_idx * 16;
581                    let my_local = ly * 32 + lx;
582                    let my_node = (tile_idx * 1024) + my_local;
583                    
584                    // Check if I am active
585                    let my_blk = my_blk_offset + (my_local / 64);
586                    let my_bit = my_local % 64;
587                    let my_block = &*blocks_ptr.add(my_blk);
588                    let my_active = (my_block.occupied & (1 << my_bit)) != 0;
589
590                    // Iterate Neighbors
591                    T::for_each_neighbor(m_idx, |n_m| {
592                         let (nx, ny) = morton_decode_2d(n_m);
593                         // Check if neighbor is out of bounds or in DIFFERENT tile
594                         if nx as usize >= self.width || ny as usize >= self.height {
595                             return;
596                         }
597
598                         let n_tx = (nx / 32) as usize;
599                         let n_ty = (ny / 32) as usize;
600                         
601                         // We only stitch inter-tile edges
602                         if n_tx == tx && n_ty == ty {
603                             return;
604                         }
605                         
606                         // Enforce order to avoid double locking/processing if possible
607                         // (u, v) vs (v, u).
608                         // Simple check: process only if neighbor is "larger" coordinate?
609                         // But n_m vs m_idx works.
610                         if n_m < m_idx {
611                             return;
612                         }
613
614                         let n_tile_idx = n_ty * self.tiles_x + n_tx;
615                         let n_lx = (nx % 32) as usize;
616                         let n_ly = (ny % 32) as usize;
617                         
618                         let n_local = n_ly * 32 + n_lx;
619                         let n_blk = n_tile_idx * 16 + (n_local / 64);
620                         let n_bit = n_local % 64;
621                         
622                         let n_block = &*blocks_ptr.add(n_blk);
623                         let n_active = (n_block.occupied & (1 << n_bit)) != 0;
624                         
625                         if my_active || n_active {
626                             let n_node = (n_tile_idx * 1024) + n_local;
627                             
628                             if self.union(my_node as u32, n_node as u32) {
629                                 if !my_active {
630                                     (*blocks_ptr.add(my_blk)).occupied |= 1 << my_bit;
631                                     self.mark_global_dirty(my_blk);
632                                 }
633                                 if !n_active {
634                                     (*blocks_ptr.add(n_blk)).occupied |= 1 << n_bit;
635                                     self.mark_global_dirty(n_blk);
636                                 }
637                             }
638                         }
639                    });
640                };
641
642                // Top & Bottom Rows
643                for x in 0..32 {
644                    process_pixel(x, 0);
645                    process_pixel(x, 31);
646                }
647                // Left & Right Cols (excluding corners processed above? No, process all)
648                for y in 1..31 {
649                    process_pixel(0, y);
650                    process_pixel(31, y);
651                }
652            }
653        }
654    }
655
656    #[inline(always)]
657    fn mark_global_dirty(&mut self, blk_idx: usize) {
658        let mask_idx = blk_idx >> 6;
659        let mask_bit = blk_idx & 63;
660        self.queued_mask[mask_idx] |= 1 << mask_bit;
661        self.block_dirty_mask[mask_idx] |= 1 << mask_bit;
662    }
663
664    // Union Find Helpers (Copied from state.rs/union_find.rs or forwarded)
665
666    /// Finds the cluster root for node `i` with path compression.
667    ///
668    /// Uses path halving compression for O(α(n)) amortized complexity.
669    pub fn find(&mut self, mut i: u32) -> u32 {
670        unsafe {
671            let mut p = *self.parents.get_unchecked(i as usize);
672            if p != i {
673                let mut gp = *self.parents.get_unchecked(p as usize);
674                if gp != p {
675                    loop {
676                        *self.parents.get_unchecked_mut(i as usize) = gp;
677                        i = gp;
678                        p = *self.parents.get_unchecked(gp as usize);
679                        if p == gp {
680                            break;
681                        }
682                        gp = *self.parents.get_unchecked(p as usize);
683                        if gp == p {
684                            *self.parents.get_unchecked_mut(i as usize) = p;
685                            break;
686                        }
687                    }
688                }
689                return gp;
690            }
691            i
692        }
693    }
694
695    /// Merges clusters containing nodes `u` and `v`.
696    ///
697    /// Returns `true` if clusters were different and got merged.
698    pub fn union(&mut self, u: u32, v: u32) -> bool {
699        let root_u = self.find(u);
700        let root_v = self.find(v);
701        if root_u != root_v {
702            let (small, big) = if root_u < root_v {
703                (root_u, root_v)
704            } else {
705                (root_v, root_u)
706            };
707            unsafe {
708                *self.parents.get_unchecked_mut(small as usize) = big;
709                let tile = small as usize / 1024;
710                if tile < self.tiles_x * self.tiles_y {
711                    let local = small as usize % 1024;
712                    let blk = tile * 16 + (local / 64);
713                    self.mark_global_dirty(blk);
714                }
715            }
716            return true;
717        }
718        false
719    }
720
721    /// Extracts edge corrections from grown clusters.
722    ///
723    /// Traces paths from syndrome nodes through the Union Find forest,
724    /// converting logical edges to physical corrections.
725    ///
726    /// # Arguments
727    ///
728    /// * `corrections` - Output buffer for edge corrections.
729    ///
730    /// # Returns
731    ///
732    /// Number of corrections written to the buffer.
733    pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
734        // Clear auxiliary buffers
735        self.path_mark.fill(0);
736
737        // 1. Identify Syndromes and Trace Paths to Root/Boundary
738        // Iterate over defect mask to find syndromes
739        for blk_idx in 0..self.defect_mask.len() {
740            let mut word = unsafe { *self.defect_mask.get_unchecked(blk_idx) };
741            if word == 0 {
742                continue;
743            }
744
745            let base_node = blk_idx * 64;
746            while word != 0 {
747                let bit = tzcnt(word) as usize;
748                word &= word - 1;
749                let u = (base_node + bit) as u32;
750
751                // Trace path from syndrome u to root/boundary
752                self.trace_path(u);
753            }
754        }
755
756        // 2. Process Marked Paths (Peeling)
757        // Iterate over path_mark. If a node u is marked, it means the edge (u, parent[u]) is part of an odd number of paths.
758        // We need to 'realize' this logical edge as physical edges (corrections).
759        for blk_idx in 0..self.path_mark.len() {
760            let mut word = unsafe { *self.path_mark.get_unchecked(blk_idx) };
761            if word == 0 {
762                continue;
763            }
764
765            let base_node = blk_idx * 64;
766            while word != 0 {
767                let bit = tzcnt(word) as usize;
768                word &= word - 1;
769                let u = (base_node + bit) as u32;
770
771                let v = unsafe { *self.parents.get_unchecked(u as usize) };
772
773                if u != v {
774                    self.trace_manhattan_tiled(u, v);
775                }
776            }
777        }
778
779        // 3. Collect Corrections from dirty edges
780        self.reconstruct_tiled_corrections(corrections)
781    }
782
783    fn trace_path(&mut self, u: u32) {
784        let mut curr = u;
785        loop {
786            let next = unsafe { *self.parents.get_unchecked(curr as usize) };
787            if curr == next {
788                break;
789            }
790
791            let blk = (curr as usize) / 64;
792            let bit = (curr as usize) % 64;
793
794            unsafe {
795                *self.path_mark.get_unchecked_mut(blk) ^= 1 << bit;
796            }
797            curr = next;
798        }
799    }
800
801    fn trace_manhattan_tiled(&mut self, u: u32, v: u32) {
802        let boundary_node = (self.parents.len() - 1) as u32;
803
804        if u == boundary_node {
805            self.emit_tiled_edge(v, u32::MAX);
806            return;
807        }
808        if v == boundary_node {
809            self.emit_tiled_edge(u, u32::MAX);
810            return;
811        }
812
813        let (ux, uy) = self.get_global_coord(u);
814        let (vx, vy) = self.get_global_coord(v);
815
816        // Simple Manhattan routing: Move X, then Move Y.
817        // We assume the path is valid in the topology or at least sufficient for correction.
818        // For Square/Rect: Always valid.
819        // For Triangular/Honeycomb: Manhattan path exists (using only cardinal moves).
820        
821        let dx = (vx as isize) - (ux as isize);
822        let dy = (vy as isize) - (uy as isize);
823
824        let mut curr = u;
825        let mut curr_x = ux as isize;
826        let mut curr_y = uy as isize;
827
828        // Move in X
829        if dx != 0 {
830            let step = if dx > 0 { 1 } else { -1 };
831            let steps = dx.abs();
832            for _ in 0..steps {
833                let next_x = curr_x + step;
834                if let Some(next) = self.get_node_idx(next_x as usize, curr_y as usize) {
835                    self.emit_tiled_edge(curr, next);
836                    curr = next;
837                    curr_x = next_x;
838                } else {
839                    break; // Should not happen if bounds correct
840                }
841            }
842        }
843
844        // Move in Y
845        if dy != 0 {
846            let step = if dy > 0 { 1 } else { -1 };
847            let steps = dy.abs();
848            for _ in 0..steps {
849                let next_y = curr_y + step;
850                if let Some(next) = self.get_node_idx(curr_x as usize, next_y as usize) {
851                    self.emit_tiled_edge(curr, next);
852                    curr = next;
853                    curr_y = next_y;
854                } else {
855                    break;
856                }
857            }
858        }
859    }
860
861    fn get_node_idx(&self, x: usize, y: usize) -> Option<u32> {
862        if x < self.width && y < self.height {
863            let tx = x / 32;
864            let ty = y / 32;
865            let lx = x % 32;
866            let ly = y % 32;
867            let node = (ty * self.tiles_x + tx) * 1024 + (ly * 32 + lx);
868            Some(node as u32)
869        } else {
870            None
871        }
872    }
873
874    fn get_global_coord(&self, u: u32) -> (usize, usize) {
875        let tile_idx = (u as usize) / 1024;
876        let local_idx = (u as usize) % 1024;
877
878        let tx = tile_idx % self.tiles_x;
879        let ty = tile_idx / self.tiles_x;
880
881        let lx = local_idx % 32;
882        let ly = local_idx / 32;
883
884        (tx * 32 + lx, ty * 32 + ly)
885    }
886
887    fn emit_tiled_edge(&mut self, u: u32, v: u32) {
888        if v == u32::MAX {
889            // Boundary edge
890            let blk_idx = (u as usize) / 64;
891            let bit_idx = (u as usize) % 64;
892
893            let mask_idx = blk_idx >> 6;
894            let mask_bit = blk_idx & 63;
895
896            unsafe {
897                let m_ptr = self.boundary_dirty_mask.get_unchecked_mut(mask_idx);
898                if (*m_ptr & (1 << mask_bit)) == 0 {
899                    *m_ptr |= 1 << mask_bit;
900                    *self
901                        .boundary_dirty_list
902                        .get_unchecked_mut(self.boundary_dirty_count) = blk_idx as u32;
903                    self.boundary_dirty_count += 1;
904                }
905                *self.boundary_bitmap.get_unchecked_mut(blk_idx) ^= 1 << bit_idx;
906            }
907            return;
908        }
909
910        // Regular edge
911        // We need to determine an edge index for (u, v).
912        // Canonical order: u < v.
913        let (src, dst) = if u < v { (u, v) } else { (v, u) };
914
915        // Determine 'dir' (0..2) relative to src.
916        // We need to find dst in src's neighbors and get its index?
917        // Or calculate relative position.
918        let (ux, uy) = self.get_global_coord(src);
919        let (vx, vy) = self.get_global_coord(dst);
920        
921        // Differences (dst - src)
922        // Since src < dst (mostly), and layout is Row-Major...
923        // vy >= uy.
924        // If vy == uy, vx > ux (Right).
925        // If vy > uy, vx could be anything.
926        
927        let dx = (vx as isize) - (ux as isize);
928        let dy = (vy as isize) - (uy as isize);
929
930        // TriangularGrid only has ONE diagonal type per node, so both diagonal
931        // directions (Down-Right and Down-Left) map to slot 2.
932        let dir = if dy == 0 && dx == 1 {
933            0 // Right
934        } else if dy == 1 && dx == 0 {
935            1 // Down
936        } else if dy == 1 && (dx == 1 || dx == -1) {
937            2 // Diagonal (slot 2 for either direction)
938        } else {
939            // Should not happen for handled topologies
940            0
941        };
942
943        // Edge Index
944        let edge_idx = (src as usize) * 3 + dir;
945        let word_idx = edge_idx / 64;
946        let bit_idx = edge_idx % 64;
947
948        let mask_idx = word_idx >> 6;
949        let mask_bit = word_idx & 63;
950
951        unsafe {
952            let m_ptr = self.edge_dirty_mask.get_unchecked_mut(mask_idx);
953            if (*m_ptr & (1 << mask_bit)) == 0 {
954                *m_ptr |= 1 << mask_bit;
955                *self
956                    .edge_dirty_list
957                    .get_unchecked_mut(self.edge_dirty_count) = word_idx as u32;
958                self.edge_dirty_count += 1;
959            }
960            *self.edge_bitmap.get_unchecked_mut(word_idx) ^= 1 << bit_idx;
961        }
962    }
963
964    fn reconstruct_tiled_corrections(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
965        let mut count = 0;
966
967        // Process Dirty Edges
968        for i in 0..self.edge_dirty_count {
969            let word_idx = unsafe { *self.edge_dirty_list.get_unchecked(i) } as usize;
970
971            // Clear mask
972            let mask_idx = word_idx >> 6;
973            let mask_bit = word_idx & 63;
974            unsafe {
975                *self.edge_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
976            }
977
978            let word_ptr = unsafe { self.edge_bitmap.get_unchecked_mut(word_idx) };
979            let mut word = *word_ptr;
980            *word_ptr = 0;
981
982            let base_idx = word_idx * 64;
983            while word != 0 {
984                let bit = tzcnt(word) as usize;
985                word &= word - 1;
986
987                let edge_idx = base_idx + bit;
988                let u = (edge_idx / 3) as u32;
989                let dir = edge_idx % 3;
990                
991                let (ux, uy) = self.get_global_coord(u);
992                
993                // Recover v from u and dir
994                let (vx, vy) = match dir {
995                    0 => (ux + 1, uy),
996                    1 => (ux, uy + 1),
997                    2 => {
998                         // Determine diagonal direction based on Topology?
999                         // Or assume fixed diagonal for now?
1000                         // TriangularGrid logic:
1001                         // if (idx.count_ones() & 1) == 0 -> Up-Right.
1002                         // But we are at 'u' looking for neighbor 'v' such that u < v.
1003                         // Neighbors of u: Right, Down, maybe Diag.
1004                         // If u has Down-Right diagonal: (ux+1, uy+1).
1005                         // If u has Down-Left diagonal: (ux-1, uy+1).
1006                         // We need T to know which one.
1007                         // BUT we are generic.
1008                         // We can assume for TriangularGrid, Diag is always "the diagonal".
1009                         // However, T::for_each_neighbor is the truth.
1010                         // If we assume TriangularGrid matches `prav-core/src/topology.rs`:
1011                         // Parity check uses Morton index.
1012                         let m_idx = morton_encode_2d(ux as u32, uy as u32);
1013                         if (m_idx.count_ones() & 1) == 0 {
1014                             // "Has Right and Up". Diag is Up-Right. (x+1, y-1).
1015                             // But v must be > u. (y-1) < y. So u > v.
1016                             // This edge would be stored at v (the smaller node).
1017                             // So if we are at u, we store edges to v > u.
1018                             // For this node u, valid v > u neighbors are:
1019                             // Right (x+1, y). Down (x, y+1).
1020                             // Does it have Down-Right or Down-Left?
1021                             // If neighbor w (Down-Left) exists, w < u? No, w.y > u.y. So w > u.
1022                             // So Down-Left is a valid forward edge.
1023                             // If neighbor z (Down-Right) exists, z > u.
1024                             
1025                             // We need to know which one 'dir=2' represents.
1026                             // In emit_tiled_edge:
1027                             // dy=1, dx=1 -> Down-Right.
1028                             // dy=1, dx=-1 -> Down-Left.
1029                             // Both mapped to dir=2.
1030                             // This implies a node has AT MOST one "forward diagonal" neighbor?
1031                             // TriangularGrid: 6 neighbors.
1032                             // Left, Right, Up, Down.
1033                             // Plus ONE diagonal pair.
1034                             // If Parity 0: Diag is / (Up-Right, Down-Left).
1035                             // Forward neighbors (>u): Right, Down, Down-Left.
1036                             // If Parity 1: Diag is \ (Up-Left, Down-Right).
1037                             // Forward neighbors (>u): Right, Down, Down-Right.
1038                             
1039                             // So yes! Only one "Down" diagonal per node.
1040                             // Parity 0 -> Down-Left.
1041                             // Parity 1 -> Down-Right.
1042                             if (m_idx.count_ones() & 1) == 0 {
1043                                 (ux.wrapping_sub(1), uy + 1)
1044                             } else {
1045                                 (ux + 1, uy + 1)
1046                             }
1047                         } else {
1048                             // Copy-paste error in reasoning above?
1049                             // Re-read Topology.rs:
1050                             // if parity 0: if has_right && has_up { dec(right, Y) -> (x+1, y-1) }
1051                             //    This is Up-Right.
1052                             //    Does it have Down-Left?
1053                             //    No, "else if parity 1: if has_left && has_down { inc(left, Y) -> (x-1, y+1) }"
1054                             //    This is Down-Left.
1055                             
1056                             // So Node Parity 0 has Up-Right.
1057                             // Node Parity 1 has Down-Left.
1058                             
1059                             // Let's verify reciprocal.
1060                             // Edge (u, v). u < v.
1061                             // v is "Down" relative to u (mostly).
1062                             // If Edge is Up-Right from u? v = (x+1, y-1).
1063                             // v.y < u.y. So v < u (mostly).
1064                             // So Up-Right is a BACKWARD edge.
1065                             // We don't store it at u. We store it at v.
1066                             
1067                             // If Edge is Down-Left from u? v = (x-1, y+1).
1068                             // v.y > u.y. So v > u.
1069                             // So Down-Left is a FORWARD edge.
1070                             // Does Parity 0 have Down-Left? No.
1071                             // Does Parity 1 have Down-Left? Yes.
1072                             
1073                             // So Parity 1 nodes have a generic "Diagonal Forward" (Down-Left).
1074                             // What about Parity 0 nodes?
1075                             // They have Up-Right (Backward).
1076                             // Do they have Down-Right? No.
1077                             // So Parity 0 nodes have NO Forward Diagonal?
1078                             // Wait.
1079                             // Triangular grid is connected.
1080                             // Edges are undirected.
1081                             // Edge between A(0,0) and B(1,1)?
1082                             // A=0 (Parity 0). B=3 (Parity 0).
1083                             // A has Up-Right? No (y=0).
1084                             // B has Up-Right? (2, 0).
1085                             // This assumes specific layout.
1086                             
1087                             // Let's rely on coordinates from emit_tiled_edge logic.
1088                             // We map dir=2 to "The Diagonal".
1089                             // But reconstructing requires knowing WHICH diagonal.
1090                             // We can use the Parity logic again.
1091                             
1092                             if (m_idx.count_ones() & 1) != 0 {
1093                                 // Parity 1 has Down-Left (Forward).
1094                                 (ux.wrapping_sub(1), uy + 1)
1095                             } else {
1096                                 // Parity 0 has... NO forward diagonal?
1097                                 // Check neighbors of Parity 0.
1098                                 // Right (x+1, y) -> >u.
1099                                 // Down (x, y+1) -> >u.
1100                                 // Up-Right (x+1, y-1) -> <u.
1101                                 // Left, Up -> <u.
1102                                 // So Parity 0 only has Right and Down as forward edges?
1103                                 // If so, dir=2 should never happen for Parity 0!
1104                                 // UNLESS I messed up u < v logic.
1105                                 // If u < v, and v is Up-Right of u?
1106                                 // v.y < u.y. v < u. Contradiction.
1107                                 
1108                                 // So, Parity 0 nodes ONLY have dir=0 and dir=1.
1109                                 // Parity 1 nodes have dir=0, dir=1, dir=2 (Down-Left).
1110                                 
1111                                 // Wait, what about Down-Right?
1112                                 // Topology.rs doesn't seem to implement Down-Right for anyone?
1113                                 // TriangularGrid:
1114                                 // Parity 0: Up-Right.
1115                                 // Parity 1: Down-Left.
1116                                 // This forms diagonals like ///.
1117                                 // So only one type of diagonal exists in the whole grid ( ///// ).
1118                                 // (x, y) connected to (x+1, y-1).
1119                                 // Equivalent to (x, y) connected to (x-1, y+1).
1120                                 // So yes, all diagonals are "Up-Right / Down-Left" type.
1121                                 // "Down-Right" (\) does not exist.
1122                                 
1123                                 // So dir=2 ALWAYS means Down-Left (x-1, y+1).
1124                                 (ux.wrapping_sub(1), uy + 1)
1125                             }
1126                         }
1127                    }
1128                    _ => (ux, uy),
1129                };
1130                
1131                // Map global v back to tiled u32
1132                if vx < self.width && vy < self.height {
1133                    let v_tx = vx / 32;
1134                    let v_ty = vy / 32;
1135                    let v_lx = vx % 32;
1136                    let v_ly = vy % 32;
1137                    let v_node = (v_ty * self.tiles_x + v_tx) * 1024 + (v_ly * 32 + v_lx);
1138                    
1139                    if count < corrections.len() {
1140                        unsafe {
1141                            *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v: v_node as u32 };
1142                        }
1143                        count += 1;
1144                    }
1145                }
1146            }
1147        }
1148        self.edge_dirty_count = 0;
1149
1150        // Process Dirty Boundaries
1151        for i in 0..self.boundary_dirty_count {
1152            let blk_idx = unsafe { *self.boundary_dirty_list.get_unchecked(i) } as usize;
1153
1154            // Clear mask
1155            let mask_idx = blk_idx >> 6;
1156            let mask_bit = blk_idx & 63;
1157            unsafe {
1158                *self.boundary_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
1159            }
1160
1161            let word_ptr = unsafe { self.boundary_bitmap.get_unchecked_mut(blk_idx) };
1162            let mut word = *word_ptr;
1163            *word_ptr = 0;
1164
1165            let base_u = blk_idx * 64;
1166            while word != 0 {
1167                let bit = tzcnt(word) as usize;
1168                word &= word - 1;
1169                let u = (base_u + bit) as u32;
1170
1171                if count < corrections.len() {
1172                    unsafe {
1173                        *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v: u32::MAX };
1174                    }
1175                    count += 1;
1176                }
1177            }
1178        }
1179        self.boundary_dirty_count = 0;
1180
1181        count
1182    }
1183}
1184
1185#[cfg(test)]
1186
1187mod tests {
1188
1189    use super::*;
1190
1191    use crate::arena::Arena;
1192    use crate::topology::SquareGrid;
1193
1194    extern crate std;
1195
1196    #[test]
1197
1198    fn test_tiled_horizontal_stitching() {
1199        let mut memory = std::vec![0u8; 1024 * 1024 * 16];
1200
1201        let mut arena = Arena::new(&mut memory);
1202
1203        // 64x32 grid (2 tiles wide, 1 tile high)
1204
1205        let width = 64;
1206
1207        let height = 32;
1208
1209        let mut decoder = TiledDecodingState::<SquareGrid>::new(&mut arena, width, height);
1210
1211        // Node A: (31, 0) -> Right edge of Tile 0. Global index 31.
1212
1213        // Node B: (32, 0) -> Left edge of Tile 1. Global index 1024 (Tile 1 start).
1214
1215        let node_a = 31;
1216
1217        let node_b = 1024;
1218
1219        // Manually inject defects/active state to simulate growth
1220
1221        // We need to find the blocks corresponding to these nodes.
1222
1223        // Node A: Tile 0. Local 31. Block 0. Bit 31.
1224
1225        // Node B: Tile 1. Local 0. Block 16. Bit 0.
1226
1227        let blk_a = 0;
1228
1229        let bit_a = 31;
1230
1231        let blk_b = 16;
1232
1233        let bit_b = 0;
1234
1235        decoder.blocks_state[blk_a].occupied |= 1 << bit_a;
1236
1237        decoder.blocks_state[blk_a].boundary |= 1 << bit_a; // Ensure it spreads
1238
1239        decoder.active_mask[blk_a >> 6] |= 1 << (blk_a & 63);
1240
1241        decoder.blocks_state[blk_b].occupied |= 1 << bit_b;
1242
1243        decoder.blocks_state[blk_b].boundary |= 1 << bit_b;
1244
1245        decoder.active_mask[blk_b >> 6] |= 1 << (blk_b & 63);
1246
1247        // Run growth
1248
1249        decoder.grow_clusters();
1250
1251        let root_a = decoder.find(node_a);
1252
1253        let root_b = decoder.find(node_b);
1254
1255        assert_eq!(
1256            root_a, root_b,
1257            "Horizontal stitching failed between Tile 0 and Tile 1"
1258        );
1259    }
1260
1261    #[test]
1262
1263    fn test_tiled_vertical_stitching() {
1264        let mut memory = std::vec![0u8; 1024 * 1024 * 16];
1265
1266        let mut arena = Arena::new(&mut memory);
1267
1268        // 32x64 grid (1 tile wide, 2 tiles high)
1269
1270        let width = 32;
1271
1272        let height = 64;
1273
1274        let mut decoder = TiledDecodingState::<SquareGrid>::new(&mut arena, width, height);
1275
1276        // Node A: (0, 31) -> Bottom edge of Tile 0. Global index 31*32 + 0 = 992.
1277
1278        // Node B: (0, 32) -> Top edge of Tile 1. Global index 1024 + 0 = 1024.
1279
1280        let node_a = 992;
1281
1282        let node_b = 1024;
1283
1284        // Node A: Tile 0. Local 992. Block 992/64 = 15. Bit 992%64 = 32.
1285
1286        // Node B: Tile 1. Local 0. Block 16. Bit 0.
1287
1288        let blk_a = 15;
1289
1290        let bit_a = 32;
1291
1292        let blk_b = 16;
1293
1294        let bit_b = 0;
1295
1296        decoder.blocks_state[blk_a].occupied |= 1 << bit_a;
1297
1298        decoder.blocks_state[blk_a].boundary |= 1 << bit_a;
1299
1300        decoder.active_mask[blk_a >> 6] |= 1 << (blk_a & 63);
1301
1302        decoder.blocks_state[blk_b].occupied |= 1 << bit_b;
1303
1304        decoder.blocks_state[blk_b].boundary |= 1 << bit_b;
1305
1306        decoder.active_mask[blk_b >> 6] |= 1 << (blk_b & 63);
1307
1308        decoder.grow_clusters();
1309
1310        let root_a = decoder.find(node_a);
1311
1312        let root_b = decoder.find(node_b);
1313
1314        assert_eq!(
1315            root_a, root_b,
1316            "Vertical stitching failed between Tile 0 and Tile 1"
1317        );
1318    }
1319}