use core::borrow::BorrowMut;
use p3_field::{Field, PrimeCharacteristicRing};
use p3_matrix::Matrix;
use p3_matrix::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
use p3_util::log2_strict_usize;
use tracing::instrument;
#[instrument(skip_all, fields(dims = %mat.dimensions()))]
pub fn divide_by_height<F: Field, S: DenseStorage<F> + BorrowMut<[F]>>(
mat: &mut DenseMatrix<F, S>,
) {
let h = mat.height();
let log_h = log2_strict_usize(h);
let h_inv_subfield = F::PrimeSubfield::ONE.div_2exp_u64(log_h as u64);
let h_inv = F::from_prime_subfield(h_inv_subfield);
mat.scale(h_inv);
}
pub(crate) fn coset_shift_cols<F: Field>(mat: &mut RowMajorMatrix<F>, shift: F) {
mat.rows_mut()
.zip(shift.powers())
.for_each(|(row, weight)| {
row.iter_mut().for_each(|coeff| {
*coeff *= weight;
});
});
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;
use super::*;
type F = BabyBear;
#[test]
fn test_divide_by_height_2x2() {
let mut mat = RowMajorMatrix::new(
vec![F::from_u8(2), F::from_u8(4), F::from_u8(6), F::from_u8(8)],
2,
);
divide_by_height(&mut mat);
let expected = vec![F::from_u8(1), F::from_u8(2), F::from_u8(3), F::from_u8(4)];
assert_eq!(mat.values, expected);
}
#[test]
fn test_divide_by_height_1x4() {
let mut mat = RowMajorMatrix::new_row(vec![
F::from_u8(10),
F::from_u8(20),
F::from_u8(30),
F::from_u8(40),
]);
divide_by_height(&mut mat);
let expected = vec![
F::from_u8(10),
F::from_u8(20),
F::from_u8(30),
F::from_u8(40),
];
assert_eq!(mat.values, expected);
}
#[test]
#[should_panic]
fn test_divide_by_height_non_power_of_two_height_should_panic() {
let mut mat = RowMajorMatrix::new(vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)], 1);
divide_by_height(&mut mat);
}
#[test]
fn test_coset_shift_cols_3x2_shift_2() {
let mut mat = RowMajorMatrix::new(
vec![
F::from_u8(1),
F::from_u8(2),
F::from_u8(3),
F::from_u8(4),
F::from_u8(5),
F::from_u8(6),
],
2,
);
coset_shift_cols(&mut mat, F::from_u8(2));
let expected = vec![
F::from_u8(1),
F::from_u8(2),
F::from_u8(6),
F::from_u8(8),
F::from_u8(20),
F::from_u8(24),
];
assert_eq!(mat.values, expected);
}
#[test]
fn test_coset_shift_cols_identity_shift() {
let mut mat = RowMajorMatrix::new(
vec![F::from_u8(7), F::from_u8(8), F::from_u8(9), F::from_u8(10)],
2,
);
coset_shift_cols(&mut mat, F::from_u8(1));
let expected = vec![F::from_u8(7), F::from_u8(8), F::from_u8(9), F::from_u8(10)];
assert_eq!(mat.values, expected);
}
}