use core::borrow::BorrowMut;
use p3_maybe_rayon::prelude::*;
use p3_util::{log2_strict_usize, reverse_bits_len};
use tracing::instrument;
use crate::Matrix;
use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
#[instrument(level = "debug", skip_all)]
pub fn reverse_matrix_index_bits<'a, F, S>(mat: &mut DenseMatrix<F, S>)
where
F: Clone + Send + Sync + 'a,
S: DenseStorage<F> + BorrowMut<[F]>,
{
let w = mat.width();
let h = mat.height();
let log_h = log2_strict_usize(h);
let values = mat.values.borrow_mut().as_mut_ptr() as usize;
(0..h).into_par_iter().for_each(|i| {
let values = values as *mut F;
let j = reverse_bits_len(i, log_h);
if i < j {
unsafe { swap_rows_raw(values, w, i, j) };
}
});
}
pub fn swap_rows<F: Clone + Send + Sync>(mat: &mut RowMajorMatrix<F>, i: usize, j: usize) {
let w = mat.width();
let (upper, lower) = mat.values.split_at_mut(j * w);
let row_i = &mut upper[i * w..(i + 1) * w];
let row_j = &mut lower[..w];
row_i.swap_with_slice(row_j);
}
unsafe fn swap_rows_raw<F>(mat: *mut F, w: usize, i: usize, j: usize) {
unsafe {
let row_i = core::slice::from_raw_parts_mut(mat.add(i * w), w);
let row_j = core::slice::from_raw_parts_mut(mat.add(j * w), w);
row_i.swap_with_slice(row_j);
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
use crate::dense::RowMajorMatrix;
#[test]
fn test_swap_rows_basic() {
let mut matrix = RowMajorMatrix::new(
vec![
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, ],
3,
);
swap_rows(&mut matrix, 0, 2);
assert_eq!(
matrix.values,
vec![
7, 8, 9, 4, 5, 6, 1, 2, 3, 10, 11, 12, ]
);
}
#[test]
fn test_swap_rows_raw_basic() {
let mut matrix = RowMajorMatrix::new(
vec![
1, 2, 3, 4, 5, 6, 7, 8, 9, ],
3,
);
let ptr = matrix.values.as_mut_ptr();
unsafe {
swap_rows_raw(ptr, matrix.width(), 0, 2);
}
assert_eq!(
matrix.values,
vec![
7, 8, 9, 4, 5, 6, 1, 2, 3, ]
);
}
#[test]
fn test_reverse_matrix_index_bits_pow2_height() {
let mut matrix = RowMajorMatrix::new(
vec![
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, ],
2,
);
reverse_matrix_index_bits(&mut matrix);
assert_eq!(
matrix.values,
vec![
0, 1, 8, 9, 4, 5, 12, 13, 2, 3, 10, 11, 6, 7, 14, 15, ]
);
}
#[test]
fn test_reverse_matrix_index_bits_height_1() {
let mut matrix = RowMajorMatrix::new(
vec![
42, 43, ],
2,
);
reverse_matrix_index_bits(&mut matrix);
assert_eq!(
matrix.values,
vec![
42, 43, ]
);
}
#[test]
#[should_panic]
fn test_reverse_matrix_index_bits_non_power_of_two_should_panic() {
let mut matrix = RowMajorMatrix::new(
vec![
1, 2, 3, 4, 5, 6, ],
2,
);
reverse_matrix_index_bits(&mut matrix);
}
}