use lobatto::collocation::{CollocationBasis, Gauss};
use lobatto::utilities::{barycentric_weights, lagrangian_interpolation};
use rayon::prelude::*;
use rustfft::{num_complex::Complex, Fft, FftPlanner};
use std::collections::HashMap;
use std::f64::consts::PI;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Extension {
Periodic,
Odd,
Even,
}
pub struct RawEngine {
n: usize,
r: usize,
l: f64,
xl: f64,
gauss_lobatto_points: Vec<f64>,
fw_fft: Arc<dyn Fft<f64>>,
i_fft: Arc<dyn Fft<f64>>,
}
impl RawEngine {
pub fn new(n: usize, r: usize, l: f64, xl: f64) -> Self {
let mut planner = FftPlanner::<f64>::new();
let fw_fft = planner.plan_fft_forward(n);
let i_fft = planner.plan_fft_inverse(n);
let gauss_lobatto_points = CollocationBasis::new(vec![(r + 1, Gauss::Lobatto)])
.points_1d(0)
.to_vec();
Self {
n,
r,
l,
xl,
gauss_lobatto_points,
fw_fft,
i_fft,
}
}
pub fn get_x(&self) -> Vec<f64> {
let h = self.l / (self.n as f64);
let mut x = vec![0.0_f64; self.n * self.r];
for k in 0..self.n {
for j in 0..self.r {
x[j * self.n + k] = h * ((k as f64) + self.gauss_lobatto_points[j]) + self.xl;
}
}
x
}
pub fn forward(&self, g: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.r * self.n);
self.fw_fft.process(g);
}
pub fn inverse(&self, g: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.r * self.n);
let norm = self.n as f64;
self.i_fft.process(g);
for v in g.iter_mut() {
*v /= norm;
}
}
}
pub struct Engine {
n: usize,
n_fft: usize,
n_dof: usize,
r: usize,
l: f64,
xl: f64,
extension: Extension,
gauss_lobatto_points: Vec<f64>,
bary_weights: Vec<f64>,
fw_fft: Arc<dyn Fft<f64>>,
i_fft: Arc<dyn Fft<f64>>,
unpack_table: Vec<(usize, Complex<f64>)>,
canonical_table: Vec<bool>,
packing_table: Vec<usize>,
}
impl Engine {
pub fn new(n: usize, r: usize, l: f64, xl: f64, extension: Extension) -> Self {
let n_fft = match extension {
Extension::Periodic => n,
Extension::Odd | Extension::Even => 2 * n,
};
let n_dof = match extension {
Extension::Periodic => r * n,
Extension::Odd => r * n - 1,
Extension::Even => r * n + 1,
};
let mut fft_planner = FftPlanner::<f64>::new();
let fw_fft = fft_planner.plan_fft_forward(n_fft);
let i_fft = fft_planner.plan_fft_inverse(n_fft);
let gauss_lobatto_points = CollocationBasis::new(vec![(r + 1, Gauss::Lobatto)])
.points_1d(0)
.to_vec();
let bary_weights: Vec<f64> = barycentric_weights(&gauss_lobatto_points);
let mut s = Self {
n,
n_fft,
n_dof,
r,
l,
xl,
extension,
gauss_lobatto_points,
bary_weights,
fw_fft,
i_fft,
unpack_table: vec![],
canonical_table: vec![],
packing_table: vec![],
};
let table_size = s.r * s.n_fft;
s.unpack_table = (0..table_size).map(|i| s.k_unpacking(i)).collect();
s.canonical_table = (0..table_size)
.map(|i| {
let (c, _) = s.k_unpacking(i);
s.k_packing(c) == i
})
.collect();
if extension != Extension::Periodic {
s.packing_table = (0..n_dof).map(|i| s.k_packing(i)).collect();
}
s
}
pub fn get_values(&self, k: usize, g: &[Complex<f64>], g_k: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.n_dof);
debug_assert_eq!(g_k.len(), self.r);
for j in 0..self.r {
let (i_g, coeff) = self.unpack_table[j * self.n_fft + k];
g_k[j] = coeff * g[i_g];
}
}
pub fn set_values(&self, k: usize, g: &mut [Complex<f64>], g_k: &[Complex<f64>]) {
assert!(k <= self.n);
debug_assert_eq!(g.len(), self.n_dof);
debug_assert_eq!(g_k.len(), self.r);
match self.extension {
Extension::Periodic => {
assert!(k < self.n);
for j in 0..self.r {
g[j * self.n + k] = g_k[j];
}
}
Extension::Odd => {
if k >= 1 && k < self.n {
for j in 0..self.r {
g[j * (self.n - 1) + (k - 1)] = g_k[j];
}
} else if k == 0 {
let base = self.r * (self.n - 1);
for j in 1..=(self.r - 1) / 2 {
g[base + (j - 1)] = g_k[j];
}
} else {
let base = self.r * (self.n - 1) + (self.r - 1) / 2;
for j in 1..=self.r / 2 {
g[base + (j - 1)] = g_k[j];
}
}
}
Extension::Even => {
if k >= 1 && k < self.n {
for j in 0..self.r {
g[j * (self.n - 1) + (k - 1)] = g_k[j];
}
} else if k == 0 {
let base = self.r * (self.n - 1);
for j in 0..=self.r / 2 {
g[base + j] = g_k[j];
}
} else {
let base = self.r * (self.n - 1) + self.r / 2 + 1;
for j in 0..=(self.r - 1) / 2 {
g[base + j] = g_k[j];
}
}
}
}
}
fn k_packing(&self, i: usize) -> usize {
match self.extension {
Extension::Periodic => i,
Extension::Odd => {
let c1 = self.r * (self.n - 1);
let c2 = (self.r - 1) / 2; if i < c1 {
let j = i / (self.n - 1);
let k = i % (self.n - 1) + 1;
j * self.n_fft + k
} else if i < c1 + c2 {
let j = i - c1 + 1;
j * self.n_fft
} else {
let j = i - c1 - c2 + 1;
j * self.n_fft + self.n
}
}
Extension::Even => {
let c1 = self.r * (self.n - 1);
let c2 = self.r / 2 + 1; if i < c1 {
let j = i / (self.n - 1);
let k = i % (self.n - 1) + 1;
j * self.n_fft + k
} else if i < c1 + c2 {
let j = i - c1;
j * self.n_fft
} else {
let j = i - c1 - c2;
j * self.n_fft + self.n
}
}
}
}
fn k_unpacking(&self, i: usize) -> (usize, Complex<f64>) {
let one = Complex::new(1.0_f64, 0.0);
let zero_coeff = Complex::new(0.0_f64, 0.0);
let minus_one = Complex::new(-1.0_f64, 0.0);
match self.extension {
Extension::Periodic => (i, one),
Extension::Odd => {
let j = i / self.n_fft;
let k = i % self.n_fft;
let c1 = self.r * (self.n - 1);
let c2 = (self.r - 1) / 2;
if k >= 1 && k < self.n {
(j * (self.n - 1) + (k - 1), one)
} else if k == 0 {
if j == 0 {
return (0, zero_coeff);
} if self.r % 2 == 0 && j == self.r / 2 {
return (0, zero_coeff);
}
if j <= (self.r - 1) / 2 {
(c1 + (j - 1), one)
} else {
(c1 + (self.r - j - 1), minus_one)
}
} else if k == self.n {
if j == 0 {
return (0, zero_coeff);
} let j_stored = j.min(self.r - j);
(c1 + c2 + (j_stored - 1), one)
} else {
let km = self.n_fft - k; if j == 0 {
let packed = km - 1; (packed, minus_one)
} else {
let phase =
Complex::new(0.0_f64, -PI * (km as f64) / (self.n as f64)).exp();
let packed = (self.r - j) * (self.n - 1) + (km - 1);
(packed, -phase)
}
}
}
Extension::Even => {
let j = i / self.n_fft;
let k = i % self.n_fft;
let c1 = self.r * (self.n - 1);
let c2 = self.r / 2 + 1;
if k >= 1 && k < self.n {
(j * (self.n - 1) + (k - 1), one)
} else if k == 0 {
let j_stored = j.min(self.r - j);
(c1 + j_stored, one)
} else if k == self.n {
if self.r % 2 == 0 && j == self.r / 2 {
return (0, zero_coeff);
} if j <= (self.r - 1) / 2 {
(c1 + c2 + j, one)
} else {
(c1 + c2 + (self.r - j), minus_one)
}
} else {
let km = self.n_fft - k;
if j == 0 {
let packed = km - 1; (packed, one)
} else {
let phase =
Complex::new(0.0_f64, -PI * (km as f64) / (self.n as f64)).exp();
let packed = (self.r - j) * (self.n - 1) + (km - 1);
(packed, phase)
}
}
}
}
}
pub fn get_x(&self) -> Vec<f64> {
let h = self.l / (self.n as f64);
let mut x = Vec::with_capacity(self.n_dof);
match self.extension {
Extension::Periodic => {
for k in 0..self.n {
for j in 0..self.r {
x.push(self.xl + h * ((k as f64) + self.gauss_lobatto_points[j]));
}
}
x
}
Extension::Odd => {
for j in 1..self.r {
x.push(self.xl + h * self.gauss_lobatto_points[j]);
}
for k in 1..self.n {
for j in 0..self.r {
x.push(self.xl + h * ((k as f64) + self.gauss_lobatto_points[j]));
}
}
x
}
Extension::Even => {
for j in 0..=self.r {
x.push(self.xl + h * self.gauss_lobatto_points[j]);
}
for k in 1..self.n {
for j in 1..=self.r {
x.push(self.xl + h * ((k as f64) + self.gauss_lobatto_points[j]));
}
}
x
}
}
}
pub fn eval(&self, x: f64, u: &[Complex<f64>]) -> Complex<f64> {
let y = match self.extension {
Extension::Periodic => (x - self.xl).rem_euclid(self.l) + self.xl,
_ => {
assert!(
x >= self.xl && x <= self.xl + self.l,
"eval: x={x} outside [{}, {}]",
self.xl,
self.xl + self.l
);
x
}
};
let h = self.l / (self.n as f64);
let k = (((y - self.xl) / h) as usize).min(self.n - 1);
let xi = (y - self.xl - (k as f64) * h) / h;
let phi = lagrangian_interpolation(&self.gauss_lobatto_points, &self.bary_weights, xi);
let mut val = Complex::new(0.0_f64, 0.0);
match self.extension {
Extension::Periodic => {
for j in 0..self.r {
val += u[k * self.r + j] * phi[j];
}
val += u[((k + 1) % self.n) * self.r] * phi[self.r];
}
Extension::Odd => {
for j in 0..=self.r {
let v = if (k == 0 && j == 0) || (k == self.n - 1 && j == self.r) {
Complex::default()
} else {
u[k * self.r + j - 1]
};
val += v * phi[j];
}
}
Extension::Even => {
for j in 0..=self.r {
val += u[k * self.r + j] * phi[j];
}
}
}
val
}
pub fn convect<F>(&self, lambda: F, u: &[Complex<f64>]) -> Vec<Complex<f64>>
where
F: Fn(f64) -> f64,
{
self.get_x()
.iter()
.map(|&x| self.eval(lambda(x), u))
.collect()
}
pub(crate) fn eval_dof_index(&self, k: usize, j: usize) -> usize {
match self.extension {
Extension::Periodic => {
if j < self.r {
k * self.r + j
} else {
((k + 1) % self.n) * self.r
}
}
Extension::Odd => {
if (k == 0 && j == 0) || (k == self.n - 1 && j == self.r) {
self.n_dof } else {
k * self.r + j - 1
}
}
Extension::Even => k * self.r + j,
}
}
pub fn forward(&self, f: &mut [Complex<f64>]) {
debug_assert_eq!(f.len(), self.n_dof);
let mut buf = vec![Complex::<f64>::default(); self.r * self.n_fft];
match self.extension {
Extension::Periodic => {
for k in 0..self.n {
for j in 0..self.r {
buf[j * self.n_fft + k] = f[k * self.r + j];
}
}
}
Extension::Odd => {
for j in 0..self.r {
for k in 0..self.n_fft {
buf[j * self.n_fft + k] = if k < self.n {
if k == 0 && j == 0 {
Complex::default()
} else {
f[k * self.r + j - 1]
}
} else if k == self.n && j == 0 {
Complex::default()
} else {
-f[(self.n_fft - k) * self.r - j - 1]
};
}
}
}
Extension::Even => {
for j in 0..self.r {
for k in 0..self.n_fft {
buf[j * self.n_fft + k] = if k < self.n {
f[k * self.r + j]
} else {
f[(self.n_fft - k) * self.r - j]
};
}
}
}
}
self.fw_fft.process(&mut buf);
match self.extension {
Extension::Periodic => {
for j in 0..self.r {
f[j * self.n..j * self.n + self.n]
.copy_from_slice(&buf[j * self.n_fft..j * self.n_fft + self.n]);
}
}
Extension::Odd | Extension::Even => {
for i in 0..self.n_dof {
f[i] = buf[self.packing_table[i]];
}
}
}
}
pub fn inverse(&self, g: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.n_dof);
let mut buf = vec![Complex::<f64>::default(); self.r * self.n_fft];
let norm = self.n_fft as f64;
match self.extension {
Extension::Periodic => {
buf.copy_from_slice(g);
}
Extension::Odd | Extension::Even => {
for i in 0..self.r * self.n_fft {
let (i_g, coeff_g) = self.unpack_table[i];
buf[i] = coeff_g * g[i_g];
}
}
}
self.i_fft.process(&mut buf);
match self.extension {
Extension::Periodic => {
for k in 0..self.n {
for j in 0..self.r {
g[k * self.r + j] = buf[j * self.n + k] / norm;
}
}
}
Extension::Odd => {
for j in 0..self.r {
for k in 0..self.n {
if k == 0 && j == 0 {
continue;
}
g[k * self.r + j - 1] = buf[j * self.n_fft + k] / norm;
}
}
}
Extension::Even => {
for j in 0..self.r {
for k in 0..self.n {
g[k * self.r + j] = buf[j * self.n_fft + k] / norm;
}
}
g[self.n * self.r] = buf[self.n] / norm;
}
}
}
}
pub struct EngineND<const N: usize> {
planners: [Engine; N],
pub(crate) ndofs: [usize; N],
pub(crate) total: usize,
pub(crate) strides: [usize; N],
pub(crate) rs: [usize; N],
pub(crate) r_strides: [usize; N],
pub(crate) r_total: usize,
full_sizes: [usize; N],
full_strides: [usize; N],
}
impl<const N: usize> EngineND<N> {
pub fn new(
ns: [usize; N],
rs: [usize; N],
ls: [f64; N],
xls: [f64; N],
exts: [Extension; N],
) -> Self {
let planners: [Engine; N] =
std::array::from_fn(|d| Engine::new(ns[d], rs[d], ls[d], xls[d], exts[d]));
let ndofs: [usize; N] = std::array::from_fn(|d| planners[d].n_dof);
let total = ndofs.iter().product();
let strides = nd_col_major_strides(&ndofs);
let rs_arr: [usize; N] = std::array::from_fn(|d| planners[d].r);
let r_strides = nd_col_major_strides(&rs_arr);
let r_total = rs_arr.iter().product();
let full_sizes: [usize; N] = std::array::from_fn(|d| planners[d].r * planners[d].n_fft);
let full_strides = nd_col_major_strides(&full_sizes);
Self {
planners,
ndofs,
total,
strides,
rs: rs_arr,
r_strides,
r_total,
full_sizes,
full_strides,
}
}
pub fn get_x(&self) -> Vec<[f64; N]> {
let x1d: [Vec<f64>; N] = std::array::from_fn(|d| self.planners[d].get_x());
(0..self.total)
.map(|p| std::array::from_fn(|d| x1d[d][(p / self.strides[d]) % self.ndofs[d]]))
.collect()
}
pub fn k_packing(&self, i: usize) -> usize {
(0..N).fold(0, |acc, d| {
let c_d = (i / self.strides[d]) % self.ndofs[d];
acc + self.planners[d].k_packing(c_d) * self.full_strides[d]
})
}
pub fn k_unpacking(&self, i: usize) -> (usize, Complex<f64>) {
let mut compact_idx = 0usize;
let mut coeff = Complex::new(1.0_f64, 0.0);
for d in 0..N {
let f_d = (i / self.full_strides[d]) % self.full_sizes[d];
let (c_d, coeff_d) = self.planners[d].unpack_table[f_d];
compact_idx += c_d * self.strides[d];
coeff *= coeff_d;
}
(compact_idx, coeff)
}
pub fn get_values(&self, ks: &[usize; N], g: &[Complex<f64>], g_k: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.total);
debug_assert_eq!(g_k.len(), self.r_total);
for j_flat in 0..self.r_total {
let mut compact_idx = 0usize;
let mut coeff = Complex::new(1.0_f64, 0.0);
for d in 0..N {
let j_d = (j_flat / self.r_strides[d]) % self.rs[d];
let (c_d, coeff_d) =
self.planners[d].unpack_table[j_d * self.planners[d].n_fft + ks[d]];
compact_idx += c_d * self.strides[d];
coeff *= coeff_d;
}
g_k[j_flat] = coeff * g[compact_idx];
}
}
pub fn set_values(&self, ks: &[usize; N], g: &mut [Complex<f64>], g_k: &[Complex<f64>]) {
debug_assert_eq!(g.len(), self.total);
debug_assert_eq!(g_k.len(), self.r_total);
for j_flat in 0..self.r_total {
let mut canonical = true;
let mut compact_idx = 0usize;
for d in 0..N {
let j_d = (j_flat / self.r_strides[d]) % self.rs[d];
let full_1d = j_d * self.planners[d].n_fft + ks[d];
let (c_d, _) = self.planners[d].unpack_table[full_1d];
if !self.planners[d].canonical_table[full_1d] {
canonical = false;
break;
}
compact_idx += c_d * self.strides[d];
}
if canonical {
g[compact_idx] = g_k[j_flat];
}
}
}
pub fn eval(&self, xs: &[[f64; N]], u: &[Complex<f64>], fu: &mut [Complex<f64>]) {
debug_assert_eq!(u.len(), self.total);
debug_assert_eq!(fu.len(), xs.len());
let ns: [usize; N] = std::array::from_fn(|d| self.planners[d].n);
let rs: [usize; N] = std::array::from_fn(|d| self.planners[d].r);
let ls: [f64; N] = std::array::from_fn(|d| self.planners[d].l);
let xls: [f64; N] = std::array::from_fn(|d| self.planners[d].xl);
let hs: [f64; N] = std::array::from_fn(|d| ls[d] / ns[d] as f64);
let counts: [usize; N] = std::array::from_fn(|d| rs[d] + 1);
let count_strides = nd_col_major_strides(&counts);
let local_total: usize = counts.iter().product();
let n_strides = nd_col_major_strides(&ns);
let elem_indices: Vec<usize> = xs
.par_iter()
.map(|x| {
(0..N).fold(0, |acc, d| {
let y = (x[d] - xls[d]).rem_euclid(ls[d]);
let k_d = ((y / hs[d]) as usize).min(ns[d] - 1);
acc + k_d * n_strides[d]
})
})
.collect();
let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
for (i, &k_flat) in elem_indices.iter().enumerate() {
groups.entry(k_flat).or_default().push(i);
}
let partial: Vec<Vec<(usize, Complex<f64>)>> = groups
.into_par_iter()
.map(|(k_flat, indices)| {
let ks: [usize; N] = std::array::from_fn(|d| (k_flat / n_strides[d]) % ns[d]);
let u_local: Vec<Complex<f64>> = (0..local_total)
.map(|p| {
let mut global = 0usize;
let mut is_zero = false;
for d in 0..N {
let j = (p / count_strides[d]) % counts[d];
let dof_d = self.planners[d].eval_dof_index(ks[d], j);
if dof_d == self.planners[d].n_dof {
is_zero = true;
break;
}
global += dof_d * self.strides[d];
}
if is_zero {
Complex::default()
} else {
u[global]
}
})
.collect();
let result: Vec<(usize, Complex<f64>)> = indices
.par_iter()
.map(|&i| {
let phi: [Vec<f64>; N] = std::array::from_fn(|d| {
let y = match self.planners[d].extension {
Extension::Periodic => {
(xs[i][d] - xls[d]).rem_euclid(ls[d]) + xls[d]
}
_ => xs[i][d],
};
let xi = (y - xls[d] - (ks[d] as f64) * hs[d]) / hs[d];
lagrangian_interpolation(
&self.planners[d].gauss_lobatto_points,
&self.planners[d].bary_weights,
xi,
)
});
let c0 = counts[0];
let rest0 = local_total / c0;
let mut v: Vec<Complex<f64>> = (0..rest0)
.map(|s| {
(0..c0)
.map(|j| phi[0][j] * u_local[s * c0 + j])
.sum::<Complex<f64>>()
})
.collect();
for d in 1..N {
let c_d = counts[d];
let next = v.len() / c_d;
v = (0..next)
.map(|s| {
(0..c_d)
.map(|j| phi[d][j] * v[s * c_d + j])
.sum::<Complex<f64>>()
})
.collect();
}
(i, v[0])
})
.collect();
result
})
.collect();
for group in partial {
for (i, val) in group {
fu[i] = val;
}
}
}
pub fn convect<F>(&self, lambda: F, u: &[Complex<f64>]) -> Vec<Complex<f64>>
where
F: Fn([f64; N]) -> [f64; N] + Sync,
{
let xs = self.get_x();
let shifted: Vec<[f64; N]> = xs.par_iter().map(|&pt| lambda(pt)).collect();
let mut out = vec![Complex::new(0.0_f64, 0.0); u.len()];
self.eval(&shifted, u, &mut out);
out
}
pub fn forward(&self, f: &mut [Complex<f64>]) {
debug_assert_eq!(f.len(), self.total);
let sizes = self.ndofs.as_slice();
let strides = self.strides.as_slice();
for d in 0..N {
let dim_d = self.ndofs[d];
let n_fibers = self.total / dim_d;
let stride_d = self.strides[d];
let ptr = f.as_mut_ptr() as usize;
(0..n_fibers).into_par_iter().for_each(|s| {
let base = nd_fiber_base(s, d, sizes, strides);
let ptr = ptr as *mut Complex<f64>;
let mut fiber: Vec<Complex<f64>> = (0..dim_d)
.map(|m| unsafe { *ptr.add(base + m * stride_d) })
.collect();
self.planners[d].forward(fiber.as_mut_slice());
for m in 0..dim_d {
unsafe { *ptr.add(base + m * stride_d) = fiber[m] };
}
});
}
}
pub fn inverse(&self, g: &mut [Complex<f64>]) {
debug_assert_eq!(g.len(), self.total);
let sizes = self.ndofs.as_slice();
let strides = self.strides.as_slice();
for d in 0..N {
let dim_d = self.ndofs[d];
let n_fibers = self.total / dim_d;
let stride_d = self.strides[d];
let ptr = g.as_mut_ptr() as usize;
(0..n_fibers).into_par_iter().for_each(|s| {
let base = nd_fiber_base(s, d, sizes, strides);
let ptr = ptr as *mut Complex<f64>;
let mut fiber: Vec<Complex<f64>> = (0..dim_d)
.map(|m| unsafe { *ptr.add(base + m * stride_d) })
.collect();
self.planners[d].inverse(fiber.as_mut_slice());
for m in 0..dim_d {
unsafe { *ptr.add(base + m * stride_d) = fiber[m] };
}
});
}
}
}
pub(crate) fn nd_col_major_strides<const N: usize>(sizes: &[usize; N]) -> [usize; N] {
let mut st = [1usize; N];
for d in 1..N {
st[d] = st[d - 1] * sizes[d - 1];
}
st
}
pub(crate) fn nd_fiber_base(s: usize, d: usize, sizes: &[usize], strides: &[usize]) -> usize {
let mut base = 0usize;
let mut rem = s;
for d2 in 0..sizes.len() {
if d2 == d {
continue;
}
base += (rem % sizes[d2]) * strides[d2];
rem /= sizes[d2];
}
base
}