use crate::hofft::{nd_col_major_strides, nd_fiber_base, Engine, EngineND, Extension};
use lobatto::collocation::{CollocationBasis, Gauss};
use nalgebra::*;
use rayon::prelude::*;
use rustfft::num_complex::Complex;
use std::f64::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BoundaryCondition {
Periodic,
Dirichlet,
Neumann,
}
impl BoundaryCondition {
pub(crate) fn extension(self) -> Extension {
match self {
BoundaryCondition::Periodic => Extension::Periodic,
BoundaryCondition::Dirichlet => Extension::Odd,
BoundaryCondition::Neumann => Extension::Even,
}
}
}
pub struct Poisson {
r: usize,
pub planner: Engine,
eigenvalues: Vec<DVector<f64>>,
eigenvectors: Vec<DMatrix<Complex<f64>>>,
periodic_mass_matrix: DVector<f64>,
bc: BoundaryCondition,
}
impl Poisson {
pub fn new(
l: f64,
xl: f64,
n: usize,
r: usize,
alpha: f64,
beta: f64,
bc: BoundaryCondition,
) -> Self {
let lobatto_basis = CollocationBasis::new(vec![(r + 1, Gauss::Lobatto)]);
let h = l / (n as f64);
let planner = Engine::new(n, r, l, xl, bc.extension());
let n_freq = match bc {
BoundaryCondition::Periodic => n,
BoundaryCondition::Dirichlet => n + 1,
BoundaryCondition::Neumann => n + 1,
};
let coeff_phase = if bc == BoundaryCondition::Periodic {
2.0
} else {
1.0
};
let mut symbol = vec![DMatrix::<Complex::<f64>>::zeros(r, r); n_freq];
let mass_matrix = lobatto_basis.weights_matrix();
let periodic_mass_matrix = DVector::from_fn(r, |i, _| {
if i == 0 {
h * 2.0 * mass_matrix[(0, 0)]
} else {
h * mass_matrix[(i, i)]
}
});
let diff_matrix = lobatto_basis.diff_matrix(0);
let stiffness_matrix: Matrix<f64, Dyn, Dyn, VecStorage<f64, Dyn, Dyn>> =
diff_matrix.transpose() * mass_matrix.clone() * diff_matrix;
for k in 0..n_freq {
let phase = -coeff_phase * PI * (k as f64) / (n as f64);
let cos = phase.cos();
let sin = phase.sin();
let exp = cos + Complex::<f64>::I * sin;
for i in 1..r {
for j in 1..r {
symbol[k][(i, j)] = (alpha * h * mass_matrix[(i, j)]
+ beta * stiffness_matrix[(i, j)] / h)
.into();
}
}
symbol[k][(0, 0)] = (2.0 * alpha * h * mass_matrix[(0, 0)]
+ 2.0 * beta * stiffness_matrix[(0, 0)] / h
+ 2.0 * beta * stiffness_matrix[(0, r)] * cos / h)
.into();
for i in 1..r {
symbol[k][(0, i)] =
beta * stiffness_matrix[(0, i)] / h + beta * stiffness_matrix[(r, i)] * exp / h;
symbol[k][(i, 0)] = symbol[k][(0, i)].conj();
}
}
let mass_sqrt_inv_c: DMatrix<Complex<f64>> =
DMatrix::from_diagonal(&DVector::from_fn(r, |i, _| {
Complex::new(1.0 / periodic_mass_matrix[i].sqrt(), 0.0)
}));
let eigs: Vec<SymmetricEigen<Complex<f64>, Dyn>> = symbol
.iter()
.map(|s| SymmetricEigen::new(&mass_sqrt_inv_c * s * &mass_sqrt_inv_c))
.collect();
let eigenvalues: Vec<DVector<f64>> = eigs.iter().map(|e| e.eigenvalues.clone()).collect();
let eigenvectors: Vec<DMatrix<Complex<f64>>> = eigs
.iter()
.map(|e| &mass_sqrt_inv_c * e.eigenvectors.clone())
.collect();
Self {
r,
planner,
eigenvalues,
eigenvectors,
periodic_mass_matrix,
bc,
}
}
pub fn get_x(&self) -> Vec<f64> {
self.planner.get_x()
}
pub fn solve_in_place(&self, f: &mut [Complex<f64>]) {
let j_offset = if self.bc == BoundaryCondition::Dirichlet {
1
} else {
0
};
for (idx, v) in f.iter_mut().enumerate() {
*v *= self.periodic_mass_matrix[(idx + j_offset) % self.r];
}
self.planner.forward(f);
let n_freq = self.eigenvalues.len();
let mut col = DVector::<Complex<f64>>::zeros(self.r);
const PINV_TOL: f64 = 1e-12;
for k in 0..n_freq {
self.planner.get_values(k, f, col.as_mut_slice());
let mut v = self.eigenvectors[k].ad_mul(&col);
for j in 0..self.r {
let lam = self.eigenvalues[k][j];
if lam.abs() > PINV_TOL {
v[j] /= lam;
} else {
v[j] = Complex::new(0.0, 0.0);
}
}
col.gemv(
Complex::new(1.0, 0.0),
&self.eigenvectors[k],
&v,
Complex::new(0.0, 0.0),
);
self.planner.set_values(k, f, col.as_slice());
}
self.planner.inverse(f);
}
}
pub struct PoissonND<const N: usize> {
solvers: [Poisson; N],
pub fft: EngineND<N>,
periodic_mass_matrix: DVector<f64>,
alpha: f64,
beta: f64,
n_freqs: [usize; N],
n_freq_total: usize,
n_freq_strides: [usize; N],
j_offsets: [usize; N],
}
impl<const N: usize> PoissonND<N> {
pub fn new(
ls: [f64; N],
xls: [f64; N],
ns: [usize; N],
rs: [usize; N],
alpha: f64,
beta: f64,
bc: [BoundaryCondition; N],
) -> Self {
let solvers =
std::array::from_fn(|d| Poisson::new(ls[d], xls[d], ns[d], rs[d], 0.0, 1.0, bc[d]));
let periodic_mass_matrix =
(1..N).fold(solvers[0].periodic_mass_matrix.clone(), |acc, d| {
let m_d = &solvers[d].periodic_mass_matrix;
DVector::from_iterator(
acc.len() * m_d.len(),
m_d.iter().flat_map(|&w| acc.iter().map(move |&v| v * w)),
)
});
let extension = std::array::from_fn(|d| bc[d].extension());
let fft = EngineND::new(ns, rs, ls, xls, extension);
let n_freqs: [usize; N] = std::array::from_fn(|d| solvers[d].eigenvalues.len());
let n_freq_total = n_freqs.iter().product();
let n_freq_strides = nd_col_major_strides(&n_freqs);
let j_offsets: [usize; N] = std::array::from_fn(|d| {
if bc[d] == BoundaryCondition::Dirichlet {
1
} else {
0
}
});
Self {
solvers,
fft,
periodic_mass_matrix,
alpha,
beta,
n_freqs,
n_freq_total,
n_freq_strides,
j_offsets,
}
}
pub fn get_x(&self) -> Vec<[f64; N]> {
self.fft.get_x()
}
pub fn solve_in_place(&self, f: &mut [Complex<f64>]) {
let rs = &self.fft.rs;
let r_strides = &self.fft.r_strides;
let r_total = self.fft.r_total;
f.par_iter_mut().enumerate().for_each(|(p, fp)| {
let lj = (0..N).fold(0, |acc, d| {
let dof_d = (p / self.fft.strides[d]) % self.fft.ndofs[d];
acc + ((dof_d + self.j_offsets[d]) % rs[d]) * r_strides[d]
});
*fp *= self.periodic_mass_matrix[lj];
});
self.fft.forward(f);
let ptr = f.as_mut_ptr() as usize;
(0..self.n_freq_total).into_par_iter().for_each(|k_flat| {
let ks: [usize; N] =
std::array::from_fn(|d| (k_flat / self.n_freq_strides[d]) % self.n_freqs[d]);
let mut v = DVector::<Complex<f64>>::zeros(r_total);
{
let f_ro = unsafe {
std::slice::from_raw_parts(ptr as *const Complex<f64>, self.fft.total)
};
self.fft.get_values(&ks, f_ro, v.as_mut_slice());
}
for d in 0..N {
let r_d = rs[d];
let step = r_strides[d] - 1;
for s in 0..r_total / r_d {
let bj = nd_fiber_base(s, d, rs, r_strides);
let x = self.solvers[d].eigenvectors[ks[d]]
.ad_mul(&v.rows_with_step(bj, r_d, step));
v.rows_with_step_mut(bj, r_d, step).set_column(0, &x);
}
}
const PINV_TOL: f64 = 1e-12;
for j_flat in 0..r_total {
let lambda = self.alpha
+ (0..N)
.map(|d| {
let j_d = (j_flat / r_strides[d]) % rs[d];
self.beta * self.solvers[d].eigenvalues[ks[d]][j_d]
})
.sum::<f64>();
if lambda.abs() > PINV_TOL {
v[j_flat] /= lambda;
} else {
v[j_flat] = Complex::new(0.0, 0.0);
}
}
for d in 0..N {
let r_d = rs[d];
let step = r_strides[d] - 1;
for s in 0..r_total / r_d {
let bj = nd_fiber_base(s, d, rs, r_strides);
let x = &self.solvers[d].eigenvectors[ks[d]] * v.rows_with_step(bj, r_d, step);
v.rows_with_step_mut(bj, r_d, step).set_column(0, &x);
}
}
{
let f_rw = unsafe {
std::slice::from_raw_parts_mut(ptr as *mut Complex<f64>, self.fft.total)
};
self.fft.set_values(&ks, f_rw, v.as_slice());
}
});
self.fft.inverse(f);
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{DMatrix, DVector};
#[test]
fn test_symbol_eigendecomposition() {
const ORTHO_TOL: f64 = 1e-10;
const RECON_TOL: f64 = 1e-8;
for &r in &[1, 2, 3, 4, 5, 6] {
for &n in &[4, 16, 32, 64, 128] {
let solver = Poisson::new(1.0, 0.0, n, r, 1.0, 1.0, BoundaryCondition::Periodic);
let id = DMatrix::<Complex<f64>>::identity(r, r);
let m_c = DMatrix::from_diagonal(&DVector::from_fn(r, |i, _| {
Complex::new(solver.periodic_mass_matrix[i], 0.0)
}));
let m_inv_c = DMatrix::from_diagonal(&DVector::from_fn(r, |i, _| {
Complex::new(1.0 / solver.periodic_mass_matrix[i], 0.0)
}));
for k in 0..n {
let v = &solver.eigenvectors[k];
let vtmv = v.adjoint() * &m_c * v;
let ortho_err = (&vtmv - &id).norm();
assert!(
ortho_err < ORTHO_TOL,
"r={r} n={n} k={k}: eigenvectors not M-orthonormal (err={ortho_err:.2e})"
);
let vvt = v * v.adjoint();
let vvt_err = (&vvt - &m_inv_c).norm();
assert!(
vvt_err < ORTHO_TOL,
"r={r} n={n} k={k}: V V† ≠ M⁻¹ (err={vvt_err:.2e})"
);
let lambda_diag = DMatrix::from_diagonal(&DVector::from_fn(r, |i, _| {
Complex::new(solver.eigenvalues[k][i], 0.0)
}));
let lambda_inv = DMatrix::from_diagonal(&DVector::from_fn(r, |i, _| {
let lam = solver.eigenvalues[k][i];
Complex::new(if lam.abs() > 1e-12 { 1.0 / lam } else { 0.0 }, 0.0)
}));
let rhs = v * &lambda_inv * v.adjoint() * &m_c * v * &lambda_diag;
let recon_err = (v - &rhs).norm() / v.norm().max(1.0);
assert!(
recon_err < RECON_TOL,
"r={r} n={n} k={k}: eigendecomposition round-trip failed (err={recon_err:.2e})"
);
}
}
}
}
#[test]
fn test_nd_mass_matrix_inverse_tensorial() {
const TOL: f64 = 1e-12;
let kron = |a: &DMatrix<Complex<f64>>, b: &DMatrix<Complex<f64>>| {
let (ra, ca) = (a.nrows(), a.ncols());
let (rb, cb) = (b.nrows(), b.ncols());
DMatrix::from_fn(ra * rb, ca * cb, |i, j| {
a[(i / rb, j / cb)] * b[(i % rb, j % cb)]
})
};
{
let solver = PoissonND::<1>::new(
[1.0],
[0.0],
[4],
[3],
1.0,
0.0,
[BoundaryCondition::Periodic; 1],
);
let v = &solver.solvers[0].eigenvectors[0];
let vvt = v * v.adjoint();
let r = v.nrows();
for i in 0..r {
for j in 0..r {
let expected = if i == j {
Complex::new(1.0 / solver.solvers[0].periodic_mass_matrix[i], 0.0)
} else {
Complex::new(0.0, 0.0)
};
assert!((vvt[(i, j)] - expected).norm() < TOL, "1D [{i},{j}]");
}
}
}
{
let solver = PoissonND::<2>::new(
[1.0; 2],
[0.0; 2],
[4; 2],
[2; 2],
1.0,
0.0,
[BoundaryCondition::Periodic; 2],
);
let v_nd = kron(
&solver.solvers[0].eigenvectors[0],
&solver.solvers[1].eigenvectors[0],
);
let vvt = &v_nd * v_nd.adjoint();
let r_total = v_nd.nrows();
for i in 0..r_total {
for j in 0..r_total {
let expected = if i == j {
Complex::new(1.0 / solver.periodic_mass_matrix[i], 0.0)
} else {
Complex::new(0.0, 0.0)
};
assert!((vvt[(i, j)] - expected).norm() < TOL, "2D [{i},{j}]");
}
}
}
{
let solver = PoissonND::<3>::new(
[1.0; 3],
[0.0; 3],
[4; 3],
[2; 3],
1.0,
0.0,
[BoundaryCondition::Periodic; 3],
);
let v_nd = kron(
&kron(
&solver.solvers[0].eigenvectors[0],
&solver.solvers[1].eigenvectors[0],
),
&solver.solvers[2].eigenvectors[0],
);
let vvt = &v_nd * v_nd.adjoint();
let r_total = v_nd.nrows();
for i in 0..r_total {
for j in 0..r_total {
let expected = if i == j {
Complex::new(1.0 / solver.periodic_mass_matrix[i], 0.0)
} else {
Complex::new(0.0, 0.0)
};
let scale = expected.norm().max(1.0_f64);
assert!(
(vvt[(i, j)] - expected).norm() / scale < TOL,
"3D [{i},{j}]"
);
}
}
}
}
}