use super::adr;
use super::chebyshev;
use super::config::NeuralMultigridConfig;
use super::eikonal::{self, EikonalSolution};
use super::hierarchy::WaveHierarchy;
use crate::basis::PolynomialDegree;
use crate::mesh::Mesh;
use crate::solver::{Solution, SolverError};
use math_audio_solvers::CsrMatrix;
use ndarray::Array1;
use num_complex::Complex64;
use std::time::Instant;
pub fn solve_neural_multigrid(
mesh: &Mesh,
degree: PolynomialDegree,
wavenumber: Complex64,
rhs: &[Complex64],
dirichlet_bcs: &[(usize, Complex64)],
config: &NeuralMultigridConfig,
) -> Result<Solution, SolverError> {
let k = wavenumber.re;
if k <= 0.0 {
return Err(SolverError::InvalidConfiguration(
"Wavenumber must have positive real part".into(),
));
}
let n_dofs = mesh.num_nodes();
if rhs.len() != n_dofs {
return Err(SolverError::DimensionMismatch {
expected: n_dofs,
actual: rhs.len(),
});
}
let start = Instant::now();
if config.verbosity > 0 {
println!(
" [NMG] Neural Multigrid solver: k={:.2}, {} DOFs, gamma={:.2}",
k, n_dofs, config.damping_gamma
);
}
let hierarchy_start = Instant::now();
let wave_hierarchy =
WaveHierarchy::build(mesh.clone(), degree, k, config.damping_gamma);
let hierarchy_time = hierarchy_start.elapsed();
if config.verbosity > 0 {
println!(
" [NMG] Hierarchy: {} levels, ADR level={} (kh={:.2}), build time={:.1}ms",
wave_hierarchy.num_levels(),
wave_hierarchy.adr_level,
wave_hierarchy.kh_values[wave_hierarchy.adr_level],
hierarchy_time.as_secs_f64() * 1000.0
);
if config.verbosity > 1 {
for (i, kh) in wave_hierarchy.kh_values.iter().enumerate() {
println!(
" Level {}: {} DOFs, h={:.4}, kh={:.3}, alpha={:.3}",
i,
wave_hierarchy.n_dofs(i),
wave_hierarchy.h_values[i],
kh,
wave_hierarchy.alpha_values[i]
);
}
}
}
let eikonal_start = Instant::now();
let source = config
.source_point
.unwrap_or_else(|| eikonal::domain_center(mesh));
let speed = 1.0; let eikonal_sol = eikonal::solve_eikonal_homogeneous(source, speed, mesh);
let eikonal_time = eikonal_start.elapsed();
if config.verbosity > 0 {
println!(
" [NMG] Eikonal: source=({:.2}, {:.2}, {:.2}), time={:.1}ms",
source[0],
source[1],
source[2],
eikonal_time.as_secs_f64() * 1000.0
);
}
let mut rhs_vec = Array1::from(rhs.to_vec());
let mut x = Array1::zeros(n_dofs);
let finest_op = &wave_hierarchy.operators[0];
apply_dirichlet_to_system(finest_op, &mut x, &mut rhs_vec, dirichlet_bcs);
let rhs_norm = rhs_vec.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
let tol = config.tolerance * rhs_norm.max(1e-10);
let mut iteration = 0;
let mut residual_norm = compute_residual_norm(finest_op, &x, &rhs_vec);
if config.verbosity > 0 {
println!(
" [NMG] Initial residual: {:.2e}, tolerance: {:.2e}",
residual_norm, tol
);
}
while iteration < config.max_iterations && residual_norm > tol {
wave_adr_cycle(
&wave_hierarchy,
0,
&mut x,
&rhs_vec,
&eikonal_sol,
config,
);
for &(node, value) in dirichlet_bcs {
x[node] = value;
}
residual_norm = compute_residual_norm(finest_op, &x, &rhs_vec);
iteration += 1;
if config.verbosity > 1 {
println!(
" Iteration {}: residual = {:.2e}",
iteration, residual_norm
);
}
}
let total_time = start.elapsed();
if config.verbosity > 0 {
println!(
" [NMG] {} in {} iterations (residual: {:.2e}, time: {:.1}ms)",
if residual_norm <= tol {
"Converged"
} else {
"Did not converge"
},
iteration,
residual_norm,
total_time.as_secs_f64() * 1000.0
);
}
if residual_norm > tol {
return Err(SolverError::ConvergenceFailure(iteration, residual_norm));
}
Ok(Solution {
values: x,
iterations: iteration,
residual: residual_norm,
converged: true,
})
}
fn wave_adr_cycle(
hierarchy: &WaveHierarchy,
level: usize,
x: &mut Array1<Complex64>,
b: &Array1<Complex64>,
eikonal: &EikonalSolution,
config: &NeuralMultigridConfig,
) {
let n_levels = hierarchy.num_levels();
let operator = &hierarchy.operators[level];
if level == n_levels - 1 {
let alpha = hierarchy.alpha_values[level];
chebyshev::chebyshev_smooth(
operator,
x,
b,
config.coarsest_chebyshev_iters,
alpha,
);
return;
}
if level == 0 {
let omega = chebyshev::optimal_jacobi_damping(
hierarchy.wavenumber,
hierarchy.h_values[level],
);
chebyshev::jacobi_damped_smooth(operator, x, b, omega);
} else {
let alpha = hierarchy.alpha_values[level];
chebyshev::chebyshev_smooth(
operator,
x,
b,
config.chebyshev_iterations,
alpha,
);
}
let residual = compute_residual(operator, x, b);
let restriction = hierarchy.restriction(level).expect("Missing restriction operator");
let r_coarse_vec = restriction.apply(residual.as_slice().unwrap());
let r_coarse = Array1::from(r_coarse_vec);
let n_coarse = hierarchy.n_dofs(level + 1);
let mut e_coarse = Array1::zeros(n_coarse);
wave_adr_cycle(hierarchy, level + 1, &mut e_coarse, &r_coarse, eikonal, config);
let prolongation = hierarchy.prolongation(level).expect("Missing prolongation operator");
let e_fine_vec = prolongation.apply(e_coarse.as_slice().unwrap());
for (i, &correction) in e_fine_vec.iter().enumerate() {
x[i] += correction;
}
if level == hierarchy.adr_level {
let alpha = hierarchy.alpha_values[level];
chebyshev::chebyshev_smooth(
operator,
x,
b,
config.chebyshev_iterations,
alpha,
);
apply_adr_correction(hierarchy, level, x, b, eikonal, config);
} else if level == 0 {
let omega = chebyshev::optimal_jacobi_damping(
hierarchy.wavenumber,
hierarchy.h_values[level],
);
chebyshev::jacobi_damped_smooth(operator, x, b, omega);
} else {
let alpha = hierarchy.alpha_values[level];
chebyshev::chebyshev_smooth(operator, x, b, config.chebyshev_iterations, alpha);
}
}
fn apply_adr_correction(
hierarchy: &WaveHierarchy,
level: usize,
x: &mut Array1<Complex64>,
b: &Array1<Complex64>,
eikonal: &EikonalSolution,
config: &NeuralMultigridConfig,
) {
let operator = &hierarchy.operators[level];
let n_dofs = hierarchy.n_dofs(level);
let residual = compute_residual(operator, x, b);
let level_mesh = hierarchy.mesh(level);
let source = eikonal.source;
let level_eikonal = eikonal::solve_eikonal_homogeneous(source, 1.0, level_mesh);
let r_hat = adr::transform_residual(&residual, &level_eikonal.tau, hierarchy.omega);
let slowness_values = vec![1.0; n_dofs];
let adr_matrix = adr::assemble_adr_system(
&hierarchy.stiffness_csrs[level],
&hierarchy.mass_csrs[level],
hierarchy.omega,
&level_eikonal,
&slowness_values,
hierarchy.gamma,
n_dofs,
);
let amplitude = adr::solve_adr(
&adr_matrix,
&r_hat,
config.adr_gmres_iters,
config.adr_tolerance,
);
let correction = adr::transform_correction(&litude, &level_eikonal.tau, hierarchy.omega);
for i in 0..n_dofs.min(x.len()) {
x[i] += correction[i];
}
}
fn compute_residual(
a: &CsrMatrix<Complex64>,
x: &Array1<Complex64>,
b: &Array1<Complex64>,
) -> Array1<Complex64> {
let ax = a.matvec(x);
b - &ax
}
fn compute_residual_norm(
a: &CsrMatrix<Complex64>,
x: &Array1<Complex64>,
b: &Array1<Complex64>,
) -> f64 {
let r = compute_residual(a, x, b);
r.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt()
}
fn apply_dirichlet_to_system(
_a: &CsrMatrix<Complex64>,
x: &mut Array1<Complex64>,
_rhs: &mut Array1<Complex64>,
dirichlet_bcs: &[(usize, Complex64)],
) {
for &(node, value) in dirichlet_bcs {
x[node] = value;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assembly::HelmholtzProblem;
use crate::mesh::unit_square_triangles;
#[test]
fn test_neural_multigrid_basic() {
let mesh = unit_square_triangles(8);
let k = 1.0;
let k_complex = Complex64::new(k, 0.0);
let problem = HelmholtzProblem::assemble(
&mesh,
PolynomialDegree::P1,
k_complex,
|x, y, _z| Complex64::new((x * std::f64::consts::PI).sin() * (y * std::f64::consts::PI).sin(), 0.0),
);
let mut dirichlet_pairs = Vec::new();
for boundary in &mesh.boundaries {
for &node in &boundary.nodes {
dirichlet_pairs.push((node, Complex64::new(0.0, 0.0)));
}
}
dirichlet_pairs.sort_by_key(|&(n, _)| n);
dirichlet_pairs.dedup_by_key(|pair| pair.0);
let config = NeuralMultigridConfig {
max_iterations: 50,
tolerance: 1e-4,
verbosity: 0,
..Default::default()
};
let result = solve_neural_multigrid(
&mesh,
PolynomialDegree::P1,
k_complex,
&problem.rhs,
&dirichlet_pairs,
&config,
);
match result {
Ok(sol) => {
assert_eq!(sol.values.len(), mesh.num_nodes());
assert!(sol.iterations > 0);
}
Err(SolverError::ConvergenceFailure(iters, _residual)) => {
assert!(iters > 0);
}
Err(e) => panic!("Unexpected error: {}", e),
}
}
#[test]
fn test_neural_multigrid_hierarchy_and_components() {
let mesh = unit_square_triangles(8);
let k = 2.0;
let wave_hierarchy =
WaveHierarchy::build(mesh.clone(), PolynomialDegree::P1, k, 0.5);
assert!(wave_hierarchy.num_levels() >= 2);
assert!(wave_hierarchy.adr_level >= 1);
assert!(wave_hierarchy.adr_level < wave_hierarchy.num_levels());
for i in 0..wave_hierarchy.num_levels() {
assert!(wave_hierarchy.operators[i].nnz() > 0);
assert!(wave_hierarchy.stiffness_csrs[i].nnz() > 0);
assert!(wave_hierarchy.mass_csrs[i].nnz() > 0);
assert!(wave_hierarchy.kh_values[i] > 0.0);
assert!(wave_hierarchy.h_values[i] > 0.0);
assert!(wave_hierarchy.alpha_values[i] >= 0.01);
assert!(wave_hierarchy.alpha_values[i] <= 0.9);
}
for i in 1..wave_hierarchy.num_levels() {
assert!(wave_hierarchy.kh_values[i] >= wave_hierarchy.kh_values[i - 1] * 0.9);
}
for i in 1..wave_hierarchy.num_levels() {
assert!(wave_hierarchy.n_dofs(i) < wave_hierarchy.n_dofs(i - 1));
}
let source = eikonal::domain_center(&mesh);
let eikonal_sol = eikonal::solve_eikonal_homogeneous(source, 1.0, &mesh);
assert_eq!(eikonal_sol.tau.len(), mesh.num_nodes());
assert_eq!(eikonal_sol.grad_tau.len(), mesh.num_nodes());
assert_eq!(eikonal_sol.laplacian_tau.len(), mesh.num_nodes());
}
}