prav_core/decoder/tiled.rs
1//! Tiled decoder for large grids.
2//!
3//! This module provides [`TiledDecodingState`], which breaks large grids into
4//! 32x32 tiles for improved cache locality and SWAR efficiency. Each tile is
5//! processed using the optimized Stride-32 path, then tiles are stitched together.
6
7#![allow(unsafe_op_in_unsafe_fn)]
8
9use crate::arena::Arena;
10use crate::decoder::state::DecodingState;
11use crate::decoder::types::{BlockStateHot, BoundaryConfig, EdgeCorrection, FLAG_VALID_FULL};
12use crate::intrinsics::{morton_decode_2d, morton_encode_2d, tzcnt};
13use crate::topology::Topology;
14
15/// A tiled decoder that manages large grids by breaking them into 32x32 tiles.
16///
17/// For grids larger than ~4096 nodes, the standard decoder's cache utilization
18/// degrades. `TiledDecodingState` addresses this by:
19///
20/// 1. Dividing the grid into 32x32 tiles (1024 nodes each)
21/// 2. Running the optimized Stride-32 decoder on each tile
22/// 3. Stitching tiles together at boundaries using Union Find
23///
24/// # When to Use
25///
26/// Use `TiledDecodingState` for grids larger than 64x64 nodes. For smaller grids,
27/// the standard [`DecodingState`] is more efficient.
28///
29/// # Tile Layout
30///
31/// ```text
32/// Global grid (96x64):
33/// +--------+--------+--------+
34/// | Tile 0 | Tile 1 | Tile 2 | (each 32x32)
35/// +--------+--------+--------+
36/// | Tile 3 | Tile 4 | Tile 5 |
37/// +--------+--------+--------+
38/// ```
39///
40/// Each tile contains 16 blocks of 64 nodes, arranged in Stride-32 format.
41pub struct TiledDecodingState<'a, T: Topology> {
42 /// Total grid width in nodes.
43 pub width: usize,
44 /// Total grid height in nodes.
45 pub height: usize,
46
47 // Tiling configuration
48 /// Number of tiles in X direction.
49 pub tiles_x: usize,
50 /// Number of tiles in Y direction.
51 pub tiles_y: usize,
52
53 // We hold the raw memory, but organized in tiles.
54 // Each tile has 16 blocks (1024 nodes).
55 // Total blocks = tiles_x * tiles_y * 16.
56 /// Block state for all tiles (16 blocks per tile).
57 pub blocks_state: &'a mut [BlockStateHot],
58 /// Union Find parent pointers for all nodes.
59 pub parents: &'a mut [u32],
60
61 // Auxiliary state needed for DecodingState
62 /// Defect mask per block.
63 pub defect_mask: &'a mut [u64],
64 /// Path marking for peeling.
65 pub path_mark: &'a mut [u64],
66 /// Dirty block tracking for sparse reset.
67 pub block_dirty_mask: &'a mut [u64],
68 /// Currently active blocks.
69 pub active_mask: &'a mut [u64],
70 /// Blocks queued for next iteration.
71 pub queued_mask: &'a mut [u64],
72 /// Syndrome ingestion worklist.
73 pub ingestion_list: &'a mut [u32],
74 /// Edge correction bitmap.
75 pub edge_bitmap: &'a mut [u64],
76 /// Dirty edge word list.
77 pub edge_dirty_list: &'a mut [u32],
78 /// Dirty edge mask.
79 pub edge_dirty_mask: &'a mut [u64],
80 /// Boundary correction bitmap.
81 pub boundary_bitmap: &'a mut [u64],
82 /// Dirty boundary block list.
83 pub boundary_dirty_list: &'a mut [u32],
84 /// Dirty boundary mask.
85 pub boundary_dirty_mask: &'a mut [u64],
86 /// BFS predecessor array for path tracing.
87 pub bfs_pred: &'a mut [u16],
88 /// BFS queue for path tracing.
89 pub bfs_queue: &'a mut [u16],
90
91 // Static graph for tiles (shared because all tiles are 32x32 Stride 32)
92 /// Static graph metadata for 32x32 tiles (shared by all tiles).
93 pub tile_graph: &'a crate::decoder::graph::StaticGraph,
94
95 /// Syndrome ingestion count.
96 pub ingestion_count: usize,
97 /// Active block mask (for small-grid compatibility).
98 pub active_block_mask: u64,
99 /// Edge dirty count.
100 pub edge_dirty_count: usize,
101 /// Boundary dirty count.
102 pub boundary_dirty_count: usize,
103
104 /// Phantom marker for topology type.
105 pub _marker: core::marker::PhantomData<T>,
106}
107
108impl<'a, T: Topology> TiledDecodingState<'a, T> {
109 /// Creates a new tiled decoder for the given grid dimensions.
110 ///
111 /// # Arguments
112 ///
113 /// * `arena` - Arena allocator for all internal allocations.
114 /// * `width` - Total grid width in nodes.
115 /// * `height` - Total grid height in nodes.
116 ///
117 /// # Tile Calculation
118 ///
119 /// The grid is divided into ceil(width/32) x ceil(height/32) tiles.
120 /// Each tile contains 1024 nodes (32x32) in 16 blocks.
121 pub fn new(arena: &mut Arena<'a>, width: usize, height: usize) -> Self {
122 let tiles_x = width.div_ceil(32);
123 let tiles_y = height.div_ceil(32);
124 let num_tiles = tiles_x * tiles_y;
125
126 let blocks_per_tile = 16;
127 let nodes_per_tile = 1024;
128
129 let total_blocks = num_tiles * blocks_per_tile;
130 let total_nodes = num_tiles * nodes_per_tile + 1; // +1 for boundary
131
132 // Allocate Memory
133 let blocks_state = arena
134 .alloc_slice_aligned::<BlockStateHot>(total_blocks, 64)
135 .unwrap();
136 let parents = arena.alloc_slice_aligned::<u32>(total_nodes, 64).unwrap();
137
138 let defect_mask = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
139 let path_mark = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
140
141 let block_dirty_mask = arena
142 .alloc_slice_aligned::<u64>(total_blocks.div_ceil(64), 64)
143 .unwrap();
144
145 let num_bitmask_words = total_blocks.div_ceil(64);
146 let active_mask = arena
147 .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
148 .unwrap();
149 let queued_mask = arena
150 .alloc_slice_aligned::<u64>(num_bitmask_words, 64)
151 .unwrap();
152
153 let ingestion_list = arena.alloc_slice::<u32>(total_blocks).unwrap();
154
155 // Edge/Boundary tracking (sized for total capacity)
156 let num_edges = total_nodes * 3;
157 let num_edge_words = num_edges.div_ceil(64);
158 let edge_bitmap = arena
159 .alloc_slice_aligned::<u64>(num_edge_words, 64)
160 .unwrap();
161 let edge_dirty_list = arena.alloc_slice::<u32>(num_edge_words * 8).unwrap();
162 let edge_dirty_mask = arena
163 .alloc_slice_aligned::<u64>(num_edge_words.div_ceil(64), 64)
164 .unwrap();
165
166 let boundary_bitmap = arena.alloc_slice_aligned::<u64>(total_blocks, 64).unwrap();
167 let boundary_dirty_list = arena.alloc_slice::<u32>(total_blocks * 8).unwrap();
168 let boundary_dirty_mask = arena
169 .alloc_slice_aligned::<u64>(total_blocks.div_ceil(64), 64)
170 .unwrap();
171
172 let bfs_pred = arena.alloc_slice::<u16>(total_nodes).unwrap();
173 let bfs_queue = arena.alloc_slice::<u16>(total_nodes).unwrap();
174
175 // Create a StaticGraph for a 32x32 tile (Stride 32)
176 // Neighbor traversal uses SWAR bit operations (spread_syndrome_*) rather than lookup tables
177 let tile_graph = crate::decoder::graph::StaticGraph {
178 width: 32,
179 height: 32,
180 depth: 1,
181 stride_x: 1,
182 stride_y: 32,
183 stride_z: 1024,
184 blk_stride_y: 0, // Not used for Stride 32
185 shift_y: 5,
186 shift_z: 10,
187 row_end_mask: 0x8000000080000000,
188 row_start_mask: 0x0000000100000001,
189 };
190 let tile_graph_ref = arena.alloc_value(tile_graph).unwrap();
191
192 let mut state = Self {
193 width,
194 height,
195 tiles_x,
196 tiles_y,
197 blocks_state,
198 parents,
199 defect_mask,
200 path_mark,
201 block_dirty_mask,
202 active_mask,
203 queued_mask,
204 ingestion_list,
205 edge_bitmap,
206 edge_dirty_list,
207 edge_dirty_mask,
208 boundary_bitmap,
209 boundary_dirty_list,
210 boundary_dirty_mask,
211 bfs_pred,
212 bfs_queue,
213 tile_graph: tile_graph_ref,
214 ingestion_count: 0,
215 active_block_mask: 0,
216 edge_dirty_count: 0,
217 boundary_dirty_count: 0,
218 _marker: core::marker::PhantomData,
219 };
220
221 state.initialize();
222 state
223 }
224
225 /// Initializes or reinitializes the decoder state.
226 ///
227 /// Sets up valid masks for all tiles based on actual grid dimensions
228 /// and resets all dynamic state.
229 pub fn initialize(&mut self) {
230 for block in self.blocks_state.iter_mut() {
231 *block = BlockStateHot::default();
232 }
233
234 // Initialize parents
235 for (i, p) in self.parents.iter_mut().enumerate() {
236 *p = i as u32;
237 }
238
239 // Initialize valid masks based on actual width/height
240 let _total_blocks = self.tiles_x * self.tiles_y * 16;
241
242 for ty in 0..self.tiles_y {
243 for tx in 0..self.tiles_x {
244 let tile_idx = ty * self.tiles_x + tx;
245 let block_offset = tile_idx * 16;
246
247 // Determine valid region for this tile
248 let tile_base_x = tx * 32;
249 let tile_base_y = ty * 32;
250
251 for i in 0..1024 {
252 let lx = i % 32;
253 let ly = i / 32;
254
255 let gx = tile_base_x + lx;
256 let gy = tile_base_y + ly;
257
258 if gx < self.width && gy < self.height {
259 let blk = block_offset + (i / 64);
260 let bit = i % 64;
261 self.blocks_state[blk].valid_mask |= 1 << bit;
262 }
263 }
264 }
265 }
266
267 for block in self.blocks_state.iter_mut() {
268 let valid = block.valid_mask;
269 block.effective_mask = valid;
270 if valid == !0 {
271 block.flags |= FLAG_VALID_FULL;
272 }
273 }
274 }
275
276 /// Resets only modified blocks for efficient reuse.
277 ///
278 /// Similar to [`DecodingState::sparse_reset`], this only resets blocks
279 /// that were touched during the previous decoding cycle.
280 pub fn sparse_reset(&mut self) {
281 for (word_idx, word_ref) in self.block_dirty_mask.iter_mut().enumerate() {
282 let mut w = *word_ref;
283 *word_ref = 0;
284 while w != 0 {
285 let bit = w.trailing_zeros();
286 w &= w - 1;
287 let blk_idx = word_idx * 64 + bit as usize;
288
289 unsafe {
290 let block = self.blocks_state.get_unchecked_mut(blk_idx);
291 block.boundary = 0;
292 block.occupied = 0;
293 block.root = u32::MAX;
294 *self.defect_mask.get_unchecked_mut(blk_idx) = 0;
295 }
296
297 // Reset parents for this block
298 // But parents are indexed globally?
299 // In TiledDecodingState, parents are laid out in Tile Major order.
300 // So Block 0 corresponds to parents 0..63.
301 // It matches simply.
302 let start_node = blk_idx * 64;
303 let end_node = (start_node + 64).min(self.parents.len());
304 for node in start_node..end_node {
305 unsafe {
306 *self.parents.get_unchecked_mut(node) = node as u32;
307 }
308 }
309 }
310 }
311
312 self.queued_mask.fill(0);
313 self.active_mask.fill(0);
314
315 let boundary_idx = self.parents.len() - 1;
316 self.parents[boundary_idx] = boundary_idx as u32;
317 }
318
319 /// Loads syndrome measurements from a dense row-major bitarray.
320 ///
321 /// Converts from row-major input format to the internal tiled format.
322 ///
323 /// # Arguments
324 ///
325 /// * `syndromes` - Dense bitarray in row-major order with power-of-2 stride.
326 pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
327 // Input `syndromes` is Row-Major (Stride = Width.next_power_of_two()).
328 // We need to map to Tiled Layout.
329
330 let max_dim = self.width.max(self.height);
331 let stride_y = max_dim.next_power_of_two();
332 let stride_shift = stride_y.trailing_zeros();
333 let stride_mask = stride_y - 1;
334
335 let _blk_stride = stride_y / 64; // Blocks per row in input
336
337 // This is slow (scalar), but it's just loading.
338 // Iterate over set bits in syndromes and map to tiled.
339
340 for (blk_idx, &word) in syndromes.iter().enumerate() {
341 if word == 0 {
342 continue;
343 }
344
345 // Input Block covers 64 bits. Stride is `stride_y`.
346 // Wait, input format depends on `generate_defects`.
347 // In growth_bench, stride is power of 2.
348 // Let's assume input matches that.
349
350 // Map each bit
351 let mut w = word;
352 while w != 0 {
353 let bit = w.trailing_zeros();
354 w &= w - 1;
355
356 let global_idx = blk_idx * 64 + bit as usize;
357 let gy = global_idx >> stride_shift;
358 let gx = global_idx & stride_mask;
359
360 if gx >= self.width || gy >= self.height {
361 continue;
362 }
363
364 let tx = gx / 32;
365 let ty = gy / 32;
366 let lx = gx % 32;
367 let ly = gy % 32;
368
369 let tile_idx = ty * self.tiles_x + tx;
370 let local_idx = ly * 32 + lx;
371
372 let target_blk = tile_idx * 16 + (local_idx / 64);
373 let target_bit = local_idx % 64;
374
375 // Mark dirty
376 let mask_idx = target_blk >> 6;
377 let mask_bit = target_blk & 63;
378 unsafe {
379 *self.block_dirty_mask.get_unchecked_mut(mask_idx) |= 1 << mask_bit;
380 let block = self.blocks_state.get_unchecked_mut(target_blk);
381 block.boundary |= 1 << target_bit;
382 block.occupied |= 1 << target_bit;
383 *self.defect_mask.get_unchecked_mut(target_blk) |= 1 << target_bit;
384
385 // Mark active
386 let active_word = target_blk >> 6;
387 let active_bit = target_blk & 63;
388 *self.active_mask.get_unchecked_mut(active_word) |= 1 << active_bit;
389 }
390 }
391 }
392 }
393
394 /// Expands cluster boundaries until convergence.
395 ///
396 /// Processes tiles in two phases:
397 /// 1. **Intra-tile growth**: Run Stride-32 decoder within each tile
398 /// 2. **Inter-tile stitching**: Connect clusters across tile boundaries
399 pub fn grow_clusters(&mut self) {
400 let max_cycles = self.width.max(self.height) * 16;
401
402 // Extract raw pointers to circumvent borrow checker in the loop.
403 // Safety: We manage disjoint access patterns manually.
404 let active_mask_ptr = self.active_mask.as_mut_ptr();
405
406 for _ in 0..max_cycles {
407 if self.active_mask.iter().all(|&w| w == 0) {
408 break;
409 }
410
411 // 1. Intra-Tile Growth (Phase 1)
412 // Iterate over all tiles
413 for ty in 0..self.tiles_y {
414 for tx in 0..self.tiles_x {
415 let tile_idx = ty * self.tiles_x + tx;
416
417 // Check if tile has active blocks
418 let start_blk = tile_idx * 16;
419 let end_blk = start_blk + 16;
420
421 // Quick check if tile is active using raw pointer to avoid full borrow
422 let mut tile_active = false;
423 for blk in start_blk..end_blk {
424 let word_idx = blk >> 6;
425 let bit = blk & 63;
426 unsafe {
427 if (*active_mask_ptr.add(word_idx) & (1 << bit)) != 0 {
428 tile_active = true;
429 break;
430 }
431 }
432 }
433
434 if !tile_active {
435 continue;
436 }
437
438 unsafe {
439 self.process_tile_unsafe(tile_idx, tx, ty);
440 }
441 }
442 }
443
444 // 2. Inter-Tile Stitching (Phase 2)
445 unsafe {
446 self.stitch_tiles();
447 }
448
449 // Swap queues
450 core::mem::swap(&mut self.active_mask, &mut self.queued_mask);
451 self.queued_mask.fill(0);
452 }
453 }
454
455 unsafe fn process_tile_unsafe(&mut self, tile_idx: usize, tx: usize, ty: usize) {
456 let block_offset = tile_idx * 16;
457 let parent_offset = tile_idx * 1024; // 32x32 nodes
458
459 // Prepare Pointers for DecodingState reconstruction
460 // We cast to lifetime 'a effectively.
461 let blocks_ptr = self.blocks_state.as_mut_ptr().add(block_offset);
462 let blocks_slice = core::slice::from_raw_parts_mut(blocks_ptr, 16);
463
464 // Pass full parents slice.
465 // We calculate parents_len to be safe, but decoder uses len() - 1 for boundary.
466 // We want decoder.parents[decoder.parents.len()-1] to be the Global Boundary.
467 // So we pass the full slice from index 0.
468 // Wait, self.parents is contiguous.
469 let parents_ptr = self.parents.as_mut_ptr();
470 let parents_len = self.parents.len();
471 let parents_slice = core::slice::from_raw_parts_mut(parents_ptr, parents_len);
472
473 let queued_mask_slice = core::slice::from_raw_parts_mut(self.queued_mask.as_mut_ptr(), self.queued_mask.len());
474 let dirty_mask_slice = core::slice::from_raw_parts_mut(self.block_dirty_mask.as_mut_ptr(), self.block_dirty_mask.len());
475
476 // Disjoint slices for auxiliary
477 // These are not actually used by process_block logic except dirty_masks.
478 // We pass empty slices for unused ones.
479
480 // Config
481 let config = BoundaryConfig {
482 check_left: tx == 0,
483 check_right: tx == self.tiles_x - 1,
484 check_top: ty == 0,
485 check_bottom: ty == self.tiles_y - 1,
486 };
487
488 // Create manual DecodingState
489 // Note: active_mask/queued_mask are passed but we will override them inside the loop
490 // with stack buffers to avoid global indexing confusion.
491 let mut decoder = DecodingState::<T, 32> {
492 graph: self.tile_graph,
493 width: 32,
494 height: 32,
495 stride_y: 32,
496 row_start_mask: 0x0000000100000001,
497 row_end_mask: 0x8000000080000000,
498
499 blocks_state: blocks_slice,
500 parents: parents_slice,
501
502 defect_mask: &mut [],
503 path_mark: &mut [],
504 block_dirty_mask: dirty_mask_slice, // Global
505
506 active_mask: &mut [],
507 queued_mask: queued_mask_slice, // Global
508 active_block_mask: 0,
509
510 ingestion_list: &mut [],
511 ingestion_count: 0,
512
513 edge_bitmap: &mut [],
514 edge_dirty_list: &mut [],
515 edge_dirty_count: 0,
516 edge_dirty_mask: &mut [],
517
518 boundary_bitmap: &mut [],
519 boundary_dirty_list: &mut [],
520 boundary_dirty_count: 0,
521 boundary_dirty_mask: &mut [],
522
523 bfs_pred: &mut [],
524 bfs_queue: &mut [],
525
526 needs_scalar_fallback: false,
527 scalar_fallback_mask: 0,
528 boundary_config: config,
529 parent_offset,
530 _marker: core::marker::PhantomData,
531 };
532
533 // Pointers for global masks to merge back results
534 // let global_queued_ptr = self.queued_mask.as_mut_ptr();
535 // let global_dirty_ptr = self.block_dirty_mask.as_mut_ptr();
536 let global_active_ptr = self.active_mask.as_ptr(); // Read-only access to verify activation
537
538 // Loop over tile blocks
539 let start_blk = tile_idx * 16;
540 for i in 0..16 {
541 let global_blk = start_blk + i;
542 let word_idx = global_blk >> 6;
543 let bit = global_blk & 63;
544
545 let is_active = (*global_active_ptr.add(word_idx) & (1 << bit)) != 0;
546 if is_active {
547 // decoder.process_block writes directly to global masks using offset logic in portable_32.rs
548 decoder.process_block(i);
549 }
550 }
551 }
552
553 unsafe fn stitch_tiles(&mut self) {
554 // Generic Stitching using T::for_each_neighbor
555 let blocks_ptr = self.blocks_state.as_mut_ptr();
556
557 for ty in 0..self.tiles_y {
558 for tx in 0..self.tiles_x {
559 let tile_idx = ty * self.tiles_x + tx;
560 let base_gx = tx * 32;
561 let base_gy = ty * 32;
562
563 // Iterate boundary pixels of this tile
564 // Boundaries: Row 0, Row 31, Col 0, Col 31.
565 // Duplicate checks are fine (UnionFind handles it).
566
567 // Helper closure to process a pixel
568 let mut process_pixel = |lx: usize, ly: usize| {
569 let gx = base_gx + lx;
570 let gy = base_gy + ly;
571
572 if gx >= self.width || gy >= self.height {
573 return;
574 }
575
576 // My Global Morton Index
577 let m_idx = morton_encode_2d(gx as u32, gy as u32);
578
579 // My Tiled Node Index
580 let my_blk_offset = tile_idx * 16;
581 let my_local = ly * 32 + lx;
582 let my_node = (tile_idx * 1024) + my_local;
583
584 // Check if I am active
585 let my_blk = my_blk_offset + (my_local / 64);
586 let my_bit = my_local % 64;
587 let my_block = &*blocks_ptr.add(my_blk);
588 let my_active = (my_block.occupied & (1 << my_bit)) != 0;
589
590 // Iterate Neighbors
591 T::for_each_neighbor(m_idx, |n_m| {
592 let (nx, ny) = morton_decode_2d(n_m);
593 // Check if neighbor is out of bounds or in DIFFERENT tile
594 if nx as usize >= self.width || ny as usize >= self.height {
595 return;
596 }
597
598 let n_tx = (nx / 32) as usize;
599 let n_ty = (ny / 32) as usize;
600
601 // We only stitch inter-tile edges
602 if n_tx == tx && n_ty == ty {
603 return;
604 }
605
606 // Enforce order to avoid double locking/processing if possible
607 // (u, v) vs (v, u).
608 // Simple check: process only if neighbor is "larger" coordinate?
609 // But n_m vs m_idx works.
610 if n_m < m_idx {
611 return;
612 }
613
614 let n_tile_idx = n_ty * self.tiles_x + n_tx;
615 let n_lx = (nx % 32) as usize;
616 let n_ly = (ny % 32) as usize;
617
618 let n_local = n_ly * 32 + n_lx;
619 let n_blk = n_tile_idx * 16 + (n_local / 64);
620 let n_bit = n_local % 64;
621
622 let n_block = &*blocks_ptr.add(n_blk);
623 let n_active = (n_block.occupied & (1 << n_bit)) != 0;
624
625 if my_active || n_active {
626 let n_node = (n_tile_idx * 1024) + n_local;
627
628 if self.union(my_node as u32, n_node as u32) {
629 if !my_active {
630 (*blocks_ptr.add(my_blk)).occupied |= 1 << my_bit;
631 self.mark_global_dirty(my_blk);
632 }
633 if !n_active {
634 (*blocks_ptr.add(n_blk)).occupied |= 1 << n_bit;
635 self.mark_global_dirty(n_blk);
636 }
637 }
638 }
639 });
640 };
641
642 // Top & Bottom Rows
643 for x in 0..32 {
644 process_pixel(x, 0);
645 process_pixel(x, 31);
646 }
647 // Left & Right Cols (excluding corners processed above? No, process all)
648 for y in 1..31 {
649 process_pixel(0, y);
650 process_pixel(31, y);
651 }
652 }
653 }
654 }
655
656 #[inline(always)]
657 fn mark_global_dirty(&mut self, blk_idx: usize) {
658 let mask_idx = blk_idx >> 6;
659 let mask_bit = blk_idx & 63;
660 self.queued_mask[mask_idx] |= 1 << mask_bit;
661 self.block_dirty_mask[mask_idx] |= 1 << mask_bit;
662 }
663
664 // Union Find Helpers (Copied from state.rs/union_find.rs or forwarded)
665
666 /// Finds the cluster root for node `i` with path compression.
667 ///
668 /// Uses path halving compression for O(α(n)) amortized complexity.
669 pub fn find(&mut self, mut i: u32) -> u32 {
670 unsafe {
671 let mut p = *self.parents.get_unchecked(i as usize);
672 if p != i {
673 let mut gp = *self.parents.get_unchecked(p as usize);
674 if gp != p {
675 loop {
676 *self.parents.get_unchecked_mut(i as usize) = gp;
677 i = gp;
678 p = *self.parents.get_unchecked(gp as usize);
679 if p == gp {
680 break;
681 }
682 gp = *self.parents.get_unchecked(p as usize);
683 if gp == p {
684 *self.parents.get_unchecked_mut(i as usize) = p;
685 break;
686 }
687 }
688 }
689 return gp;
690 }
691 i
692 }
693 }
694
695 /// Merges clusters containing nodes `u` and `v`.
696 ///
697 /// Returns `true` if clusters were different and got merged.
698 pub fn union(&mut self, u: u32, v: u32) -> bool {
699 let root_u = self.find(u);
700 let root_v = self.find(v);
701 if root_u != root_v {
702 let (small, big) = if root_u < root_v {
703 (root_u, root_v)
704 } else {
705 (root_v, root_u)
706 };
707 unsafe {
708 *self.parents.get_unchecked_mut(small as usize) = big;
709 let tile = small as usize / 1024;
710 if tile < self.tiles_x * self.tiles_y {
711 let local = small as usize % 1024;
712 let blk = tile * 16 + (local / 64);
713 self.mark_global_dirty(blk);
714 }
715 }
716 return true;
717 }
718 false
719 }
720
721 /// Extracts edge corrections from grown clusters.
722 ///
723 /// Traces paths from syndrome nodes through the Union Find forest,
724 /// converting logical edges to physical corrections.
725 ///
726 /// # Arguments
727 ///
728 /// * `corrections` - Output buffer for edge corrections.
729 ///
730 /// # Returns
731 ///
732 /// Number of corrections written to the buffer.
733 pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
734 // Clear auxiliary buffers
735 self.path_mark.fill(0);
736
737 // 1. Identify Syndromes and Trace Paths to Root/Boundary
738 // Iterate over defect mask to find syndromes
739 for blk_idx in 0..self.defect_mask.len() {
740 let mut word = unsafe { *self.defect_mask.get_unchecked(blk_idx) };
741 if word == 0 {
742 continue;
743 }
744
745 let base_node = blk_idx * 64;
746 while word != 0 {
747 let bit = tzcnt(word) as usize;
748 word &= word - 1;
749 let u = (base_node + bit) as u32;
750
751 // Trace path from syndrome u to root/boundary
752 self.trace_path(u);
753 }
754 }
755
756 // 2. Process Marked Paths (Peeling)
757 // Iterate over path_mark. If a node u is marked, it means the edge (u, parent[u]) is part of an odd number of paths.
758 // We need to 'realize' this logical edge as physical edges (corrections).
759 for blk_idx in 0..self.path_mark.len() {
760 let mut word = unsafe { *self.path_mark.get_unchecked(blk_idx) };
761 if word == 0 {
762 continue;
763 }
764
765 let base_node = blk_idx * 64;
766 while word != 0 {
767 let bit = tzcnt(word) as usize;
768 word &= word - 1;
769 let u = (base_node + bit) as u32;
770
771 let v = unsafe { *self.parents.get_unchecked(u as usize) };
772
773 if u != v {
774 self.trace_manhattan_tiled(u, v);
775 }
776 }
777 }
778
779 // 3. Collect Corrections from dirty edges
780 self.reconstruct_tiled_corrections(corrections)
781 }
782
783 fn trace_path(&mut self, u: u32) {
784 let mut curr = u;
785 loop {
786 let next = unsafe { *self.parents.get_unchecked(curr as usize) };
787 if curr == next {
788 break;
789 }
790
791 let blk = (curr as usize) / 64;
792 let bit = (curr as usize) % 64;
793
794 unsafe {
795 *self.path_mark.get_unchecked_mut(blk) ^= 1 << bit;
796 }
797 curr = next;
798 }
799 }
800
801 fn trace_manhattan_tiled(&mut self, u: u32, v: u32) {
802 let boundary_node = (self.parents.len() - 1) as u32;
803
804 if u == boundary_node {
805 self.emit_tiled_edge(v, u32::MAX);
806 return;
807 }
808 if v == boundary_node {
809 self.emit_tiled_edge(u, u32::MAX);
810 return;
811 }
812
813 let (ux, uy) = self.get_global_coord(u);
814 let (vx, vy) = self.get_global_coord(v);
815
816 // Simple Manhattan routing: Move X, then Move Y.
817 // We assume the path is valid in the topology or at least sufficient for correction.
818 // For Square/Rect: Always valid.
819 // For Triangular/Honeycomb: Manhattan path exists (using only cardinal moves).
820
821 let dx = (vx as isize) - (ux as isize);
822 let dy = (vy as isize) - (uy as isize);
823
824 let mut curr = u;
825 let mut curr_x = ux as isize;
826 let mut curr_y = uy as isize;
827
828 // Move in X
829 if dx != 0 {
830 let step = if dx > 0 { 1 } else { -1 };
831 let steps = dx.abs();
832 for _ in 0..steps {
833 let next_x = curr_x + step;
834 if let Some(next) = self.get_node_idx(next_x as usize, curr_y as usize) {
835 self.emit_tiled_edge(curr, next);
836 curr = next;
837 curr_x = next_x;
838 } else {
839 break; // Should not happen if bounds correct
840 }
841 }
842 }
843
844 // Move in Y
845 if dy != 0 {
846 let step = if dy > 0 { 1 } else { -1 };
847 let steps = dy.abs();
848 for _ in 0..steps {
849 let next_y = curr_y + step;
850 if let Some(next) = self.get_node_idx(curr_x as usize, next_y as usize) {
851 self.emit_tiled_edge(curr, next);
852 curr = next;
853 curr_y = next_y;
854 } else {
855 break;
856 }
857 }
858 }
859 }
860
861 fn get_node_idx(&self, x: usize, y: usize) -> Option<u32> {
862 if x < self.width && y < self.height {
863 let tx = x / 32;
864 let ty = y / 32;
865 let lx = x % 32;
866 let ly = y % 32;
867 let node = (ty * self.tiles_x + tx) * 1024 + (ly * 32 + lx);
868 Some(node as u32)
869 } else {
870 None
871 }
872 }
873
874 fn get_global_coord(&self, u: u32) -> (usize, usize) {
875 let tile_idx = (u as usize) / 1024;
876 let local_idx = (u as usize) % 1024;
877
878 let tx = tile_idx % self.tiles_x;
879 let ty = tile_idx / self.tiles_x;
880
881 let lx = local_idx % 32;
882 let ly = local_idx / 32;
883
884 (tx * 32 + lx, ty * 32 + ly)
885 }
886
887 fn emit_tiled_edge(&mut self, u: u32, v: u32) {
888 if v == u32::MAX {
889 // Boundary edge
890 let blk_idx = (u as usize) / 64;
891 let bit_idx = (u as usize) % 64;
892
893 let mask_idx = blk_idx >> 6;
894 let mask_bit = blk_idx & 63;
895
896 unsafe {
897 let m_ptr = self.boundary_dirty_mask.get_unchecked_mut(mask_idx);
898 if (*m_ptr & (1 << mask_bit)) == 0 {
899 *m_ptr |= 1 << mask_bit;
900 *self
901 .boundary_dirty_list
902 .get_unchecked_mut(self.boundary_dirty_count) = blk_idx as u32;
903 self.boundary_dirty_count += 1;
904 }
905 *self.boundary_bitmap.get_unchecked_mut(blk_idx) ^= 1 << bit_idx;
906 }
907 return;
908 }
909
910 // Regular edge
911 // We need to determine an edge index for (u, v).
912 // Canonical order: u < v.
913 let (src, dst) = if u < v { (u, v) } else { (v, u) };
914
915 // Determine 'dir' (0..2) relative to src.
916 // We need to find dst in src's neighbors and get its index?
917 // Or calculate relative position.
918 let (ux, uy) = self.get_global_coord(src);
919 let (vx, vy) = self.get_global_coord(dst);
920
921 // Differences (dst - src)
922 // Since src < dst (mostly), and layout is Row-Major...
923 // vy >= uy.
924 // If vy == uy, vx > ux (Right).
925 // If vy > uy, vx could be anything.
926
927 let dx = (vx as isize) - (ux as isize);
928 let dy = (vy as isize) - (uy as isize);
929
930 // TriangularGrid only has ONE diagonal type per node, so both diagonal
931 // directions (Down-Right and Down-Left) map to slot 2.
932 let dir = if dy == 0 && dx == 1 {
933 0 // Right
934 } else if dy == 1 && dx == 0 {
935 1 // Down
936 } else if dy == 1 && (dx == 1 || dx == -1) {
937 2 // Diagonal (slot 2 for either direction)
938 } else {
939 // Should not happen for handled topologies
940 0
941 };
942
943 // Edge Index
944 let edge_idx = (src as usize) * 3 + dir;
945 let word_idx = edge_idx / 64;
946 let bit_idx = edge_idx % 64;
947
948 let mask_idx = word_idx >> 6;
949 let mask_bit = word_idx & 63;
950
951 unsafe {
952 let m_ptr = self.edge_dirty_mask.get_unchecked_mut(mask_idx);
953 if (*m_ptr & (1 << mask_bit)) == 0 {
954 *m_ptr |= 1 << mask_bit;
955 *self
956 .edge_dirty_list
957 .get_unchecked_mut(self.edge_dirty_count) = word_idx as u32;
958 self.edge_dirty_count += 1;
959 }
960 *self.edge_bitmap.get_unchecked_mut(word_idx) ^= 1 << bit_idx;
961 }
962 }
963
964 fn reconstruct_tiled_corrections(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
965 let mut count = 0;
966
967 // Process Dirty Edges
968 for i in 0..self.edge_dirty_count {
969 let word_idx = unsafe { *self.edge_dirty_list.get_unchecked(i) } as usize;
970
971 // Clear mask
972 let mask_idx = word_idx >> 6;
973 let mask_bit = word_idx & 63;
974 unsafe {
975 *self.edge_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
976 }
977
978 let word_ptr = unsafe { self.edge_bitmap.get_unchecked_mut(word_idx) };
979 let mut word = *word_ptr;
980 *word_ptr = 0;
981
982 let base_idx = word_idx * 64;
983 while word != 0 {
984 let bit = tzcnt(word) as usize;
985 word &= word - 1;
986
987 let edge_idx = base_idx + bit;
988 let u = (edge_idx / 3) as u32;
989 let dir = edge_idx % 3;
990
991 let (ux, uy) = self.get_global_coord(u);
992
993 // Recover v from u and dir
994 let (vx, vy) = match dir {
995 0 => (ux + 1, uy),
996 1 => (ux, uy + 1),
997 2 => {
998 // Determine diagonal direction based on Topology?
999 // Or assume fixed diagonal for now?
1000 // TriangularGrid logic:
1001 // if (idx.count_ones() & 1) == 0 -> Up-Right.
1002 // But we are at 'u' looking for neighbor 'v' such that u < v.
1003 // Neighbors of u: Right, Down, maybe Diag.
1004 // If u has Down-Right diagonal: (ux+1, uy+1).
1005 // If u has Down-Left diagonal: (ux-1, uy+1).
1006 // We need T to know which one.
1007 // BUT we are generic.
1008 // We can assume for TriangularGrid, Diag is always "the diagonal".
1009 // However, T::for_each_neighbor is the truth.
1010 // If we assume TriangularGrid matches `prav-core/src/topology.rs`:
1011 // Parity check uses Morton index.
1012 let m_idx = morton_encode_2d(ux as u32, uy as u32);
1013 if (m_idx.count_ones() & 1) == 0 {
1014 // "Has Right and Up". Diag is Up-Right. (x+1, y-1).
1015 // But v must be > u. (y-1) < y. So u > v.
1016 // This edge would be stored at v (the smaller node).
1017 // So if we are at u, we store edges to v > u.
1018 // For this node u, valid v > u neighbors are:
1019 // Right (x+1, y). Down (x, y+1).
1020 // Does it have Down-Right or Down-Left?
1021 // If neighbor w (Down-Left) exists, w < u? No, w.y > u.y. So w > u.
1022 // So Down-Left is a valid forward edge.
1023 // If neighbor z (Down-Right) exists, z > u.
1024
1025 // We need to know which one 'dir=2' represents.
1026 // In emit_tiled_edge:
1027 // dy=1, dx=1 -> Down-Right.
1028 // dy=1, dx=-1 -> Down-Left.
1029 // Both mapped to dir=2.
1030 // This implies a node has AT MOST one "forward diagonal" neighbor?
1031 // TriangularGrid: 6 neighbors.
1032 // Left, Right, Up, Down.
1033 // Plus ONE diagonal pair.
1034 // If Parity 0: Diag is / (Up-Right, Down-Left).
1035 // Forward neighbors (>u): Right, Down, Down-Left.
1036 // If Parity 1: Diag is \ (Up-Left, Down-Right).
1037 // Forward neighbors (>u): Right, Down, Down-Right.
1038
1039 // So yes! Only one "Down" diagonal per node.
1040 // Parity 0 -> Down-Left.
1041 // Parity 1 -> Down-Right.
1042 if (m_idx.count_ones() & 1) == 0 {
1043 (ux.wrapping_sub(1), uy + 1)
1044 } else {
1045 (ux + 1, uy + 1)
1046 }
1047 } else {
1048 // Copy-paste error in reasoning above?
1049 // Re-read Topology.rs:
1050 // if parity 0: if has_right && has_up { dec(right, Y) -> (x+1, y-1) }
1051 // This is Up-Right.
1052 // Does it have Down-Left?
1053 // No, "else if parity 1: if has_left && has_down { inc(left, Y) -> (x-1, y+1) }"
1054 // This is Down-Left.
1055
1056 // So Node Parity 0 has Up-Right.
1057 // Node Parity 1 has Down-Left.
1058
1059 // Let's verify reciprocal.
1060 // Edge (u, v). u < v.
1061 // v is "Down" relative to u (mostly).
1062 // If Edge is Up-Right from u? v = (x+1, y-1).
1063 // v.y < u.y. So v < u (mostly).
1064 // So Up-Right is a BACKWARD edge.
1065 // We don't store it at u. We store it at v.
1066
1067 // If Edge is Down-Left from u? v = (x-1, y+1).
1068 // v.y > u.y. So v > u.
1069 // So Down-Left is a FORWARD edge.
1070 // Does Parity 0 have Down-Left? No.
1071 // Does Parity 1 have Down-Left? Yes.
1072
1073 // So Parity 1 nodes have a generic "Diagonal Forward" (Down-Left).
1074 // What about Parity 0 nodes?
1075 // They have Up-Right (Backward).
1076 // Do they have Down-Right? No.
1077 // So Parity 0 nodes have NO Forward Diagonal?
1078 // Wait.
1079 // Triangular grid is connected.
1080 // Edges are undirected.
1081 // Edge between A(0,0) and B(1,1)?
1082 // A=0 (Parity 0). B=3 (Parity 0).
1083 // A has Up-Right? No (y=0).
1084 // B has Up-Right? (2, 0).
1085 // This assumes specific layout.
1086
1087 // Let's rely on coordinates from emit_tiled_edge logic.
1088 // We map dir=2 to "The Diagonal".
1089 // But reconstructing requires knowing WHICH diagonal.
1090 // We can use the Parity logic again.
1091
1092 if (m_idx.count_ones() & 1) != 0 {
1093 // Parity 1 has Down-Left (Forward).
1094 (ux.wrapping_sub(1), uy + 1)
1095 } else {
1096 // Parity 0 has... NO forward diagonal?
1097 // Check neighbors of Parity 0.
1098 // Right (x+1, y) -> >u.
1099 // Down (x, y+1) -> >u.
1100 // Up-Right (x+1, y-1) -> <u.
1101 // Left, Up -> <u.
1102 // So Parity 0 only has Right and Down as forward edges?
1103 // If so, dir=2 should never happen for Parity 0!
1104 // UNLESS I messed up u < v logic.
1105 // If u < v, and v is Up-Right of u?
1106 // v.y < u.y. v < u. Contradiction.
1107
1108 // So, Parity 0 nodes ONLY have dir=0 and dir=1.
1109 // Parity 1 nodes have dir=0, dir=1, dir=2 (Down-Left).
1110
1111 // Wait, what about Down-Right?
1112 // Topology.rs doesn't seem to implement Down-Right for anyone?
1113 // TriangularGrid:
1114 // Parity 0: Up-Right.
1115 // Parity 1: Down-Left.
1116 // This forms diagonals like ///.
1117 // So only one type of diagonal exists in the whole grid ( ///// ).
1118 // (x, y) connected to (x+1, y-1).
1119 // Equivalent to (x, y) connected to (x-1, y+1).
1120 // So yes, all diagonals are "Up-Right / Down-Left" type.
1121 // "Down-Right" (\) does not exist.
1122
1123 // So dir=2 ALWAYS means Down-Left (x-1, y+1).
1124 (ux.wrapping_sub(1), uy + 1)
1125 }
1126 }
1127 }
1128 _ => (ux, uy),
1129 };
1130
1131 // Map global v back to tiled u32
1132 if vx < self.width && vy < self.height {
1133 let v_tx = vx / 32;
1134 let v_ty = vy / 32;
1135 let v_lx = vx % 32;
1136 let v_ly = vy % 32;
1137 let v_node = (v_ty * self.tiles_x + v_tx) * 1024 + (v_ly * 32 + v_lx);
1138
1139 if count < corrections.len() {
1140 unsafe {
1141 *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v: v_node as u32 };
1142 }
1143 count += 1;
1144 }
1145 }
1146 }
1147 }
1148 self.edge_dirty_count = 0;
1149
1150 // Process Dirty Boundaries
1151 for i in 0..self.boundary_dirty_count {
1152 let blk_idx = unsafe { *self.boundary_dirty_list.get_unchecked(i) } as usize;
1153
1154 // Clear mask
1155 let mask_idx = blk_idx >> 6;
1156 let mask_bit = blk_idx & 63;
1157 unsafe {
1158 *self.boundary_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
1159 }
1160
1161 let word_ptr = unsafe { self.boundary_bitmap.get_unchecked_mut(blk_idx) };
1162 let mut word = *word_ptr;
1163 *word_ptr = 0;
1164
1165 let base_u = blk_idx * 64;
1166 while word != 0 {
1167 let bit = tzcnt(word) as usize;
1168 word &= word - 1;
1169 let u = (base_u + bit) as u32;
1170
1171 if count < corrections.len() {
1172 unsafe {
1173 *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v: u32::MAX };
1174 }
1175 count += 1;
1176 }
1177 }
1178 }
1179 self.boundary_dirty_count = 0;
1180
1181 count
1182 }
1183}
1184
1185#[cfg(test)]
1186
1187mod tests {
1188
1189 use super::*;
1190
1191 use crate::arena::Arena;
1192 use crate::topology::SquareGrid;
1193
1194 extern crate std;
1195
1196 #[test]
1197
1198 fn test_tiled_horizontal_stitching() {
1199 let mut memory = std::vec![0u8; 1024 * 1024 * 16];
1200
1201 let mut arena = Arena::new(&mut memory);
1202
1203 // 64x32 grid (2 tiles wide, 1 tile high)
1204
1205 let width = 64;
1206
1207 let height = 32;
1208
1209 let mut decoder = TiledDecodingState::<SquareGrid>::new(&mut arena, width, height);
1210
1211 // Node A: (31, 0) -> Right edge of Tile 0. Global index 31.
1212
1213 // Node B: (32, 0) -> Left edge of Tile 1. Global index 1024 (Tile 1 start).
1214
1215 let node_a = 31;
1216
1217 let node_b = 1024;
1218
1219 // Manually inject defects/active state to simulate growth
1220
1221 // We need to find the blocks corresponding to these nodes.
1222
1223 // Node A: Tile 0. Local 31. Block 0. Bit 31.
1224
1225 // Node B: Tile 1. Local 0. Block 16. Bit 0.
1226
1227 let blk_a = 0;
1228
1229 let bit_a = 31;
1230
1231 let blk_b = 16;
1232
1233 let bit_b = 0;
1234
1235 decoder.blocks_state[blk_a].occupied |= 1 << bit_a;
1236
1237 decoder.blocks_state[blk_a].boundary |= 1 << bit_a; // Ensure it spreads
1238
1239 decoder.active_mask[blk_a >> 6] |= 1 << (blk_a & 63);
1240
1241 decoder.blocks_state[blk_b].occupied |= 1 << bit_b;
1242
1243 decoder.blocks_state[blk_b].boundary |= 1 << bit_b;
1244
1245 decoder.active_mask[blk_b >> 6] |= 1 << (blk_b & 63);
1246
1247 // Run growth
1248
1249 decoder.grow_clusters();
1250
1251 let root_a = decoder.find(node_a);
1252
1253 let root_b = decoder.find(node_b);
1254
1255 assert_eq!(
1256 root_a, root_b,
1257 "Horizontal stitching failed between Tile 0 and Tile 1"
1258 );
1259 }
1260
1261 #[test]
1262
1263 fn test_tiled_vertical_stitching() {
1264 let mut memory = std::vec![0u8; 1024 * 1024 * 16];
1265
1266 let mut arena = Arena::new(&mut memory);
1267
1268 // 32x64 grid (1 tile wide, 2 tiles high)
1269
1270 let width = 32;
1271
1272 let height = 64;
1273
1274 let mut decoder = TiledDecodingState::<SquareGrid>::new(&mut arena, width, height);
1275
1276 // Node A: (0, 31) -> Bottom edge of Tile 0. Global index 31*32 + 0 = 992.
1277
1278 // Node B: (0, 32) -> Top edge of Tile 1. Global index 1024 + 0 = 1024.
1279
1280 let node_a = 992;
1281
1282 let node_b = 1024;
1283
1284 // Node A: Tile 0. Local 992. Block 992/64 = 15. Bit 992%64 = 32.
1285
1286 // Node B: Tile 1. Local 0. Block 16. Bit 0.
1287
1288 let blk_a = 15;
1289
1290 let bit_a = 32;
1291
1292 let blk_b = 16;
1293
1294 let bit_b = 0;
1295
1296 decoder.blocks_state[blk_a].occupied |= 1 << bit_a;
1297
1298 decoder.blocks_state[blk_a].boundary |= 1 << bit_a;
1299
1300 decoder.active_mask[blk_a >> 6] |= 1 << (blk_a & 63);
1301
1302 decoder.blocks_state[blk_b].occupied |= 1 << bit_b;
1303
1304 decoder.blocks_state[blk_b].boundary |= 1 << bit_b;
1305
1306 decoder.active_mask[blk_b >> 6] |= 1 << (blk_b & 63);
1307
1308 decoder.grow_clusters();
1309
1310 let root_a = decoder.find(node_a);
1311
1312 let root_b = decoder.find(node_b);
1313
1314 assert_eq!(
1315 root_a, root_b,
1316 "Vertical stitching failed between Tile 0 and Tile 1"
1317 );
1318 }
1319}