use core::ops::{AddAssign, Mul};
use p3_dft::TwoAdicSubgroupDft;
use p3_field::{PrimeCharacteristicRing, TwoAdicField};
#[inline(always)]
pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
where
T: Copy + AddAssign + Mul<Output = T>,
{
debug_assert_ne!(N, 0);
let mut dp = u[0] * v[0];
for i in 1..N {
dp += u[i] * v[i];
}
dp
}
pub fn apply_circulant<R: PrimeCharacteristicRing, const N: usize>(
circ_matrix: &[u64; N],
input: &[R; N],
) -> [R; N] {
let mut matrix = circ_matrix.map(R::from_u64);
let mut output = [R::ZERO; N];
for out_i in output.iter_mut().take(N - 1) {
*out_i = R::dot_product(&matrix, input);
matrix.rotate_right(1);
}
output[N - 1] = R::dot_product(&matrix, input);
output
}
pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
let mut output = *v;
let mut i = 1;
while i < N {
output[i] = v[N - i];
i += 1;
}
output
}
#[inline]
pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
fft: &FFT,
column: [u64; N],
input: &[F; N],
) -> [F; N] {
let column = column.map(F::from_u64).to_vec();
let matrix = fft.dft(column);
let input = fft.dft(input.to_vec());
let product = matrix.iter().zip(input).map(|(&x, y)| x * y).collect();
let output = fft.idft(product);
output.try_into().unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_first_row_to_first_col_even_length() {
let input = [0, 1, 2, 3, 4, 5];
let output = [0, 5, 4, 3, 2, 1];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_first_row_to_first_col_odd_length() {
let input = [10, 20, 30, 40, 50];
let output = [10, 50, 40, 30, 20];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_first_row_to_first_col_single_element() {
let input = [42];
let output = [42];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_first_row_to_first_col_all_zeros() {
let input = [0; 6];
let output = [0; 6];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_first_row_to_first_col_negative_numbers() {
let input = [-1, -2, -3, -4];
let output = [-1, -4, -3, -2];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_first_row_to_first_col_large_numbers() {
let input = [1_000_000, 2_000_000, 3_000_000, 4_000_000];
let output = [1_000_000, 4_000_000, 3_000_000, 2_000_000];
assert_eq!(first_row_to_first_col(&input), output);
}
#[test]
fn test_basic_dot_product() {
let u = [1, 2, 3];
let v = [4, 5, 6];
assert_eq!(dot_product(u, v), 4 + 2 * 5 + 3 * 6);
}
#[test]
fn test_single_element() {
let u = [7];
let v = [8];
assert_eq!(dot_product(u, v), 7 * 8);
}
#[test]
fn test_all_zeros() {
let u = [0; 4];
let v = [0; 4];
assert_eq!(dot_product(u, v), 0);
}
#[test]
fn test_negative_numbers() {
let u = [-1, -2, -3];
let v = [-4, -5, -6];
assert_eq!(dot_product(u, v), (-1) * (-4) + (-2) * (-5) + (-3) * (-6));
}
#[test]
fn test_large_numbers() {
let u = [1_000_000, 2_000_000, 3_000_000];
let v = [4, 5, 6];
assert_eq!(
dot_product(u, v),
1_000_000 * 4 + 2_000_000 * 5 + 3_000_000 * 6
);
}
}