prav_core/decoder/peeling/
mod.rs1#![allow(unsafe_op_in_unsafe_fn)]
31
32use crate::decoder::state::DecodingState;
33use crate::decoder::types::EdgeCorrection;
34use crate::intrinsics::tzcnt;
35use crate::topology::Topology;
36
37pub trait Peeling {
61 fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize;
74
75 fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize;
96
97 fn reconstruct_corrections(&mut self, corrections: &mut [EdgeCorrection]) -> usize;
110
111 fn trace_path(&mut self, u: u32, boundary_node: u32);
121
122 fn trace_bfs(&mut self, u: u32, v: u32, mask: u64);
133
134 fn trace_bitmask_bfs(&mut self, start_node: u32);
143
144 fn trace_manhattan(&mut self, u: u32, v: u32);
154
155 fn emit_linear(&mut self, u: u32, v: u32);
165
166 fn get_coord(&self, u: u32) -> (usize, usize, usize);
178}
179
180impl<'a, T: Topology, const STRIDE_Y: usize> Peeling for DecodingState<'a, T, STRIDE_Y> {
181 fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
182 self.grow_clusters();
183 self.peel_forest(corrections)
184 }
185
186 fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
187 self.path_mark.fill(0);
188 let boundary_node = (self.parents.len() - 1) as u32;
189 let _is_small_grid = STRIDE_Y <= 32 && self.blocks_state.len() <= 17;
190
191 for blk_idx in 0..self.defect_mask.len() {
192 let mut word = unsafe { *self.defect_mask.get_unchecked(blk_idx) };
193 if word == 0 {
194 continue;
195 }
196
197 let base_node = blk_idx * 64;
198 while word != 0 {
199 let bit = tzcnt(word) as usize;
200 word &= word - 1;
201 let u = (base_node + bit) as u32;
202
203 if _is_small_grid {
204 let root = self.find(u);
205 if root == u && root != boundary_node {
206 let occ = unsafe { self.blocks_state.get_unchecked(blk_idx).occupied };
207 if (occ & (1 << bit)) != 0 {
208 self.trace_bitmask_bfs(u);
209 continue;
210 }
211 }
212 }
213
214 self.trace_path(u, boundary_node);
215 }
216 }
217
218 for blk_idx in 0..self.path_mark.len() {
219 let mut word = unsafe { *self.path_mark.get_unchecked(blk_idx) };
220 if word == 0 {
221 continue;
222 }
223
224 let base_node = blk_idx * 64;
225 while word != 0 {
226 let bit = tzcnt(word) as usize;
227 word &= word - 1;
228 let u = (base_node + bit) as u32;
229 let v = unsafe { *self.parents.get_unchecked(u as usize) };
230
231 if u != v {
232 self.trace_manhattan(u, v);
233 }
234 }
235 }
236
237 self.reconstruct_corrections(corrections)
238 }
239
240 fn reconstruct_corrections(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
241 let mut count = 0;
242
243 for i in 0..self.edge_dirty_count {
244 let word_idx = unsafe { *self.edge_dirty_list.get_unchecked(i) } as usize;
245
246 let mask_idx = word_idx >> 6;
248 let mask_bit = word_idx & 63;
249 unsafe {
250 *self.edge_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
251 }
252
253 let word_ptr = unsafe { self.edge_bitmap.get_unchecked_mut(word_idx) };
254 let mut word = *word_ptr;
255 *word_ptr = 0;
256
257 let base_idx = word_idx * 64;
258 while word != 0 {
259 let bit = tzcnt(word) as usize;
260 word &= word - 1;
261
262 let global_idx = base_idx + bit;
263 let u = (global_idx / 3) as u32;
264 let dir = global_idx % 3;
265
266 let v = match dir {
267 0 => u + 1,
268 1 => u + self.stride_y as u32,
269 2 => u + self.graph.stride_z as u32,
270 _ => unsafe { core::hint::unreachable_unchecked() },
271 };
272
273 if count < corrections.len() {
274 unsafe {
275 *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v };
276 }
277 count += 1;
278 }
279 }
280 }
281 self.edge_dirty_count = 0;
282
283 for i in 0..self.boundary_dirty_count {
284 let blk_idx = unsafe { *self.boundary_dirty_list.get_unchecked(i) } as usize;
285
286 let mask_idx = blk_idx >> 6;
288 let mask_bit = blk_idx & 63;
289 unsafe {
290 *self.boundary_dirty_mask.get_unchecked_mut(mask_idx) &= !(1 << mask_bit);
291 }
292
293 let word_ptr = unsafe { self.boundary_bitmap.get_unchecked_mut(blk_idx) };
294 let mut word = *word_ptr;
295 *word_ptr = 0;
296
297 let base_u = blk_idx * 64;
298 while word != 0 {
299 let bit = tzcnt(word) as usize;
300 word &= word - 1;
301 let u = (base_u + bit) as u32;
302 if count < corrections.len() {
303 unsafe {
304 *corrections.get_unchecked_mut(count) = EdgeCorrection { u, v: u32::MAX };
305 }
306 count += 1;
307 }
308 }
309 }
310 self.boundary_dirty_count = 0;
311 count
312 }
313
314 fn trace_path(&mut self, u: u32, _boundary_node: u32) {
315 let mut curr = u;
316 loop {
317 let next = unsafe { *self.parents.get_unchecked(curr as usize) };
318
319 if curr == next {
320 break;
321 }
322
323 let blk = (curr as usize) / 64;
324 let bit = (curr as usize) % 64;
325 let mark_ptr = unsafe { self.path_mark.get_unchecked_mut(blk) };
326 let mask = 1 << bit;
327
328 *mark_ptr ^= mask;
329
330 curr = next;
331 }
332 }
333
334 fn trace_bfs(&mut self, u: u32, v: u32, mask: u64) {
335 if u == v {
336 return;
337 }
338
339 let u_local = (u % 64) as usize;
340 let v_local = (v % 64) as usize;
341
342 let mut pred = [64u8; 64];
343 let mut visited = 0u64;
344 let mut queue = 0u64;
345
346 visited |= 1 << u_local;
347 queue |= 1 << u_local;
348 pred[u_local] = u_local as u8;
349
350 let stride_y = self.stride_y;
351 let do_vertical = stride_y < 64;
352
353 let mut found = false;
354
355 while queue != 0 {
356 let curr_bit = tzcnt(queue) as usize;
357 queue &= queue - 1;
358
359 if curr_bit == v_local {
360 found = true;
361 break;
362 }
363
364 if curr_bit > 0 && (self.row_start_mask & (1 << curr_bit)) == 0 {
365 try_queue(
366 curr_bit - 1,
367 curr_bit,
368 mask,
369 &mut visited,
370 &mut queue,
371 &mut pred,
372 );
373 }
374 if curr_bit < 63 && (self.row_end_mask & (1 << curr_bit)) == 0 {
375 try_queue(
376 curr_bit + 1,
377 curr_bit,
378 mask,
379 &mut visited,
380 &mut queue,
381 &mut pred,
382 );
383 }
384 if do_vertical {
385 if curr_bit >= stride_y {
386 try_queue(
387 curr_bit - stride_y,
388 curr_bit,
389 mask,
390 &mut visited,
391 &mut queue,
392 &mut pred,
393 );
394 }
395 if curr_bit + stride_y < 64 {
396 try_queue(
397 curr_bit + stride_y,
398 curr_bit,
399 mask,
400 &mut visited,
401 &mut queue,
402 &mut pred,
403 );
404 }
405 }
406 }
407
408 if found {
409 let mut curr = v_local;
410 let base = (u / 64) * 64;
411 while curr != u_local {
412 let p = pred[curr] as usize;
413 let u_abs = base + p as u32;
414 let v_abs = base + curr as u32;
415 self.emit_linear(u_abs, v_abs);
416 curr = p;
417 }
418 } else {
419 self.trace_manhattan(u, v);
420 }
421 }
422
423 fn trace_bitmask_bfs(&mut self, start_node: u32) {
424 if STRIDE_Y <= 32 {
425 let mut visited = [0u64; 17];
426 self.trace_bitmask_bfs_impl(start_node, &mut visited);
427 } else {
428 let mut visited = [0u64; 65];
429 self.trace_bitmask_bfs_impl(start_node, &mut visited);
430 }
431 }
432
433 fn trace_manhattan(&mut self, u: u32, v: u32) {
434 if u == v {
435 return;
436 }
437
438 let boundary_node = (self.parents.len() - 1) as u32;
439 if u == boundary_node {
440 self.emit_linear(v, u32::MAX);
441 return;
442 }
443 if v == boundary_node {
444 self.emit_linear(u, u32::MAX);
445 return;
446 }
447
448 let (ux, uy, uz) = self.get_coord(u);
449 let (vx, vy, vz) = self.get_coord(v);
450
451 let dx = ux.abs_diff(vx);
452 let dy = uy.abs_diff(vy);
453 let dz = uz.abs_diff(vz);
454
455 let mut curr_idx = u as usize;
456
457 if dx > 0 {
458 let stride = self.graph.stride_x;
459 let step = if ux < vx {
460 stride as isize
461 } else {
462 -(stride as isize)
463 };
464 for _ in 0..dx {
465 let next_idx = (curr_idx as isize + step) as usize;
466 self.emit_linear(curr_idx as u32, next_idx as u32);
467 curr_idx = next_idx;
468 }
469 }
470
471 if dy > 0 {
472 let stride = self.stride_y;
473 let step = if uy < vy {
474 stride as isize
475 } else {
476 -(stride as isize)
477 };
478 for _ in 0..dy {
479 let next_idx = (curr_idx as isize + step) as usize;
480 self.emit_linear(curr_idx as u32, next_idx as u32);
481 curr_idx = next_idx;
482 }
483 }
484
485 if dz > 0 {
486 let stride = self.graph.stride_z;
487 let step = if uz < vz {
488 stride as isize
489 } else {
490 -(stride as isize)
491 };
492 for _ in 0..dz {
493 let next_idx = (curr_idx as isize + step) as usize;
494 self.emit_linear(curr_idx as u32, next_idx as u32);
495 curr_idx = next_idx;
496 }
497 }
498 }
499
500 fn emit_linear(&mut self, u: u32, v: u32) {
501 if v == u32::MAX {
502 let blk_idx = (u as usize) / 64;
503 let bit_idx = (u as usize) % 64;
504
505 let mask_idx = blk_idx >> 6;
506 let mask_bit = blk_idx & 63;
507 let m_ptr = unsafe { self.boundary_dirty_mask.get_unchecked_mut(mask_idx) };
508 if (*m_ptr & (1 << mask_bit)) == 0 {
509 *m_ptr |= 1 << mask_bit;
510 unsafe {
511 *self
512 .boundary_dirty_list
513 .get_unchecked_mut(self.boundary_dirty_count) = blk_idx as u32;
514 }
515 self.boundary_dirty_count += 1;
516 }
517 let word_ptr = unsafe { self.boundary_bitmap.get_unchecked_mut(blk_idx) };
518 *word_ptr ^= 1 << bit_idx;
519 return;
520 }
521
522 let (u, v) = if u < v { (u, v) } else { (v, u) };
523 let diff = v - u;
524
525 let dir = if diff == 1 {
526 0
527 } else if diff == self.stride_y as u32 {
528 1
529 } else if diff == self.graph.stride_z as u32 {
530 2
531 } else {
532 return;
533 };
534
535 let idx = (u as usize) * 3 + dir;
536 let word_idx = idx / 64;
537 let bit_idx = idx % 64;
538
539 let mask_idx = word_idx >> 6;
540 let mask_bit = word_idx & 63;
541 let m_ptr = unsafe { self.edge_dirty_mask.get_unchecked_mut(mask_idx) };
542 if (*m_ptr & (1 << mask_bit)) == 0 {
543 *m_ptr |= 1 << mask_bit;
544 unsafe {
545 *self
546 .edge_dirty_list
547 .get_unchecked_mut(self.edge_dirty_count) = word_idx as u32;
548 }
549 self.edge_dirty_count += 1;
550 }
551 let word_ptr = unsafe { self.edge_bitmap.get_unchecked_mut(word_idx) };
552 *word_ptr ^= 1 << bit_idx;
553 }
554
555 fn get_coord(&self, u: u32) -> (usize, usize, usize) {
556 let u = u as usize;
557 if self.graph.depth > 1 {
558 let z = u >> self.graph.shift_z;
559 let rem = u & (self.graph.stride_z - 1);
560 let y = rem >> self.graph.shift_y;
561 let x = rem & (self.stride_y - 1);
562 (x, y, z)
563 } else {
564 let y = u >> self.graph.shift_y;
565 let x = u & (self.stride_y - 1);
566 (x, y, 0)
567 }
568 }
569}
570
571impl<'a, T: Topology, const STRIDE_Y: usize> DecodingState<'a, T, STRIDE_Y> {
572 #[inline(always)]
573 fn trace_bitmask_bfs_impl(&mut self, start_node: u32, visited: &mut [u64]) {
574 let visited_len = self.blocks_state.len().min(visited.len());
576
577 let bfs_pred = &mut *self.bfs_pred;
579 let bfs_queue = &mut *self.bfs_queue;
580 let edge_bitmap = &mut *self.edge_bitmap;
581 let edge_dirty_list = &mut *self.edge_dirty_list;
582 let edge_dirty_count = &mut self.edge_dirty_count;
583 let boundary_bitmap = &mut *self.boundary_bitmap;
584 let boundary_dirty_list = &mut *self.boundary_dirty_list;
585 let boundary_dirty_count = &mut self.boundary_dirty_count;
586
587 let blocks_state = &*self.blocks_state;
588
589 let stride_y = self.stride_y as u32;
590 let stride_z = self.graph.stride_z as u32;
591 let shift_y = self.graph.shift_y;
592 let shift_z = self.graph.shift_z;
593 let mask_y = self.stride_y - 1;
594 let mask_z = self.graph.stride_z - 1;
595
596 let width = self.width;
597 let height = self.height;
598 let depth = self.graph.depth;
599 let is_3d = depth > 1;
600
601 let mut head = 0;
602 let mut tail = 0;
603
604 let start_blk = (start_node as usize) / 64;
605 let start_bit = (start_node as usize) % 64;
606
607 if start_blk < visited.len() {
608 visited[start_blk] |= 1 << start_bit;
609 }
610
611 bfs_queue[tail] = start_node as u16;
612 tail += 1;
613
614 let mut boundary_hit = u32::MAX;
615
616 let mut emit_linear = |u: u32, v: u32| {
617 if v == u32::MAX {
618 let blk_idx = (u as usize) / 64;
619 let bit_idx = (u as usize) % 64;
620
621 let word_ptr = unsafe { boundary_bitmap.get_unchecked_mut(blk_idx) };
622 if *word_ptr == 0 {
623 unsafe {
624 *boundary_dirty_list.get_unchecked_mut(*boundary_dirty_count) =
625 blk_idx as u32;
626 }
627 *boundary_dirty_count += 1;
628 }
629 *word_ptr ^= 1 << bit_idx;
630 return;
631 }
632
633 let (u, v) = if u < v { (u, v) } else { (v, u) };
634 let diff = v - u;
635
636 let dir = if diff == 1 {
637 0
638 } else if diff == stride_y {
639 1
640 } else if diff == stride_z {
641 2
642 } else {
643 return;
644 };
645
646 let idx = (u as usize) * 3 + dir;
647 let word_idx = idx / 64;
648 let bit_idx = idx % 64;
649
650 let word_ptr = unsafe { edge_bitmap.get_unchecked_mut(word_idx) };
651 if *word_ptr == 0 {
652 unsafe {
653 *edge_dirty_list.get_unchecked_mut(*edge_dirty_count) = word_idx as u32;
654 }
655 *edge_dirty_count += 1;
656 }
657 *word_ptr ^= 1 << bit_idx;
658 };
659
660 if STRIDE_Y == 32 && !is_3d {
662 while head != tail {
663 let u = bfs_queue[head] as u32;
664 head += 1;
665
666 let x = u & 31;
668 let y = u >> 5;
669 if x == 0 || x == (width as u32 - 1) || y == 0 || y == (height as u32 - 1) {
670 boundary_hit = u;
671 break;
672 }
673
674 let n = u - 1;
676 let n_blk = (n as usize) >> 6;
677 let n_bit = (n as usize) & 63;
678 if n_blk < 17 {
679 let n_occ = unsafe { blocks_state.get_unchecked(n_blk).occupied };
680 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
681 visited[n_blk] |= 1 << n_bit;
682 bfs_pred[n as usize] = u as u16;
683 bfs_queue[tail] = n as u16;
684 tail += 1;
685 }
686 }
687
688 let n = u + 1;
690 let n_blk = (n as usize) >> 6;
691 let n_bit = (n as usize) & 63;
692 if n_blk < 17 {
693 let n_occ = unsafe { blocks_state.get_unchecked(n_blk).occupied };
694 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
695 visited[n_blk] |= 1 << n_bit;
696 bfs_pred[n as usize] = u as u16;
697 bfs_queue[tail] = n as u16;
698 tail += 1;
699 }
700 }
701
702 let n = u - 32;
704 let n_blk = (n as usize) >> 6;
705 let n_bit = (n as usize) & 63;
706 if n_blk < 17 {
707 let n_occ = unsafe { blocks_state.get_unchecked(n_blk).occupied };
708 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
709 visited[n_blk] |= 1 << n_bit;
710 bfs_pred[n as usize] = u as u16;
711 bfs_queue[tail] = n as u16;
712 tail += 1;
713 }
714 }
715
716 let n = u + 32;
718 let n_blk = (n as usize) >> 6;
719 let n_bit = (n as usize) & 63;
720 if n_blk < 17 {
721 let n_occ = unsafe { blocks_state.get_unchecked(n_blk).occupied };
722 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
723 visited[n_blk] |= 1 << n_bit;
724 bfs_pred[n as usize] = u as u16;
725 bfs_queue[tail] = n as u16;
726 tail += 1;
727 }
728 }
729 }
730 } else {
731 while head != tail {
733 let u = bfs_queue[head] as u32;
734 head += 1;
735
736 let (x, y, z) = if is_3d {
738 let z_coord = (u as usize) >> shift_z;
739 let rem = (u as usize) & mask_z;
740 let y_coord = rem >> shift_y;
741 let x_coord = rem & mask_y;
742 (x_coord, y_coord, z_coord)
743 } else {
744 let y_coord = (u as usize) >> shift_y;
745 let x_coord = (u as usize) & mask_y;
746 (x_coord, y_coord, 0)
747 };
748
749 if x == 0
750 || x == width - 1
751 || y == 0
752 || y == height - 1
753 || (is_3d && (z == 0 || z == depth - 1))
754 {
755 boundary_hit = u;
756 break;
757 }
758
759 if x > 0 {
761 let n = u - 1;
762 let n_blk = (n as usize) / 64;
763 let n_bit = (n as usize) % 64;
764 if n_blk < visited.len() {
765 let n_occ = blocks_state[n_blk].occupied;
766 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
767 visited[n_blk] |= 1 << n_bit;
768 bfs_pred[n as usize] = u as u16;
769 bfs_queue[tail] = n as u16;
770 tail += 1;
771 }
772 }
773 }
774 if x < width - 1 {
776 let n = u + 1;
777 let n_blk = (n as usize) / 64;
778 let n_bit = (n as usize) % 64;
779 if n_blk < visited.len() {
780 let n_occ = blocks_state[n_blk].occupied;
781 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
782 visited[n_blk] |= 1 << n_bit;
783 bfs_pred[n as usize] = u as u16;
784 bfs_queue[tail] = n as u16;
785 tail += 1;
786 }
787 }
788 }
789 if y > 0 {
791 let n = u - stride_y;
792 let n_blk = (n as usize) / 64;
793 let n_bit = (n as usize) % 64;
794 if n_blk < visited.len() {
795 let n_occ = blocks_state[n_blk].occupied;
796 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
797 visited[n_blk] |= 1 << n_bit;
798 bfs_pred[n as usize] = u as u16;
799 bfs_queue[tail] = n as u16;
800 tail += 1;
801 }
802 }
803 }
804 if y < height - 1 {
806 let n = u + stride_y;
807 let n_blk = (n as usize) / 64;
808 let n_bit = (n as usize) % 64;
809 if n_blk < visited.len() {
810 let n_occ = blocks_state[n_blk].occupied;
811 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
812 visited[n_blk] |= 1 << n_bit;
813 bfs_pred[n as usize] = u as u16;
814 bfs_queue[tail] = n as u16;
815 tail += 1;
816 }
817 }
818 }
819 if is_3d {
821 if z > 0 {
822 let n = u - stride_z;
823 let n_blk = (n as usize) / 64;
824 let n_bit = (n as usize) % 64;
825 if n_blk < visited.len() {
826 let n_occ = blocks_state[n_blk].occupied;
827 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
828 visited[n_blk] |= 1 << n_bit;
829 bfs_pred[n as usize] = u as u16;
830 bfs_queue[tail] = n as u16;
831 tail += 1;
832 }
833 }
834 }
835 if z < depth - 1 {
836 let n = u + stride_z;
837 let n_blk = (n as usize) / 64;
838 let n_bit = (n as usize) % 64;
839 if n_blk < visited.len() {
840 let n_occ = blocks_state[n_blk].occupied;
841 if (n_occ & (1 << n_bit)) != 0 && (visited[n_blk] & (1 << n_bit)) == 0 {
842 visited[n_blk] |= 1 << n_bit;
843 bfs_pred[n as usize] = u as u16;
844 bfs_queue[tail] = n as u16;
845 tail += 1;
846 }
847 }
848 }
849 }
850 }
851 }
852
853 if boundary_hit != u32::MAX {
854 let mut curr = boundary_hit;
855 emit_linear(curr, u32::MAX);
856
857 while curr != start_node {
858 let p = bfs_pred[curr as usize];
859 if p == u16::MAX {
860 break;
861 }
862 emit_linear(p as u32, curr);
863 curr = p as u32;
864 }
865 }
866
867 #[allow(clippy::needless_range_loop)]
869 for i in 0..visited_len {
870 if visited[i] != 0 && let Some(dm) = self.defect_mask.get_mut(i) {
871 *dm &= !visited[i];
872 }
873 }
874 }
875}
876
877#[inline(always)]
878fn try_queue(
879 next: usize,
880 curr: usize,
881 mask: u64,
882 visited: &mut u64,
883 queue: &mut u64,
884 pred: &mut [u8; 64],
885) {
886 if (mask & (1 << next)) != 0 && (*visited & (1 << next)) == 0 {
887 *visited |= 1 << next;
888 *queue |= 1 << next;
889 pred[next] = curr as u8;
890 }
891}
892
893
894
895#[cfg(kani)]
896
897mod kani_proofs;
898
899