use crate::basis::PolynomialDegree;
use crate::mesh::Mesh;
use crate::multigrid::{MultigridHierarchy, TransferMatrix};
use math_audio_solvers::CsrMatrix;
use num_complex::Complex64;
use super::chebyshev::estimate_alpha;
pub struct WaveHierarchy {
pub hierarchy: MultigridHierarchy,
pub operators: Vec<CsrMatrix<Complex64>>,
pub stiffness_csrs: Vec<CsrMatrix<Complex64>>,
pub mass_csrs: Vec<CsrMatrix<Complex64>>,
pub kh_values: Vec<f64>,
pub h_values: Vec<f64>,
pub alpha_values: Vec<f64>,
pub adr_level: usize,
pub wavenumber: f64,
pub omega: f64,
pub gamma: f64,
}
impl WaveHierarchy {
pub fn build(mesh: Mesh, degree: PolynomialDegree, k: f64, gamma: f64) -> Self {
let h_fine = super::eikonal::mesh_size(&mesh);
let kh_fine = k * h_fine;
let n_levels = if kh_fine > 0.01 {
let levels_needed = ((std::f64::consts::PI / kh_fine).log2().ceil() as usize).max(2);
levels_needed.min(8) } else {
4 };
let wavenumber_complex = Complex64::new(k, 0.0);
let mut hierarchy =
MultigridHierarchy::from_fine_mesh(mesh, n_levels, degree, wavenumber_complex);
let actual_levels = hierarchy.num_levels();
hierarchy.assemble_all();
let mut operators = Vec::with_capacity(actual_levels);
let mut stiffness_csrs = Vec::with_capacity(actual_levels);
let mut mass_csrs = Vec::with_capacity(actual_levels);
let mut kh_values = Vec::with_capacity(actual_levels);
let mut h_values = Vec::with_capacity(actual_levels);
let mut alpha_values = Vec::with_capacity(actual_levels);
for (level_idx, level) in hierarchy.levels.iter().enumerate() {
let h = super::eikonal::mesh_size(&level.mesh);
let kh = k * h;
let system = level.system.as_ref().expect("System not assembled");
let csr = system.to_csr();
operators.push(csr);
let stiffness = level.stiffness.as_ref().expect("Stiffness not assembled");
let mass = level.mass.as_ref().expect("Mass not assembled");
stiffness_csrs.push(real_to_complex_csr(stiffness));
mass_csrs.push(real_to_complex_csr_mass(mass));
let alpha = estimate_alpha(k, h, level_idx);
alpha_values.push(alpha);
kh_values.push(kh);
h_values.push(h);
}
let adr_level = kh_values
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let da = (**a - 1.0).abs();
let db = (**b - 1.0).abs();
da.partial_cmp(&db).unwrap()
})
.map(|(i, _)| i)
.unwrap_or(1)
.max(1);
WaveHierarchy {
hierarchy,
operators,
stiffness_csrs,
mass_csrs,
kh_values,
h_values,
alpha_values,
adr_level,
wavenumber: k,
omega: k, gamma,
}
}
pub fn num_levels(&self) -> usize {
self.hierarchy.num_levels()
}
pub fn restriction(&self, level: usize) -> Option<&TransferMatrix> {
self.hierarchy.levels[level].restriction.as_ref()
}
pub fn prolongation(&self, level: usize) -> Option<&TransferMatrix> {
if level + 1 < self.num_levels() {
self.hierarchy.levels[level + 1].prolongation.as_ref()
} else {
None
}
}
pub fn n_dofs(&self, level: usize) -> usize {
self.hierarchy.levels[level].n_dofs
}
pub fn mesh(&self, level: usize) -> &Mesh {
&self.hierarchy.levels[level].mesh
}
}
fn real_to_complex_csr(stiffness: &crate::assembly::StiffnessMatrix) -> CsrMatrix<Complex64> {
let triplets: Vec<(usize, usize, Complex64)> = stiffness
.rows
.iter()
.zip(stiffness.cols.iter())
.zip(stiffness.values.iter())
.map(|((&r, &c), &v)| (r, c, Complex64::new(v, 0.0)))
.collect();
CsrMatrix::from_triplets(stiffness.dim, stiffness.dim, triplets)
}
fn real_to_complex_csr_mass(mass: &crate::assembly::MassMatrix) -> CsrMatrix<Complex64> {
let triplets: Vec<(usize, usize, Complex64)> = mass
.rows
.iter()
.zip(mass.cols.iter())
.zip(mass.values.iter())
.map(|((&r, &c), &v)| (r, c, Complex64::new(v, 0.0)))
.collect();
CsrMatrix::from_triplets(mass.dim, mass.dim, triplets)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mesh::unit_square_triangles;
#[test]
fn test_wave_hierarchy_build() {
let mesh = unit_square_triangles(8);
let k = 5.0;
let gamma = 0.5;
let wh = WaveHierarchy::build(mesh, PolynomialDegree::P1, k, gamma);
assert!(wh.num_levels() >= 2);
assert!(wh.adr_level >= 1);
assert!(wh.adr_level < wh.num_levels());
for i in 1..wh.num_levels() {
assert!(
wh.kh_values[i] >= wh.kh_values[i - 1] * 0.9,
"kh should generally increase: level {} kh={}, level {} kh={}",
i - 1,
wh.kh_values[i - 1],
i,
wh.kh_values[i]
);
}
}
#[test]
fn test_wave_hierarchy_operators() {
let mesh = unit_square_triangles(4);
let k = 2.0;
let gamma = 0.5;
let wh = WaveHierarchy::build(mesh, PolynomialDegree::P1, k, gamma);
for (i, op) in wh.operators.iter().enumerate() {
assert_eq!(op.num_rows, op.num_cols);
assert!(op.nnz() > 0, "Level {} operator should have entries", i);
assert_eq!(op.num_rows, wh.n_dofs(i));
}
}
#[test]
fn test_wave_hierarchy_adr_level() {
let mesh = unit_square_triangles(16);
let k = 10.0;
let gamma = 0.5;
let wh = WaveHierarchy::build(mesh, PolynomialDegree::P1, k, gamma);
assert!(
wh.kh_values[wh.adr_level] < 3.0,
"ADR level kh should be moderate, got {}",
wh.kh_values[wh.adr_level]
);
}
}