use crate::kernel::{Complex, Float};
pub fn mod_inv(a: usize, m: usize) -> Option<usize> {
if m <= 1 {
return None;
}
let (mut old_r, mut r) = (a as isize, m as isize);
let (mut old_s, mut s) = (1_isize, 0_isize);
while r != 0 {
let q = old_r / r;
(old_r, r) = (r, old_r - q * r);
(old_s, s) = (s, old_s - q * s);
}
if old_r != 1 {
return None;
}
let result = old_s.rem_euclid(m as isize) as usize;
Some(result)
}
pub fn pfa_compose<T, K1, K2>(
input: &[Complex<T>],
output: &mut [Complex<T>],
n1: usize,
n2: usize,
kernel1: K1,
kernel2: K2,
sign: i32,
) where
T: Float,
K1: Fn(&mut [Complex<T>], i32),
K2: Fn(&mut [Complex<T>], i32),
{
let n = n1 * n2;
debug_assert_eq!(input.len(), n, "input length must equal n1*n2");
debug_assert_eq!(output.len(), n, "output length must equal n1*n2");
let inv_n2_mod_n1 =
mod_inv(n2 % n1, n1).expect("n2 must be invertible modulo n1 (gcd must be 1)");
let inv_n1_mod_n2 =
mod_inv(n1 % n2, n2).expect("n1 must be invertible modulo n2 (gcd must be 1)");
let mut work = vec![Complex::<T>::zero(); n];
for n1_idx in 0..n1 {
for n2_idx in 0..n2 {
let src_idx = (n1_idx * n2 * inv_n2_mod_n1 + n2_idx * n1 * inv_n1_mod_n2) % n;
work[n1_idx * n2 + n2_idx] = input[src_idx];
}
}
for row in 0..n1 {
let start = row * n2;
kernel2(&mut work[start..start + n2], sign);
}
let mut col_buf = vec![Complex::<T>::zero(); n1];
for col in 0..n2 {
for row in 0..n1 {
col_buf[row] = work[row * n2 + col];
}
kernel1(&mut col_buf, sign);
for row in 0..n1 {
work[row * n2 + col] = col_buf[row];
}
}
for k1 in 0..n1 {
for k2 in 0..n2 {
let dst_idx = (k1 * n2 + k2 * n1) % n;
output[dst_idx] = work[k1 * n2 + k2];
}
}
}
#[cfg(test)]
mod tests {
use super::super::winograd::{winograd_3, winograd_5, winograd_7};
use super::*;
fn naive_dft(x: &[Complex<f64>], sign: i32) -> Vec<Complex<f64>> {
let n = x.len();
let sign_f = f64::from(sign);
(0..n)
.map(|k| {
x.iter()
.enumerate()
.map(|(j, xj)| {
let angle = sign_f * core::f64::consts::TAU * (j * k) as f64 / n as f64;
let (s, c) = angle.sin_cos();
Complex::new(xj.re * c - xj.im * s, xj.re * s + xj.im * c)
})
.fold(Complex::new(0.0, 0.0), |acc, v| acc + v)
})
.collect()
}
fn check_eq(got: &[Complex<f64>], exp: &[Complex<f64>], tol: f64, label: &str) {
for (i, (g, e)) in got.iter().zip(exp.iter()).enumerate() {
let diff_re = (g.re - e.re).abs();
let diff_im = (g.im - e.im).abs();
assert!(
diff_re < tol && diff_im < tol,
"{label}[{i}]: got ({}, {}), expected ({}, {}), diff ({diff_re:.2e}, {diff_im:.2e})",
g.re, g.im, e.re, e.im
);
}
}
fn pfa_test(
n1: usize,
n2: usize,
k1: impl Fn(&mut [Complex<f64>], i32),
k2: impl Fn(&mut [Complex<f64>], i32),
) {
let n = n1 * n2;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i + 1) as f64 * 0.25, (i as f64 * 0.6).cos()))
.collect();
let tol = 1e-10 * n as f64;
let expected = naive_dft(&input, -1);
let mut output = vec![Complex::new(0.0, 0.0); n];
pfa_compose(&input, &mut output, n1, n2, &k1, &k2, -1);
check_eq(&output, &expected, tol, &format!("pfa_fwd({n1}x{n2})"));
let mut rt = vec![Complex::new(0.0, 0.0); n];
pfa_compose(&output, &mut rt, n1, n2, &k1, &k2, 1);
let n_f = n as f64;
for v in &mut rt {
*v = Complex::new(v.re / n_f, v.im / n_f);
}
check_eq(&rt, &input, tol, &format!("pfa_roundtrip({n1}x{n2})"));
}
#[test]
fn test_pfa_15() {
pfa_test(3, 5, winograd_3, winograd_5);
}
#[test]
fn test_pfa_21() {
pfa_test(3, 7, winograd_3, winograd_7);
}
#[test]
fn test_pfa_35() {
pfa_test(5, 7, winograd_5, winograd_7);
}
#[test]
fn test_mod_inv() {
assert_eq!(mod_inv(2, 5), Some(3)); assert_eq!(mod_inv(3, 5), Some(2)); assert_eq!(mod_inv(2, 3), Some(2)); assert_eq!(mod_inv(5, 3), Some(2)); assert_eq!(mod_inv(2, 4), None); assert_eq!(mod_inv(1, 1), None); }
}