use scirs2_core::num_complex::Complex64;
use super::types::FlatCirculant;
use crate::error::{LinalgError, LinalgResult};
fn dft(x: &[f64]) -> Vec<Complex64> {
let n = x.len();
use std::f64::consts::TAU;
(0..n)
.map(|k| {
x.iter()
.enumerate()
.map(|(j, &xj)| {
let angle = -TAU * (k * j) as f64 / n as f64;
Complex64::new(xj, 0.0) * Complex64::from_polar(1.0, angle)
})
.sum()
})
.collect()
}
fn idft(x: &[Complex64]) -> Vec<f64> {
let n = x.len();
use std::f64::consts::TAU;
(0..n)
.map(|k| {
x.iter()
.enumerate()
.map(|(j, &xj)| {
let angle = TAU * (k * j) as f64 / n as f64;
(xj * Complex64::from_polar(1.0, angle)).re
})
.sum::<f64>()
/ n as f64
})
.collect()
}
fn prev_pow2(n: usize) -> usize {
if n == 0 {
return 0;
}
let mut p = 1;
while p * 2 <= n {
p *= 2;
}
p
}
fn is_power_of_two(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0
}
fn fft_inplace(buf: &mut [Complex64]) {
let n = buf.len();
debug_assert!(
is_power_of_two(n),
"fft_inplace requires power-of-two length"
);
let bits = n.trailing_zeros() as usize;
for i in 0..n {
let j = bit_reverse(i, bits);
if i < j {
buf.swap(i, j);
}
}
use std::f64::consts::TAU;
let mut len = 2usize;
while len <= n {
let half = len / 2;
let w_base = Complex64::from_polar(1.0, -TAU / len as f64);
for start in (0..n).step_by(len) {
let mut w = Complex64::new(1.0, 0.0);
for k in 0..half {
let u = buf[start + k];
let v = buf[start + k + half] * w;
buf[start + k] = u + v;
buf[start + k + half] = u - v;
w *= w_base;
}
}
len *= 2;
}
}
fn ifft_inplace(buf: &mut [Complex64]) {
for x in buf.iter_mut() {
*x = x.conj();
}
fft_inplace(buf);
let n = buf.len() as f64;
for x in buf.iter_mut() {
*x = x.conj() / n;
}
}
fn bit_reverse(mut x: usize, bits: usize) -> usize {
let mut r = 0usize;
for _ in 0..bits {
r = (r << 1) | (x & 1);
x >>= 1;
}
r
}
fn fast_dft(x: &[f64]) -> Vec<Complex64> {
let n = x.len();
if is_power_of_two(n) {
let mut buf: Vec<Complex64> = x.iter().map(|&v| Complex64::new(v, 0.0)).collect();
fft_inplace(&mut buf);
buf
} else {
dft(x)
}
}
fn fast_idft(x: &[Complex64]) -> Vec<f64> {
let n = x.len();
if is_power_of_two(n) {
let mut buf = x.to_vec();
ifft_inplace(&mut buf);
buf.iter().map(|v| v.re).collect()
} else {
idft(x)
}
}
impl FlatCirculant {
pub fn matvec(&self, x: &[f64]) -> LinalgResult<Vec<f64>> {
let n = self.n;
if x.len() != n {
return Err(LinalgError::ShapeError(format!(
"x has length {} but matrix is {}×{}",
x.len(),
n,
n
)));
}
let c_hat = fast_dft(&self.first_row);
let x_hat = fast_dft(x);
let y_hat: Vec<Complex64> = c_hat
.iter()
.zip(x_hat.iter())
.map(|(ci, xi)| ci * xi)
.collect();
Ok(fast_idft(&y_hat))
}
pub fn eigenvalues(&self) -> Vec<Complex64> {
fast_dft(&self.first_row)
}
pub fn solve(&self, b: &[f64]) -> LinalgResult<Vec<f64>> {
let n = self.n;
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"b has length {} but matrix is {}×{}",
b.len(),
n,
n
)));
}
let c_hat = fast_dft(&self.first_row);
let b_hat = fast_dft(b);
let thresh = 1e-14 * c_hat.iter().map(|z| z.norm()).fold(0.0_f64, f64::max);
let x_hat: Vec<Complex64> = c_hat
.iter()
.zip(b_hat.iter())
.map(|(ci, bi)| {
if ci.norm() < thresh {
Err(LinalgError::SingularMatrixError(
"circulant matrix is singular (zero eigenvalue)".into(),
))
} else {
Ok(bi / ci)
}
})
.collect::<LinalgResult<Vec<_>>>()?;
Ok(fast_idft(&x_hat))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn e1(n: usize) -> Vec<f64> {
let mut v = vec![0.0_f64; n];
v[0] = 1.0;
v
}
#[test]
fn test_circulant_matvec_first_basis() {
let c = vec![1.0_f64, 2.0, 3.0, 4.0];
let mat = FlatCirculant::new(c.clone()).expect("failed");
let y = mat.matvec(&e1(4)).expect("matvec failed");
for i in 0..4 {
assert_relative_eq!(y[i], c[i], epsilon = 1e-10);
}
}
#[test]
fn test_circulant_first_row_def() {
let c = vec![5.0_f64, 3.0, 2.0];
let n = 3;
let mat = FlatCirculant::new(c.clone()).expect("failed");
let dense = mat.to_dense();
assert_relative_eq!(dense[0], c[0], epsilon = 1e-12); assert_relative_eq!(dense[1], c[n - 1], epsilon = 1e-12); assert_relative_eq!(dense[2], c[n - 2], epsilon = 1e-12); }
#[test]
fn test_circulant_solve_roundtrip() {
let c = vec![4.0_f64, 1.0, 2.0, 1.0];
let mat = FlatCirculant::new(c).expect("failed");
let b = vec![1.0_f64, 2.0, 3.0, 4.0];
let x = mat.solve(&b).expect("solve failed");
let y = mat.matvec(&x).expect("matvec failed");
for i in 0..4 {
assert_relative_eq!(y[i], b[i], epsilon = 1e-8);
}
}
#[test]
fn test_circulant_eigenvalues_length() {
let c = vec![1.0_f64, 2.0, 3.0, 4.0];
let mat = FlatCirculant::new(c).expect("failed");
let eigs = mat.eigenvalues();
assert_eq!(eigs.len(), 4);
}
}