p3_matrix/
bitrev.rs

1use p3_util::{log2_strict_usize, reverse_bits_len};
2
3use crate::Matrix;
4use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
5use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
6use crate::util::reverse_matrix_index_bits;
7
8/// A trait for matrices that support *bit-reversed row reordering*.
9///
10/// Implementors of this trait can switch between row-major order and bit-reversed
11/// row order (i.e., reversing the binary representation of each row index).
12///
13/// This trait allows interoperability between regular matrices and views
14/// that access their rows in a bit-reversed order.
15pub trait BitReversibleMatrix<T: Send + Sync + Clone>: Matrix<T> {
16    /// The type returned when this matrix is viewed in bit-reversed order.
17    type BitRev: BitReversibleMatrix<T>;
18
19    /// Return a version of the matrix with its row order reversed by bit index.
20    fn bit_reverse_rows(self) -> Self::BitRev;
21}
22
23/// A row index permutation that reorders rows according to bit-reversed index.
24///
25/// Used internally to implement `BitReversedMatrixView`.
26#[derive(Debug)]
27pub struct BitReversalPerm {
28    /// The logarithm (base 2) of the matrix height. For height `h`, this is `log2(h)`.
29    ///
30    /// This must be exact, so the height must be a power of two.
31    log_height: usize,
32}
33
34impl BitReversalPerm {
35    /// Create a new bit-reversal view over the given matrix.
36    ///
37    /// # Panics
38    /// Panics if the height of the matrix is not a power of two.
39    ///
40    /// # Arguments
41    /// - `inner`: The matrix to wrap in a bit-reversed row view.
42    ///
43    /// # Returns
44    /// A `BitReversedMatrixView` that wraps the input with row reordering.
45    pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
46        inner: Inner,
47    ) -> BitReversedMatrixView<Inner> {
48        RowIndexMappedView {
49            index_map: Self {
50                log_height: log2_strict_usize(inner.height()),
51            },
52            inner,
53        }
54    }
55}
56
57impl RowIndexMap for BitReversalPerm {
58    fn height(&self) -> usize {
59        1 << self.log_height
60    }
61
62    fn map_row_index(&self, r: usize) -> usize {
63        reverse_bits_len(r, self.log_height)
64    }
65
66    // This might not be more efficient than the lazy generic impl
67    // if we have a nested view.
68    fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
69        &self,
70        inner: Inner,
71    ) -> RowMajorMatrix<T> {
72        let mut inner = inner.to_row_major_matrix();
73        reverse_matrix_index_bits(&mut inner);
74        inner
75    }
76}
77
78/// A matrix view that reorders its rows using bit-reversal.
79///
80/// This type is produced by applying `BitReversibleMatrix::bit_reverse_rows()`
81/// to a `DenseMatrix`.
82pub type BitReversedMatrixView<Inner> = RowIndexMappedView<BitReversalPerm, Inner>;
83
84impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T>
85    for BitReversedMatrixView<DenseMatrix<T, S>>
86{
87    type BitRev = DenseMatrix<T, S>;
88
89    fn bit_reverse_rows(self) -> Self::BitRev {
90        self.inner
91    }
92}
93
94impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T> for DenseMatrix<T, S> {
95    type BitRev = BitReversedMatrixView<Self>;
96
97    fn bit_reverse_rows(self) -> Self::BitRev {
98        BitReversalPerm::new_view(self)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use alloc::vec;
105    use alloc::vec::Vec;
106
107    use super::*;
108
109    #[test]
110    fn test_bit_reversal_perm_map_index() {
111        let perm = BitReversalPerm { log_height: 3 }; // height = 8
112        assert_eq!(perm.map_row_index(0), 0); // 000 -> 000
113        assert_eq!(perm.map_row_index(1), 4); // 001 -> 100
114        assert_eq!(perm.map_row_index(2), 2); // 010 -> 010
115        assert_eq!(perm.map_row_index(3), 6); // 011 -> 110
116        assert_eq!(perm.map_row_index(4), 1); // 100 -> 001
117        assert_eq!(perm.map_row_index(5), 5); // 101 -> 101
118        assert_eq!(perm.map_row_index(6), 3); // 110 -> 011
119        assert_eq!(perm.map_row_index(7), 7); // 111 -> 111
120    }
121
122    #[test]
123    fn test_bit_reversal_perm_height() {
124        let perm = BitReversalPerm { log_height: 3 };
125        assert_eq!(perm.height(), 8); // 2^3
126    }
127
128    #[test]
129    fn test_new_view_reverses_indices_correctly() {
130        // Matrix with height = 8 (2^3), width = 1: [0,1,2,3,4,5,6,7]
131        let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
132        let bitrev = BitReversalPerm::new_view(matrix);
133
134        // Should map row indices via bit reversal
135        let expected = [0, 4, 2, 6, 1, 5, 3, 7];
136        for (i, &expected_row_idx) in expected.iter().enumerate() {
137            let row: Vec<_> = bitrev.row(i).unwrap().into_iter().collect();
138            assert_eq!(row, vec![expected_row_idx]);
139        }
140    }
141
142    #[test]
143    fn test_to_row_major_matrix_applies_reverse_matrix_index_bits() {
144        let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
145        let perm = BitReversalPerm { log_height: 3 };
146
147        let reordered = perm.to_row_major_matrix(matrix);
148        let expected_values = vec![0, 4, 2, 6, 1, 5, 3, 7]; // bit-reversed row order
149        assert_eq!(reordered.values, expected_values);
150    }
151
152    #[test]
153    fn test_bit_reversible_matrix_trait_forward_reverse() {
154        let original = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
155        let reversed_view = original.clone().bit_reverse_rows(); // -> BitReversedMatrixView
156        let back_to_dense = reversed_view.bit_reverse_rows(); // -> back to DenseMatrix
157
158        assert_eq!(original.values, back_to_dense.values);
159        assert_eq!(original.width(), back_to_dense.width());
160    }
161
162    #[test]
163    #[should_panic]
164    fn test_new_view_panics_non_pow2_height() {
165        // This matrix has height = 3 (not a power of two)
166        let matrix = RowMajorMatrix::new(vec![1, 2, 3], 1);
167        let _ = BitReversalPerm::new_view(matrix);
168    }
169}