use scirs2_core::numeric::Complex;
use std::f64::consts::PI;
use crate::error::{LinalgError, LinalgResult};
type C64 = Complex<f64>;
fn pad_to_power_of_two(v: &[C64]) -> Vec<C64> {
let n = v.len();
let m = n.next_power_of_two();
let mut out = v.to_vec();
out.resize(m, C64::new(0.0, 0.0));
out
}
fn fft_inplace(data: &mut [C64], inverse: bool) {
let n = data.len();
debug_assert!(n.is_power_of_two(), "FFT size must be a power of 2");
let log2n = n.trailing_zeros() as usize;
for i in 0..n {
let j = reverse_bits(i, log2n);
if j > i {
data.swap(i, j);
}
}
let mut h = 1usize;
while h < n {
let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
let angle_step = sign * PI / h as f64;
let w_n = C64::new(angle_step.cos(), angle_step.sin());
let mut i = 0;
while i < n {
let mut w = C64::new(1.0, 0.0);
for j in 0..h {
let u = data[i + j];
let v = w * data[i + j + h];
data[i + j] = u + v;
data[i + j + h] = u - v;
w = w * w_n;
}
i += 2 * h;
}
h *= 2;
}
if inverse {
let scale = 1.0 / n as f64;
for d in data.iter_mut() {
*d = *d * scale;
}
}
}
fn reverse_bits(mut x: usize, bits: usize) -> usize {
let mut result = 0usize;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
fn rfft(x: &[f64]) -> Vec<C64> {
let mut buf: Vec<C64> = x.iter().map(|&v| C64::new(v, 0.0)).collect();
fft_inplace(&mut buf, false);
buf
}
fn irfft(x: &mut [C64]) -> Vec<f64> {
fft_inplace(x, true);
x.iter().map(|c| c.re).collect()
}
pub fn circulant_matmul(c: &[f64], x: &[f64]) -> LinalgResult<Vec<f64>> {
let n = c.len();
if x.len() != n {
return Err(LinalgError::DimensionError(format!(
"circulant_matmul: c has length {} but x has length {}",
n,
x.len()
)));
}
if n == 0 {
return Ok(Vec::new());
}
if n.is_power_of_two() {
let c_fft = rfft(c);
let x_fft = rfft(x);
let mut prod: Vec<C64> = c_fft
.iter()
.zip(x_fft.iter())
.map(|(&cf, &xf)| cf * xf)
.collect();
let y_full = irfft(&mut prod);
Ok(y_full[..n].to_vec())
} else {
let conv_len = 2 * n - 1;
let m = conv_len.next_power_of_two();
let mut c_ext = vec![0.0f64; m];
c_ext[..n].copy_from_slice(c);
let mut x_ext = vec![0.0f64; m];
x_ext[..n].copy_from_slice(x);
let c_fft = rfft(&c_ext);
let x_fft = rfft(&x_ext);
let mut prod: Vec<C64> = c_fft
.iter()
.zip(x_fft.iter())
.map(|(&cf, &xf)| cf * xf)
.collect();
let y_linear = irfft(&mut prod);
let mut y = vec![0.0f64; n];
for i in 0..n {
y[i] = y_linear[i];
if i + n < conv_len {
y[i] += y_linear[i + n];
}
}
Ok(y)
}
}
pub fn toeplitz_matmul(t: &[f64], x: &[f64]) -> LinalgResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Ok(Vec::new());
}
if t.len() != 2 * n - 1 {
return Err(LinalgError::DimensionError(format!(
"toeplitz_matmul: t must have length 2n-1={} for n={}, got {}",
2 * n - 1,
n,
t.len()
)));
}
let two_n = 2 * n;
let mut c = vec![0.0f64; two_n];
c[0] = t[n - 1];
for k in 1..n {
c[k] = t[n - 1 + k];
}
for k in 1..n {
c[two_n - k] = t[n - 1 - k];
}
let mut x_ext = vec![0.0f64; two_n];
x_ext[..n].copy_from_slice(x);
let m = two_n.next_power_of_two();
let mut c_padded = c.clone();
c_padded.resize(m, 0.0);
let mut x_padded = x_ext.clone();
x_padded.resize(m, 0.0);
let c_fft = rfft(&c_padded);
let x_fft = rfft(&x_padded);
let mut prod: Vec<C64> = c_fft
.iter()
.zip(x_fft.iter())
.map(|(&cf, &xf)| cf * xf)
.collect();
let y_full = irfft(&mut prod);
Ok(y_full[..n].to_vec())
}
pub fn convolve_matmul(f: &[f64], g: &[f64]) -> LinalgResult<Vec<f64>> {
if f.is_empty() || g.is_empty() {
return Ok(Vec::new());
}
let output_len = f.len() + g.len() - 1;
let m = output_len.next_power_of_two();
let mut f_ext = f.to_vec();
f_ext.resize(m, 0.0);
let mut g_ext = g.to_vec();
g_ext.resize(m, 0.0);
let f_fft = rfft(&f_ext);
let g_fft = rfft(&g_ext);
let mut prod: Vec<C64> = f_fft
.iter()
.zip(g_fft.iter())
.map(|(&ff, &gf)| ff * gf)
.collect();
let h_full = irfft(&mut prod);
Ok(h_full[..output_len].to_vec())
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FFTSolverType {
Circulant,
Toeplitz,
}
#[derive(Debug, Clone)]
pub struct FFTBasedSolver {
pub solver_type: FFTSolverType,
defining_vector: Vec<f64>,
pub n: usize,
fft_buf: Vec<C64>,
}
impl FFTBasedSolver {
pub fn new_circulant(c: &[f64]) -> LinalgResult<Self> {
let n = c.len();
if n == 0 {
return Err(LinalgError::ShapeError(
"FFTBasedSolver: circulant size must be positive".to_string(),
));
}
let m = n.next_power_of_two();
let mut c_ext = c.to_vec();
c_ext.resize(m, 0.0);
let fft_buf = rfft(&c_ext);
let min_abs = fft_buf.iter().map(|z| z.norm()).fold(f64::INFINITY, f64::min);
if min_abs < 1e-300 {
return Err(LinalgError::SingularMatrixError(
"FFTBasedSolver: circulant matrix is singular (zero eigenvalue)".to_string(),
));
}
Ok(Self {
solver_type: FFTSolverType::Circulant,
defining_vector: c.to_vec(),
n,
fft_buf,
})
}
pub fn new_toeplitz(t: &[f64], n: usize) -> LinalgResult<Self> {
if n == 0 {
return Err(LinalgError::ShapeError(
"FFTBasedSolver: Toeplitz size must be positive".to_string(),
));
}
if t.len() != 2 * n - 1 {
return Err(LinalgError::DimensionError(format!(
"FFTBasedSolver: Toeplitz t must have length 2n-1={}, got {}",
2 * n - 1,
t.len()
)));
}
let mut chan_c = vec![0.0f64; n];
chan_c[0] = t[n - 1];
for k in 1..n {
let sub = t[n - 1 - k];
let sup = t[n - 1 + k];
let weight = (n - k) as f64 / n as f64;
let weight2 = k as f64 / n as f64;
chan_c[k] = weight * sub + weight2 * sup;
}
let m = n.next_power_of_two();
let mut chan_ext = chan_c.clone();
chan_ext.resize(m, 0.0);
let fft_buf = rfft(&chan_ext);
Ok(Self {
solver_type: FFTSolverType::Toeplitz,
defining_vector: t.to_vec(),
n,
fft_buf,
})
}
pub fn solve(
&self,
b: &[f64],
max_iter: Option<usize>,
tol: Option<f64>,
) -> LinalgResult<Vec<f64>> {
if b.len() != self.n {
return Err(LinalgError::DimensionError(format!(
"FFTBasedSolver::solve: b has length {} but n={}",
b.len(),
self.n
)));
}
match self.solver_type {
FFTSolverType::Circulant => self.solve_circulant(b),
FFTSolverType::Toeplitz => {
let max_it = max_iter.unwrap_or(50);
let tolerance = tol.unwrap_or(1e-10);
self.solve_toeplitz_pcg(b, max_it, tolerance)
}
}
}
pub fn matvec(&self, x: &[f64]) -> LinalgResult<Vec<f64>> {
if x.len() != self.n {
return Err(LinalgError::DimensionError(format!(
"FFTBasedSolver::matvec: x has length {} but n={}",
x.len(),
self.n
)));
}
match self.solver_type {
FFTSolverType::Circulant => circulant_matmul(&self.defining_vector, x),
FFTSolverType::Toeplitz => toeplitz_matmul(&self.defining_vector, x),
}
}
fn solve_circulant(&self, b: &[f64]) -> LinalgResult<Vec<f64>> {
let n = self.n;
let m = n.next_power_of_two();
let mut b_ext = b.to_vec();
b_ext.resize(m, 0.0);
let b_fft = rfft(&b_ext);
let mut x_fft: Vec<C64> = b_fft
.iter()
.zip(self.fft_buf.iter())
.map(|(&bf, &cf)| {
let denom = cf.norm_sqr();
if denom < 1e-300 {
C64::new(0.0, 0.0) } else {
C64::new(
(bf * cf.conj()).re / denom,
(bf * cf.conj()).im / denom,
)
}
})
.collect();
let x_full = irfft(&mut x_fft);
Ok(x_full[..n].to_vec())
}
fn solve_toeplitz_pcg(
&self,
b: &[f64],
max_iter: usize,
tol: f64,
) -> LinalgResult<Vec<f64>> {
let n = self.n;
let m = n.next_power_of_two();
let b_norm = b.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-300);
let mut x = vec![0.0f64; n];
let mut r = b.to_vec();
let apply_precond = |r_vec: &[f64]| -> LinalgResult<Vec<f64>> {
let mut r_ext = r_vec.to_vec();
r_ext.resize(m, 0.0);
let r_fft = rfft(&r_ext);
let mut z_fft: Vec<C64> = r_fft
.iter()
.zip(self.fft_buf.iter())
.map(|(&rf, &cf)| {
let denom = cf.norm_sqr();
if denom < 1e-300 {
C64::new(0.0, 0.0)
} else {
C64::new(
(rf * cf.conj()).re / denom,
(rf * cf.conj()).im / denom,
)
}
})
.collect();
let z_full = irfft(&mut z_fft);
Ok(z_full[..n].to_vec())
};
let z = apply_precond(&r)?;
let mut p = z.clone();
let mut rz = dot_f64(&r, &z);
for _ in 0..max_iter {
let r_norm = dot_f64(&r, &r).sqrt();
if r_norm / b_norm < tol {
break;
}
let q = toeplitz_matmul(&self.defining_vector, &p)?;
let pq = dot_f64(&p, &q);
if pq.abs() < 1e-300 {
break;
}
let alpha = rz / pq;
for i in 0..n {
x[i] += alpha * p[i];
}
for i in 0..n {
r[i] -= alpha * q[i];
}
let z_new = apply_precond(&r)?;
let rz_new = dot_f64(&r, &z_new);
let beta = rz_new / rz.max(1e-300);
rz = rz_new;
for i in 0..n {
p[i] = z_new[i] + beta * p[i];
}
let _ = z_new;
}
Ok(x)
}
}
#[derive(Debug, Clone)]
pub struct LevelToeplitz {
pub n1: usize,
pub n2: usize,
kernel: Vec<Vec<f64>>,
kernel_fft: Vec<Vec<C64>>,
m1: usize,
m2: usize,
}
impl LevelToeplitz {
pub fn new(n1: usize, n2: usize, kernel: Vec<Vec<f64>>) -> LinalgResult<Self> {
if n1 == 0 || n2 == 0 {
return Err(LinalgError::ShapeError(
"LevelToeplitz: n1 and n2 must be positive".to_string(),
));
}
let expected_rows = 2 * n1 - 1;
let expected_cols = 2 * n2 - 1;
if kernel.len() != expected_rows {
return Err(LinalgError::DimensionError(format!(
"LevelToeplitz: kernel must have {} rows (2*n1-1), got {}",
expected_rows,
kernel.len()
)));
}
for (r, row) in kernel.iter().enumerate() {
if row.len() != expected_cols {
return Err(LinalgError::DimensionError(format!(
"LevelToeplitz: kernel row {} must have {} cols (2*n2-1), got {}",
r,
expected_cols,
row.len()
)));
}
}
let m1 = (2 * n1).next_power_of_two();
let m2 = (2 * n2).next_power_of_two();
let mut circ_real = vec![vec![0.0f64; m2]; m1];
for r in 0..(2 * n1 - 1) {
for c_idx in 0..(2 * n2 - 1) {
let d1 = r as isize - (n1 as isize - 1);
let d2 = c_idx as isize - (n2 as isize - 1);
let r_circ = ((d1 % m1 as isize) + m1 as isize) as usize % m1;
let c_circ = ((d2 % m2 as isize) + m2 as isize) as usize % m2;
circ_real[r_circ][c_circ] = kernel[r][c_idx];
}
}
let kernel_fft = fft2d_real(&circ_real, m1, m2)?;
Ok(Self {
n1,
n2,
kernel,
kernel_fft,
m1,
m2,
})
}
pub fn matvec(&self, x: &[f64]) -> LinalgResult<Vec<f64>> {
let expected = self.n1 * self.n2;
if x.len() != expected {
return Err(LinalgError::DimensionError(format!(
"LevelToeplitz::matvec: x has length {} but n1*n2={}",
x.len(),
expected
)));
}
let m1 = self.m1;
let m2 = self.m2;
let mut x_pad = vec![vec![0.0f64; m2]; m1];
for i in 0..self.n1 {
for j in 0..self.n2 {
x_pad[i][j] = x[i * self.n2 + j];
}
}
let x_fft = fft2d_real(&x_pad, m1, m2)?;
let mut prod_fft = vec![vec![C64::new(0.0, 0.0); m2]; m1];
for i in 0..m1 {
for j in 0..m2 {
prod_fft[i][j] = self.kernel_fft[i][j] * x_fft[i][j];
}
}
let y_full = ifft2d(&mut prod_fft, m1, m2)?;
let mut y = vec![0.0f64; self.n1 * self.n2];
for i in 0..self.n1 {
for j in 0..self.n2 {
y[i * self.n2 + j] = y_full[i][j];
}
}
Ok(y)
}
pub fn solve(
&self,
b: &[f64],
max_iter: Option<usize>,
tol: Option<f64>,
) -> LinalgResult<Vec<f64>> {
let expected = self.n1 * self.n2;
if b.len() != expected {
return Err(LinalgError::DimensionError(format!(
"LevelToeplitz::solve: b has length {} but n1*n2={}",
b.len(),
expected
)));
}
let max_it = max_iter.unwrap_or(100);
let tolerance = tol.unwrap_or(1e-10);
let b_norm = dot_f64(b, b).sqrt().max(1e-300);
let mut x = vec![0.0f64; expected];
let mut r = b.to_vec();
let mut p = r.clone();
let mut rr = dot_f64(&r, &r);
for _ in 0..max_it {
if rr.sqrt() / b_norm < tolerance {
break;
}
let q = self.matvec(&p)?;
let pq = dot_f64(&p, &q);
if pq.abs() < 1e-300 {
break;
}
let alpha = rr / pq;
for i in 0..expected {
x[i] += alpha * p[i];
r[i] -= alpha * q[i];
}
let rr_new = dot_f64(&r, &r);
let beta = rr_new / rr.max(1e-300);
rr = rr_new;
for i in 0..expected {
p[i] = r[i] + beta * p[i];
}
}
Ok(x)
}
}
fn fft2d_real(input: &[Vec<f64>], m1: usize, m2: usize) -> LinalgResult<Vec<Vec<C64>>> {
let mut rows_fft: Vec<Vec<C64>> = Vec::with_capacity(m1);
for row in input {
let mut row_ext: Vec<C64> = row.iter().map(|&v| C64::new(v, 0.0)).collect();
row_ext.resize(m2, C64::new(0.0, 0.0));
fft_inplace(&mut row_ext, false);
rows_fft.push(row_ext);
}
while rows_fft.len() < m1 {
rows_fft.push(vec![C64::new(0.0, 0.0); m2]);
}
let mut result = rows_fft;
for j in 0..m2 {
let mut col: Vec<C64> = (0..m1).map(|i| result[i][j]).collect();
fft_inplace(&mut col, false);
for (i, val) in col.into_iter().enumerate() {
result[i][j] = val;
}
}
Ok(result)
}
fn ifft2d(input: &mut [Vec<C64>], m1: usize, m2: usize) -> LinalgResult<Vec<Vec<f64>>> {
for j in 0..m2 {
let mut col: Vec<C64> = (0..m1).map(|i| input[i][j]).collect();
fft_inplace(&mut col, true);
for (i, val) in col.into_iter().enumerate() {
input[i][j] = val;
}
}
let mut result = vec![vec![0.0f64; m2]; m1];
for (i, row) in input.iter_mut().enumerate() {
fft_inplace(row, true);
for (j, val) in row.iter().enumerate() {
result[i][j] = val.re;
}
}
Ok(result)
}
fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circulant_matmul_identity() {
let c = vec![1.0, 0.0, 0.0];
let x = vec![3.0, 1.0, 4.0];
let y = circulant_matmul(&c, &x).expect("circulant_matmul failed");
for i in 0..3 {
assert!((y[i] - x[i]).abs() < 1e-10, "Identity circulant failed at {i}");
}
}
#[test]
fn test_circulant_matmul_known() {
let c = vec![1.0, 2.0, 3.0];
let x = vec![1.0, 0.0, 0.0];
let y = circulant_matmul(&c, &x).expect("circulant_matmul failed");
assert!((y[0] - 1.0).abs() < 1e-9);
assert!((y[1] - 2.0).abs() < 1e-9);
assert!((y[2] - 3.0).abs() < 1e-9);
}
#[test]
fn test_toeplitz_matmul_identity_diag() {
let t = vec![0.0, 0.0, 1.0, 0.0, 0.0]; let x = vec![3.0, 1.0, 4.0];
let y = toeplitz_matmul(&t, &x).expect("toeplitz_matmul failed");
for i in 0..3 {
assert!((y[i] - x[i]).abs() < 1e-9, "Identity Toeplitz failed at {i}: got {}", y[i]);
}
}
#[test]
fn test_toeplitz_matmul_first_column() {
let n = 4;
let t = vec![7.0, 5.0, 3.0, 1.0, 2.0, 4.0, 6.0]; let x = vec![1.0, 0.0, 0.0, 0.0];
let y = toeplitz_matmul(&t, &x).expect("toeplitz_matmul failed");
assert!((y[0] - 1.0).abs() < 1e-9, "y[0]={}", y[0]);
assert!((y[1] - 2.0).abs() < 1e-9, "y[1]={}", y[1]);
assert!((y[2] - 4.0).abs() < 1e-9, "y[2]={}", y[2]);
assert!((y[3] - 6.0).abs() < 1e-9, "y[3]={}", y[3]);
}
#[test]
fn test_convolve_matmul() {
let f = vec![1.0, 2.0, 3.0];
let g = vec![1.0, 1.0];
let h = convolve_matmul(&f, &g).expect("convolve_matmul failed");
assert_eq!(h.len(), 4);
assert!((h[0] - 1.0).abs() < 1e-10);
assert!((h[1] - 3.0).abs() < 1e-10);
assert!((h[2] - 5.0).abs() < 1e-10);
assert!((h[3] - 3.0).abs() < 1e-10);
}
#[test]
fn test_fft_solver_circulant() {
let c = vec![4.0, 1.0, 1.0];
let solver = FFTBasedSolver::new_circulant(&c).expect("new_circulant failed");
let x_true = vec![1.0, 2.0, 3.0];
let b = circulant_matmul(&c, &x_true).expect("matvec failed");
let x_sol = solver.solve(&b, None, None).expect("solve failed");
for i in 0..3 {
assert!(
(x_sol[i] - x_true[i]).abs() < 1e-8,
"Circulant solve error at {i}: {} vs {}",
x_sol[i],
x_true[i]
);
}
}
#[test]
fn test_fft_solver_toeplitz() {
let n = 4;
let t = vec![0.1, 0.2, 0.3, 4.0, 0.3, 0.2, 0.1]; let solver = FFTBasedSolver::new_toeplitz(&t, n).expect("new_toeplitz failed");
let x_true = vec![1.0, -1.0, 2.0, 0.5];
let b = toeplitz_matmul(&t, &x_true).expect("matvec for b failed");
let x_sol = solver.solve(&b, Some(100), Some(1e-10)).expect("solve failed");
for i in 0..n {
assert!(
(x_sol[i] - x_true[i]).abs() < 1e-6,
"Toeplitz solve error at {i}: {} vs {}",
x_sol[i],
x_true[i]
);
}
}
#[test]
fn test_level_toeplitz_matvec() {
let n1 = 2;
let n2 = 2;
let mut kernel = vec![vec![0.0f64; 3]; 3];
kernel[n1 - 1][n2 - 1] = 1.0;
let lt = LevelToeplitz::new(n1, n2, kernel).expect("LevelToeplitz::new failed");
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = lt.matvec(&x).expect("matvec failed");
for i in 0..4 {
assert!(
(y[i] - x[i]).abs() < 1e-9,
"Level-Toeplitz identity failed at {i}: {} vs {}",
y[i],
x[i]
);
}
}
}