use p3_util::{log2_strict_usize, reverse_bits_len};
use crate::Matrix;
use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
use crate::util::reverse_matrix_index_bits;
pub trait BitReversibleMatrix<T: Send + Sync + Clone>: Matrix<T> {
type BitRev: BitReversibleMatrix<T>;
fn bit_reverse_rows(self) -> Self::BitRev;
}
#[derive(Debug)]
pub struct BitReversalPerm {
log_height: usize,
}
impl BitReversalPerm {
pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
inner: Inner,
) -> BitReversedMatrixView<Inner> {
RowIndexMappedView {
index_map: Self {
log_height: log2_strict_usize(inner.height()),
},
inner,
}
}
}
impl RowIndexMap for BitReversalPerm {
fn height(&self) -> usize {
1 << self.log_height
}
fn map_row_index(&self, r: usize) -> usize {
reverse_bits_len(r, self.log_height)
}
fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
&self,
inner: Inner,
) -> RowMajorMatrix<T> {
let mut inner = inner.to_row_major_matrix();
reverse_matrix_index_bits(&mut inner);
inner
}
}
pub type BitReversedMatrixView<Inner> = RowIndexMappedView<BitReversalPerm, Inner>;
impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T>
for BitReversedMatrixView<DenseMatrix<T, S>>
{
type BitRev = DenseMatrix<T, S>;
fn bit_reverse_rows(self) -> Self::BitRev {
self.inner
}
}
impl<T: Clone + Send + Sync, S: DenseStorage<T>> BitReversibleMatrix<T> for DenseMatrix<T, S> {
type BitRev = BitReversedMatrixView<Self>;
fn bit_reverse_rows(self) -> Self::BitRev {
BitReversalPerm::new_view(self)
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use super::*;
#[test]
fn test_bit_reversal_perm_map_index() {
let perm = BitReversalPerm { log_height: 3 }; assert_eq!(perm.map_row_index(0), 0); assert_eq!(perm.map_row_index(1), 4); assert_eq!(perm.map_row_index(2), 2); assert_eq!(perm.map_row_index(3), 6); assert_eq!(perm.map_row_index(4), 1); assert_eq!(perm.map_row_index(5), 5); assert_eq!(perm.map_row_index(6), 3); assert_eq!(perm.map_row_index(7), 7); }
#[test]
fn test_bit_reversal_perm_height() {
let perm = BitReversalPerm { log_height: 3 };
assert_eq!(perm.height(), 8); }
#[test]
fn test_new_view_reverses_indices_correctly() {
let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
let bitrev = BitReversalPerm::new_view(matrix);
let expected = [0, 4, 2, 6, 1, 5, 3, 7];
for (i, &expected_row_idx) in expected.iter().enumerate() {
let row: Vec<_> = bitrev.row(i).unwrap().into_iter().collect();
assert_eq!(row, vec![expected_row_idx]);
}
}
#[test]
fn test_to_row_major_matrix_applies_reverse_matrix_index_bits() {
let matrix = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
let perm = BitReversalPerm { log_height: 3 };
let reordered = perm.to_row_major_matrix(matrix);
let expected_values = vec![0, 4, 2, 6, 1, 5, 3, 7]; assert_eq!(reordered.values, expected_values);
}
#[test]
fn test_bit_reversible_matrix_trait_forward_reverse() {
let original = RowMajorMatrix::new((0u32..8).collect::<Vec<_>>(), 1);
let reversed_view = original.clone().bit_reverse_rows(); let back_to_dense = reversed_view.bit_reverse_rows();
assert_eq!(original.values, back_to_dense.values);
assert_eq!(original.width(), back_to_dense.width());
}
#[test]
#[should_panic]
fn test_new_view_panics_non_pow2_height() {
let matrix = RowMajorMatrix::new(vec![1, 2, 3], 1);
let _ = BitReversalPerm::new_view(matrix);
}
}