use super::filter::CosineFilter;
use crate::assembly::HelmholtzAssembler;
use math_audio_solvers::{
AmgConfig, AmgPreconditioner, AmgSmoother, CgConfig, CsrMatrix, pcg, cg,
};
use ndarray::Array1;
use std::collections::HashMap;
pub struct WaveTimeStepper {
a_impl: CsrMatrix<f64>,
b_rhs: CsrMatrix<f64>,
dt: f64,
dt_sq: f64,
ndofs: usize,
cg_config: CgConfig<f64>,
amg_precond: Option<AmgPreconditioner<f64>>,
}
impl WaveTimeStepper {
pub fn new(
assembler: &HelmholtzAssembler,
omega: f64,
steps_per_period: usize,
cg_config: CgConfig<f64>,
use_amg: bool,
) -> Self {
Self::new_with_boundaries(assembler, omega, steps_per_period, cg_config, use_amg, &HashMap::new())
}
pub fn new_with_boundaries(
assembler: &HelmholtzAssembler,
omega: f64,
steps_per_period: usize,
cg_config: CgConfig<f64>,
use_amg: bool,
boundary_coeffs: &HashMap<usize, f64>,
) -> Self {
let period = 2.0 * std::f64::consts::PI / omega;
let dt = period / steps_per_period as f64;
let dt_sq = dt * dt;
let ndofs = assembler.num_rows;
let a_impl = if boundary_coeffs.is_empty() {
assembler.assemble_real(dt_sq / 4.0, 1.0)
} else {
assembler.assemble_real_with_boundaries(dt_sq / 4.0, 1.0, boundary_coeffs)
};
let b_rhs = if boundary_coeffs.is_empty() {
assembler.assemble_real(-dt_sq / 2.0, 2.0)
} else {
assembler.assemble_real_with_boundaries(-dt_sq / 2.0, 2.0, boundary_coeffs)
};
let amg_precond = if use_amg {
let mut amg_config = AmgConfig::for_fem();
amg_config.smoother = AmgSmoother::SymmetricGaussSeidel;
Some(AmgPreconditioner::from_csr(&a_impl, amg_config))
} else {
None
};
Self {
a_impl,
b_rhs,
dt,
dt_sq,
ndofs,
cg_config,
amg_precond,
}
}
fn solve_implicit(&self, rhs: &Array1<f64>) -> Array1<f64> {
let result = if let Some(ref precond) = self.amg_precond {
pcg(&self.a_impl, precond, rhs, &self.cg_config)
} else {
cg(&self.a_impl, rhs, &self.cg_config)
};
result.x
}
pub fn ndofs(&self) -> usize {
self.ndofs
}
pub fn dt(&self) -> f64 {
self.dt
}
pub fn propagate_filtered(
&self,
w0: &Array1<f64>,
forcing: Option<&Array1<f64>>,
filter: &CosineFilter,
omega: f64,
) -> Array1<f64> {
let n_steps = filter.n_steps();
let mut accumulator = Array1::zeros(self.ndofs);
let mut w_prev;
let mut w_curr = w0.clone();
filter.accumulate(0, &w_curr, &mut accumulator);
{
let mut rhs = self.b_rhs.matvec(&w_curr);
if let Some(f) = forcing {
rhs.scaled_add(self.dt_sq, f);
}
rhs.mapv_inplace(|v| v * 0.5);
let w_next = self.solve_implicit(&rhs);
w_prev = w_curr;
w_curr = w_next;
}
filter.accumulate(1, &w_curr, &mut accumulator);
for n in 2..=n_steps {
let mut rhs = self.b_rhs.matvec(&w_curr);
let a_prev = self.a_impl.matvec(&w_prev);
rhs -= &a_prev;
if let Some(f) = forcing {
let t = (n - 1) as f64 * self.dt;
let cos_wt = (omega * t).cos();
if cos_wt.abs() > 1e-20 {
rhs.scaled_add(self.dt_sq * cos_wt, f);
}
}
let w_next = self.solve_implicit(&rhs);
w_prev = w_curr;
w_curr = w_next;
filter.accumulate(n, &w_curr, &mut accumulator);
}
accumulator
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assembly::HelmholtzAssembler;
use crate::basis::PolynomialDegree;
use crate::mesh::unit_square_triangles;
#[test]
fn test_a_impl_is_spd() {
let mesh = unit_square_triangles(4);
let assembler = HelmholtzAssembler::new(&mesh, PolynomialDegree::P1);
let omega = std::f64::consts::PI;
let cg_config = CgConfig {
max_iterations: 500,
tolerance: 1e-10,
print_interval: 0,
};
let stepper = WaveTimeStepper::new(&assembler, omega, 10, cg_config, false);
let n = stepper.ndofs();
let rhs: Array1<f64> = Array1::from_iter((0..n).map(|i| (i as f64 * 0.37).sin()));
let result = cg(&stepper.a_impl, &rhs, &stepper.cg_config);
assert!(
result.converged,
"CG should converge on A_impl (SPD). Residual: {:.2e} after {} iters",
result.residual,
result.iterations
);
}
#[test]
fn test_standing_wave_one_period() {
let mesh = unit_square_triangles(8);
let assembler = HelmholtzAssembler::new(&mesh, PolynomialDegree::P1);
let omega = std::f64::consts::PI; let n_steps = 20;
let cg_config = CgConfig {
max_iterations: 500,
tolerance: 1e-12,
print_interval: 0,
};
let stepper = WaveTimeStepper::new(&assembler, omega, n_steps, cg_config, false);
let ndofs = stepper.ndofs();
let w0: Array1<f64> = Array1::from_iter((0..ndofs).map(|i| {
let n_side = 9; let ix = i % n_side;
let iy = i / n_side;
let x = ix as f64 / (n_side - 1) as f64;
let y = iy as f64 / (n_side - 1) as f64;
(std::f64::consts::PI * x).sin() * (std::f64::consts::PI * y).sin()
}));
let filter = CosineFilter::new(omega, stepper.dt(), n_steps);
let result = stepper.propagate_filtered(&w0, None, &filter, omega);
let result_norm: f64 = result.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(
result_norm > 1e-6,
"Filtered result should be non-trivial"
);
}
}