1use p3_util::{log2_strict_usize, reverse_bits_len};
2
3use crate::Matrix;
4use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
5use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
6use crate::util::reverse_matrix_index_bits;
7
8pub trait BitReversibleMatrix<T: Send + Sync + Clone>: Matrix<T> {
16 type BitRev: BitReversibleMatrix<T>;
18
19 fn bit_reverse_rows(self) -> Self::BitRev;
21}
22
23#[derive(Debug)]
27pub struct BitReversalPerm {
28 log_height: usize,
32}
33
34impl BitReversalPerm {
35 pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
46 inner: Inner,
47 ) -> BitReversedMatrixView<Inner> {
48 RowIndexMappedView {
49 index_map: Self {
50 log_height: log2_strict_usize(inner.height()),
51 },
52 inner,
53 }
54 }
55}
56
57impl RowIndexMap for BitReversalPerm {
58 fn height(&self) -> usize {
59 1 << self.log_height
60 }
61
62 fn map_row_index(&self, r: usize) -> usize {
63 reverse_bits_len(r, self.log_height)
64 }
65
66 fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
69 &self,
70 inner: Inner,
71 ) -> RowMajorMatrix<T> {
72 let mut inner = inner.to_row_major_matrix();
73 reverse_matrix_index_bits(&mut inner);
74 inner
75 }
76}
77
78pub type BitReversedMatrixView<Inner> = RowIndexMappedView<BitReversalPerm, Inner>;
83
84impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T>
85 for BitReversedMatrixView<DenseMatrix<T, S>>
86{
87 type BitRev = DenseMatrix<T, S>;
88
89 fn bit_reverse_rows(self) -> Self::BitRev {
90 self.inner
91 }
92}
93
94impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T> for DenseMatrix<T, S> {
95 type BitRev = BitReversedMatrixView<Self>;
96
97 fn bit_reverse_rows(self) -> Self::BitRev {
98 BitReversalPerm::new_view(self)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use alloc::vec;
105 use alloc::vec::Vec;
106
107 use super::*;
108
109 #[test]
110 fn test_bit_reversal_perm_map_index() {
111 let perm = BitReversalPerm { log_height: 3 }; assert_eq!(perm.map_row_index(0), 0); assert_eq!(perm.map_row_index(1), 4); assert_eq!(perm.map_row_index(2), 2); assert_eq!(perm.map_row_index(3), 6); assert_eq!(perm.map_row_index(4), 1); assert_eq!(perm.map_row_index(5), 5); assert_eq!(perm.map_row_index(6), 3); assert_eq!(perm.map_row_index(7), 7); }
121
122 #[test]
123 fn test_bit_reversal_perm_height() {
124 let perm = BitReversalPerm { log_height: 3 };
125 assert_eq!(perm.height(), 8); }
127
128 #[test]
129 fn test_new_view_reverses_indices_correctly() {
130 let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
132 let bitrev = BitReversalPerm::new_view(matrix);
133
134 let expected = [0, 4, 2, 6, 1, 5, 3, 7];
136 for (i, &expected_row_idx) in expected.iter().enumerate() {
137 let row: Vec<_> = bitrev.row(i).unwrap().into_iter().collect();
138 assert_eq!(row, vec![expected_row_idx]);
139 }
140 }
141
142 #[test]
143 fn test_to_row_major_matrix_applies_reverse_matrix_index_bits() {
144 let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
145 let perm = BitReversalPerm { log_height: 3 };
146
147 let reordered = perm.to_row_major_matrix(matrix);
148 let expected_values = vec![0, 4, 2, 6, 1, 5, 3, 7]; assert_eq!(reordered.values, expected_values);
150 }
151
152 #[test]
153 fn test_bit_reversible_matrix_trait_forward_reverse() {
154 let original = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
155 let reversed_view = original.clone().bit_reverse_rows(); let back_to_dense = reversed_view.bit_reverse_rows(); assert_eq!(original.values, back_to_dense.values);
159 assert_eq!(original.width(), back_to_dense.width());
160 }
161
162 #[test]
163 #[should_panic]
164 fn test_new_view_panics_non_pow2_height() {
165 let matrix = RowMajorMatrix::new(vec![1, 2, 3], 1);
167 let _ = BitReversalPerm::new_view(matrix);
168 }
169}