use super::mass::{MassMatrix, assemble_mass};
use super::stiffness::{StiffnessMatrix, assemble_stiffness};
use crate::basis::PolynomialDegree;
use crate::mesh::Mesh;
use math_audio_solvers::CsrMatrix;
use num_complex::Complex64;
use rayon::prelude::*;
use std::collections::HashMap;
pub struct HelmholtzAssembler {
pub num_rows: usize,
pub row_ptrs: Vec<usize>,
pub col_indices: Vec<usize>,
pub k_values: Vec<f64>,
pub m_values: Vec<f64>,
pub boundary_values: HashMap<usize, Vec<f64>>,
}
impl HelmholtzAssembler {
pub fn new(mesh: &Mesh, degree: PolynomialDegree) -> Self {
let stiffness = assemble_stiffness(mesh, degree);
let mass = assemble_mass(mesh, degree);
Self::from_matrices(&stiffness, &mass, &[])
}
pub fn from_matrices(
stiffness: &StiffnessMatrix,
mass: &MassMatrix,
boundaries: &[(usize, MassMatrix)],
) -> Self {
assert_eq!(stiffness.dim, mass.dim);
let num_rows = stiffness.dim;
#[derive(Debug, Clone, Copy)]
struct Entry {
row: usize,
col: usize,
source_idx: i32, value: f64,
}
let total_nnz =
stiffness.nnz() + mass.nnz() + boundaries.iter().map(|(_, m)| m.nnz()).sum::<usize>();
let mut entries = Vec::with_capacity(total_nnz);
for i in 0..stiffness.nnz() {
entries.push(Entry {
row: stiffness.rows[i],
col: stiffness.cols[i],
source_idx: -1,
value: stiffness.values[i],
});
}
for i in 0..mass.nnz() {
entries.push(Entry {
row: mass.rows[i],
col: mass.cols[i],
source_idx: -2,
value: mass.values[i],
});
}
for (tag, matrix) in boundaries {
for i in 0..matrix.nnz() {
entries.push(Entry {
row: matrix.rows[i],
col: matrix.cols[i],
source_idx: *tag as i32,
value: matrix.values[i],
});
}
}
entries.par_sort_unstable_by(|a, b| {
if a.row != b.row {
a.row.cmp(&b.row)
} else {
a.col.cmp(&b.col)
}
});
let mut row_ptrs = vec![0; num_rows + 1];
let mut col_indices = Vec::with_capacity(entries.len());
let mut k_values = Vec::with_capacity(entries.len());
let mut m_values = Vec::with_capacity(entries.len());
let mut boundary_values_map: HashMap<usize, Vec<f64>> = HashMap::new();
for (tag, _) in boundaries {
boundary_values_map.insert(*tag, Vec::with_capacity(entries.len()));
}
if entries.is_empty() {
return Self {
num_rows,
row_ptrs,
col_indices,
k_values,
m_values,
boundary_values: boundary_values_map,
};
}
let mut last_r = entries[0].row;
let mut last_c = entries[0].col;
let mut acc_k = 0.0;
let mut acc_m = 0.0;
let mut acc_boundaries: HashMap<usize, f64> = HashMap::new();
let accumulate =
|entry: &Entry, k: &mut f64, m: &mut f64, b_map: &mut HashMap<usize, f64>| match entry
.source_idx
{
-1 => *k += entry.value,
-2 => *m += entry.value,
tag => {
let t = tag as usize;
*b_map.entry(t).or_insert(0.0) += entry.value;
}
};
accumulate(&entries[0], &mut acc_k, &mut acc_m, &mut acc_boundaries);
for r in 0..last_r {
row_ptrs[r + 1] = 0;
}
for entry in entries.iter().skip(1) {
if entry.row == last_r && entry.col == last_c {
accumulate(entry, &mut acc_k, &mut acc_m, &mut acc_boundaries);
} else {
k_values.push(acc_k);
m_values.push(acc_m);
for (tag, vec) in boundary_values_map.iter_mut() {
vec.push(acc_boundaries.get(tag).copied().unwrap_or(0.0));
}
col_indices.push(last_c);
if entry.row != last_r {
row_ptrs[last_r + 1] = k_values.len();
for r in (last_r + 1)..entry.row {
row_ptrs[r + 1] = k_values.len();
}
}
last_r = entry.row;
last_c = entry.col;
acc_k = 0.0;
acc_m = 0.0;
acc_boundaries.clear();
accumulate(entry, &mut acc_k, &mut acc_m, &mut acc_boundaries);
}
}
k_values.push(acc_k);
m_values.push(acc_m);
for (tag, vec) in boundary_values_map.iter_mut() {
vec.push(acc_boundaries.get(tag).copied().unwrap_or(0.0));
}
col_indices.push(last_c);
row_ptrs[last_r + 1] = k_values.len();
for r in (last_r + 1)..num_rows {
row_ptrs[r + 1] = k_values.len();
}
Self {
num_rows,
row_ptrs,
col_indices,
k_values,
m_values,
boundary_values: boundary_values_map,
}
}
pub fn assemble(
&self,
wavenumber: Complex64,
boundary_coeffs: &HashMap<usize, Complex64>,
) -> CsrMatrix<Complex64> {
let k_sq = wavenumber * wavenumber;
let nnz = self.k_values.len();
let values: Vec<Complex64> = (0..nnz)
.into_par_iter()
.map(|i| {
let mut val = Complex64::new(self.k_values[i], 0.0)
- k_sq * Complex64::new(self.m_values[i], 0.0);
if !self.boundary_values.is_empty() {
#[allow(clippy::collapsible_if)]
for (tag, coeffs) in boundary_coeffs {
if let Some(b_vals) = self.boundary_values.get(tag) {
if b_vals[i] != 0.0 {
val += coeffs * Complex64::new(b_vals[i], 0.0);
}
}
}
}
val
})
.collect();
CsrMatrix::from_raw_parts(
self.num_rows,
self.num_rows, self.row_ptrs.clone(),
self.col_indices.clone(),
values,
)
}
pub fn memory_usage(&self) -> usize {
let usize_size = std::mem::size_of::<usize>();
let f64_size = std::mem::size_of::<f64>();
let mut mem = self.row_ptrs.len() * usize_size
+ self.col_indices.len() * usize_size
+ self.k_values.len() * f64_size
+ self.m_values.len() * f64_size;
for v in self.boundary_values.values() {
mem += v.len() * f64_size;
}
mem
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mesh::unit_square_triangles;
#[test]
fn test_assembler_simple() {
let mesh = unit_square_triangles(2);
let assembler = HelmholtzAssembler::new(&mesh, PolynomialDegree::P1);
let k = Complex64::new(1.0, 0.0);
let coeffs = HashMap::new();
let matrix = assembler.assemble(k, &coeffs);
assert_eq!(matrix.num_rows, mesh.num_nodes());
assert!(matrix.nnz() > 0);
}
#[test]
fn test_assembler_with_boundary() {
use crate::assembly::mass::assemble_boundary_mass;
let mut mesh = unit_square_triangles(2);
mesh.detect_boundaries();
let stiffness = assemble_stiffness(&mesh, PolynomialDegree::P1);
let mass = assemble_mass(&mesh, PolynomialDegree::P1);
let b_mass = assemble_boundary_mass(&mesh, PolynomialDegree::P1, 0);
let boundaries = vec![(0, b_mass)];
let assembler = HelmholtzAssembler::from_matrices(&stiffness, &mass, &boundaries);
let k = Complex64::new(1.0, 0.0);
let mut coeffs = HashMap::new();
coeffs.insert(0, Complex64::new(0.5, 0.0));
let matrix = assembler.assemble(k, &coeffs);
let assembler_base = HelmholtzAssembler::from_matrices(&stiffness, &mass, &[]);
let matrix_base = assembler_base.assemble(k, &HashMap::new());
let sum: Complex64 = matrix.values.iter().sum();
let sum_base: Complex64 = matrix_base.values.iter().sum();
assert!((sum - sum_base).norm() > 1e-10);
}
}