Skip to main content

math_audio_bem/core/
bem_solver.rs

1//! High-level BEM Solver API
2//!
3//! This module provides a unified, high-level interface for acoustic BEM simulations.
4//! It integrates mesh generation, system assembly, linear solving, and post-processing.
5//!
6//! The Direct solver method works in both native and WASM modes. When the `native` feature
7//! is enabled, it uses optimized BLAS/LAPACK. Otherwise, it falls back to a pure Rust
8//! LU factorization implementation.
9//!
10//! # Example
11//!
12//! ```ignore
13//! use math_audio_bem::core::{BemSolver, BemProblem, IncidentField};
14//!
15//! // Create a rigid sphere scattering problem
16//! let problem = BemProblem::rigid_sphere_scattering(
17//!     0.1,        // radius
18//!     1000.0,     // frequency
19//!     343.0,      // speed of sound
20//!     1.21,       // density
21//! );
22//!
23//! // Configure solver
24//! let solver = BemSolver::new()
25//!     .with_mesh_refinement(3)
26//!     .with_solver_method(SolverMethod::Direct);
27//!
28//! // Solve
29//! let solution = solver.solve(&problem)?;
30//!
31//! // Evaluate field at a point
32//! let pressure = solution.evaluate_pressure(&[0.0, 0.0, 0.2]);
33//! ```
34
35use ndarray::{Array1, Array2};
36use num_complex::Complex64;
37use std::f64::consts::PI;
38
39use crate::core::assembly::slfmm::{SlfmmSystem, build_slfmm_system};
40use crate::core::assembly::tbem::build_tbem_system_with_beta;
41use crate::core::incident::IncidentField;
42use crate::core::mesh::generators::{generate_icosphere_mesh, generate_sphere_mesh};
43use crate::core::postprocess::pressure::{FieldPoint, compute_total_field};
44use crate::core::types::{BoundaryCondition, Element, Mesh, PhysicsParams};
45use math_audio_solvers::direct::lu_solve;
46use math_audio_solvers::iterative::{BiCgstabConfig, bicgstab};
47use math_audio_solvers::traits::LinearOperator;
48
49/// Solver method for the linear system
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum SolverMethod {
52    /// Direct LU factorization (best for small problems)
53    #[default]
54    Direct,
55    /// Conjugate Gradient Squared (iterative)
56    Cgs,
57    /// BiCGSTAB (iterative, more stable than CGS)
58    BiCgStab,
59}
60
61/// Assembly method for the BEM matrix
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum AssemblyMethod {
64    /// Traditional BEM with O(N^2) dense matrix
65    #[default]
66    Tbem,
67    /// Single-Level Fast Multipole Method
68    Slfmm,
69    /// Multi-Level Fast Multipole Method
70    Mlfmm,
71}
72
73/// Boundary condition type for the problem
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum BoundaryConditionType {
76    /// Rigid surface (zero normal velocity)
77    Rigid,
78    /// Soft surface (zero pressure)
79    Soft,
80    /// Impedance boundary condition
81    Impedance,
82}
83
84/// Definition of a BEM problem
85#[derive(Debug, Clone)]
86pub struct BemProblem {
87    /// Problem geometry mesh
88    pub mesh: Mesh,
89    /// Physical parameters
90    pub physics: PhysicsParams,
91    /// Incident field
92    pub incident_field: IncidentField,
93    /// Boundary condition type
94    pub bc_type: BoundaryConditionType,
95    /// Use Burton-Miller formulation (recommended for exterior problems)
96    pub use_burton_miller: bool,
97}
98
99impl BemProblem {
100    /// Create a rigid sphere scattering problem with plane wave incidence
101    ///
102    /// # Arguments
103    /// * `radius` - Sphere radius (m)
104    /// * `frequency` - Excitation frequency (Hz)
105    /// * `speed_of_sound` - Speed of sound (m/s)
106    /// * `density` - Medium density (kg/m�)
107    pub fn rigid_sphere_scattering(
108        radius: f64,
109        frequency: f64,
110        speed_of_sound: f64,
111        density: f64,
112    ) -> Self {
113        // Determine mesh resolution based on ka
114        let k = 2.0 * PI * frequency / speed_of_sound;
115        let ka = k * radius;
116
117        // Rule of thumb: ~10 elements per wavelength
118        // For a sphere, this translates to subdivision level
119        let subdivisions = if ka < 1.0 {
120            2 // Low frequency: coarse mesh ok
121        } else if ka < 5.0 {
122            3 // Medium frequency
123        } else {
124            4 // High frequency: need finer mesh
125        };
126
127        let mesh = generate_icosphere_mesh(radius, subdivisions);
128        let physics = PhysicsParams::new(frequency, speed_of_sound, density, false);
129        let incident_field = IncidentField::plane_wave_z();
130
131        Self {
132            mesh,
133            physics,
134            incident_field,
135            bc_type: BoundaryConditionType::Rigid,
136            use_burton_miller: true,
137        }
138    }
139
140    /// Create a rigid sphere scattering problem with custom mesh resolution
141    pub fn rigid_sphere_scattering_custom(
142        radius: f64,
143        frequency: f64,
144        speed_of_sound: f64,
145        density: f64,
146        n_theta: usize,
147        n_phi: usize,
148    ) -> Self {
149        let mesh = generate_sphere_mesh(radius, n_theta, n_phi);
150        let physics = PhysicsParams::new(frequency, speed_of_sound, density, false);
151        let incident_field = IncidentField::plane_wave_z();
152
153        Self {
154            mesh,
155            physics,
156            incident_field,
157            bc_type: BoundaryConditionType::Rigid,
158            use_burton_miller: true,
159        }
160    }
161
162    /// Set the incident field
163    pub fn with_incident_field(mut self, field: IncidentField) -> Self {
164        self.incident_field = field;
165        self
166    }
167
168    /// Set the boundary condition type
169    pub fn with_boundary_condition(mut self, bc_type: BoundaryConditionType) -> Self {
170        self.bc_type = bc_type;
171        self
172    }
173
174    /// Enable/disable Burton-Miller formulation
175    pub fn with_burton_miller(mut self, use_bm: bool) -> Self {
176        self.use_burton_miller = use_bm;
177        self
178    }
179
180    /// Get the wave number times radius (ka)
181    pub fn ka(&self) -> f64 {
182        self.physics.wave_number * self.mesh_radius()
183    }
184
185    /// Estimate mesh radius from bounding box
186    fn mesh_radius(&self) -> f64 {
187        // Estimate radius from mesh nodes
188        let mut max_r = 0.0f64;
189        for i in 0..self.mesh.nodes.nrows() {
190            let r = (self.mesh.nodes[[i, 0]].powi(2)
191                + self.mesh.nodes[[i, 1]].powi(2)
192                + self.mesh.nodes[[i, 2]].powi(2))
193            .sqrt();
194            max_r = max_r.max(r);
195        }
196        max_r
197    }
198}
199
200/// BEM solver configuration
201#[derive(Debug, Clone)]
202pub struct BemSolver {
203    /// Linear solver method
204    pub solver_method: SolverMethod,
205    /// Matrix assembly method
206    pub assembly_method: AssemblyMethod,
207    /// Maximum iterations for iterative solvers
208    pub max_iterations: usize,
209    /// Tolerance for iterative solvers
210    pub tolerance: f64,
211    /// Verbose output
212    pub verbose: bool,
213    /// Burton-Miller β scale factor (default: 4.0 for best accuracy at ka ~ 1)
214    pub beta_scale: f64,
215}
216
217impl Default for BemSolver {
218    fn default() -> Self {
219        Self {
220            solver_method: SolverMethod::Direct,
221            assembly_method: AssemblyMethod::Tbem,
222            max_iterations: 1000,
223            tolerance: 1e-8,
224            verbose: false,
225            beta_scale: 4.0, // Empirically optimal for ka ~ 1
226        }
227    }
228}
229
230impl BemSolver {
231    /// Create a new solver with default settings
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    /// Set the linear solver method
237    pub fn with_solver_method(mut self, method: SolverMethod) -> Self {
238        self.solver_method = method;
239        self
240    }
241
242    /// Set the assembly method
243    pub fn with_assembly_method(mut self, method: AssemblyMethod) -> Self {
244        self.assembly_method = method;
245        self
246    }
247
248    /// Set maximum iterations for iterative solvers
249    pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
250        self.max_iterations = max_iter;
251        self
252    }
253
254    /// Set tolerance for iterative solvers
255    pub fn with_tolerance(mut self, tol: f64) -> Self {
256        self.tolerance = tol;
257        self
258    }
259
260    /// Enable verbose output
261    pub fn with_verbose(mut self, verbose: bool) -> Self {
262        self.verbose = verbose;
263        self
264    }
265
266    /// Solve a BEM problem
267    ///
268    /// # Arguments
269    /// * `problem` - The BEM problem definition
270    ///
271    /// # Returns
272    /// Solution containing surface pressures and methods to evaluate fields
273    pub fn solve(&self, problem: &BemProblem) -> Result<BemSolution, BemError> {
274        if self.verbose {
275            log::info!(
276                "Solving BEM problem: {} elements, ka = {:.3}",
277                problem.mesh.elements.len(),
278                problem.ka()
279            );
280        }
281
282        let elements = self.prepare_elements(problem);
283
284        let assembly_result =
285            self.assemble_system(&elements, &problem.mesh.nodes, &problem.physics)?;
286
287        let rhs = match &assembly_result {
288            AssemblyResult::Dense(_, rhs) => rhs.clone(),
289            AssemblyResult::Slfmm(system) => system.rhs.clone(),
290        };
291
292        let rhs = self.add_incident_field_rhs(
293            rhs,
294            &elements,
295            &problem.incident_field,
296            &problem.physics,
297            problem.use_burton_miller,
298        );
299
300        let surface_pressure = match assembly_result {
301            AssemblyResult::Dense(matrix, _) => self.solve_dense_system(&matrix, &rhs)?,
302            AssemblyResult::Slfmm(system) => self.solve_fmm_system(system, &rhs)?,
303        };
304
305        if self.verbose {
306            log::info!(
307                "Solution complete. Max surface pressure: {:.6}",
308                surface_pressure
309                    .iter()
310                    .map(|p| p.norm())
311                    .fold(0.0f64, f64::max)
312            );
313        }
314
315        Ok(BemSolution {
316            surface_pressure,
317            elements,
318            nodes: problem.mesh.nodes.clone(),
319            incident_field: problem.incident_field.clone(),
320            physics: problem.physics.clone(),
321        })
322    }
323
324    /// Prepare elements with boundary conditions
325    fn prepare_elements(&self, problem: &BemProblem) -> Vec<Element> {
326        let mut elements = problem.mesh.elements.clone();
327
328        // Set boundary conditions based on problem type
329        let bc = match problem.bc_type {
330            BoundaryConditionType::Rigid => {
331                // Zero normal velocity (Neumann BC)
332                BoundaryCondition::Velocity(vec![Complex64::new(0.0, 0.0)])
333            }
334            BoundaryConditionType::Soft => {
335                // Zero pressure (Dirichlet BC)
336                BoundaryCondition::Pressure(vec![Complex64::new(0.0, 0.0)])
337            }
338            BoundaryConditionType::Impedance => {
339                // Default impedance (plane wave)
340                let z0 = problem.physics.density * problem.physics.speed_of_sound;
341                BoundaryCondition::VelocityWithAdmittance {
342                    velocity: vec![Complex64::new(0.0, 0.0)],
343                    admittance: Complex64::new(1.0 / z0, 0.0),
344                }
345            }
346        };
347
348        // Assign BC and DOF addresses to each element
349        for (i, elem) in elements.iter_mut().enumerate() {
350            elem.boundary_condition = bc.clone();
351            elem.dof_addresses = vec![i];
352        }
353
354        elements
355    }
356
357    /// Assemble the BEM system matrix
358    fn assemble_system(
359        &self,
360        elements: &[Element],
361        nodes: &Array2<f64>,
362        physics: &PhysicsParams,
363    ) -> Result<AssemblyResult, BemError> {
364        match self.assembly_method {
365            AssemblyMethod::Tbem => {
366                let beta = physics.burton_miller_beta_scaled(self.beta_scale);
367                let system = build_tbem_system_with_beta(elements, nodes, physics, beta);
368                Ok(AssemblyResult::Dense(system.matrix, system.rhs))
369            }
370            AssemblyMethod::Slfmm => {
371                #[cfg(any(feature = "native", feature = "wasm"))]
372                {
373                    use crate::core::types::Cluster;
374
375                    let num_elements = elements.len();
376                    let _elements_per_cluster = 16usize;
377
378                    let cluster = Cluster::new(Array1::from_vec(vec![0.0, 0.0, 0.0]));
379                    let mut clusters = vec![cluster];
380                    clusters[0].element_indices = (0..num_elements).collect();
381
382                    let n_theta = 6;
383                    let n_phi = 12;
384                    let n_terms = 5;
385
386                    let system = build_slfmm_system(
387                        elements, nodes, &clusters, physics, n_theta, n_phi, n_terms,
388                    );
389                    Ok(AssemblyResult::Slfmm(system))
390                }
391                #[cfg(not(any(feature = "native", feature = "wasm")))]
392                Err(BemError::NotImplemented(
393                    "SLFMM requires native or wasm feature".to_string(),
394                ))
395            }
396            AssemblyMethod::Mlfmm => Err(BemError::NotImplemented(
397                "MLFMM not yet integrated in high-level API".to_string(),
398            )),
399        }
400    }
401
402    /// Add incident field contribution to RHS
403    fn add_incident_field_rhs(
404        &self,
405        mut rhs: Array1<Complex64>,
406        elements: &[Element],
407        incident_field: &IncidentField,
408        physics: &PhysicsParams,
409        use_burton_miller: bool,
410    ) -> Array1<Complex64> {
411        let n = elements.len();
412        let mut centers = Array2::zeros((n, 3));
413        let mut normals = Array2::zeros((n, 3));
414
415        for (i, elem) in elements.iter().enumerate() {
416            for j in 0..3 {
417                centers[[i, j]] = elem.center[j];
418                normals[[i, j]] = elem.normal[j];
419            }
420        }
421
422        let incident_rhs = if use_burton_miller {
423            let beta = physics.burton_miller_beta_scaled(self.beta_scale);
424            incident_field.compute_rhs_with_beta(&centers, &normals, physics, beta)
425        } else {
426            incident_field.compute_rhs(&centers, &normals, physics, false)
427        };
428
429        rhs = rhs + incident_rhs;
430
431        rhs
432    }
433
434    /// Solve dense linear system
435    fn solve_dense_system(
436        &self,
437        matrix: &Array2<Complex64>,
438        rhs: &Array1<Complex64>,
439    ) -> Result<Array1<Complex64>, BemError> {
440        match self.solver_method {
441            SolverMethod::Direct => {
442                lu_solve(matrix, rhs).map_err(|e| BemError::SolverFailed(e.to_string()))
443            }
444            SolverMethod::Cgs | SolverMethod::BiCgStab => {
445                let config = BiCgstabConfig {
446                    max_iterations: self.max_iterations,
447                    tolerance: self.tolerance,
448                    print_interval: 0,
449                };
450
451                match bicgstab(&DenseMatrixOperator(matrix), rhs, &config) {
452                    sol if sol.converged => Ok(sol.x),
453                    sol => Err(BemError::SolverFailed(format!(
454                        "BiCGSTAB did not converge: residual = {}",
455                        sol.residual
456                    ))),
457                }
458            }
459        }
460    }
461
462    /// Solve FMM system using matrix-vector products
463    fn solve_fmm_system(
464        &self,
465        system: SlfmmSystem,
466        rhs: &Array1<Complex64>,
467    ) -> Result<Array1<Complex64>, BemError> {
468        match self.solver_method {
469            SolverMethod::Direct => {
470                if system.num_dofs <= 2000 {
471                    let matrix = system.extract_near_field_matrix();
472                    lu_solve(&matrix, rhs).map_err(|e| BemError::SolverFailed(e.to_string()))
473                } else {
474                    Err(BemError::NotImplemented(
475                        "Direct solver not available for large FMM problems".to_string(),
476                    ))
477                }
478            }
479            SolverMethod::Cgs | SolverMethod::BiCgStab => {
480                let config = BiCgstabConfig {
481                    max_iterations: self.max_iterations,
482                    tolerance: self.tolerance,
483                    print_interval: 0,
484                };
485
486                match bicgstab(&system, rhs, &config) {
487                    sol if sol.converged => Ok(sol.x),
488                    sol => Err(BemError::SolverFailed(format!(
489                        "BiCGSTAB did not converge: residual = {}",
490                        sol.residual
491                    ))),
492                }
493            }
494        }
495    }
496}
497
498/// Solution of a BEM problem
499#[derive(Debug, Clone)]
500pub struct BemSolution {
501    /// Surface pressure at each element
502    pub surface_pressure: Array1<Complex64>,
503    /// Elements used in the solution
504    pub elements: Vec<Element>,
505    /// Node coordinates
506    pub nodes: Array2<f64>,
507    /// Incident field used
508    pub incident_field: IncidentField,
509    /// Physics parameters
510    pub physics: PhysicsParams,
511}
512
513impl BemSolution {
514    /// Evaluate total pressure at a single point
515    pub fn evaluate_pressure(&self, point: &[f64; 3]) -> Complex64 {
516        let eval_points =
517            Array2::from_shape_vec((1, 3), vec![point[0], point[1], point[2]]).unwrap();
518
519        let field_points = compute_total_field(
520            &eval_points,
521            &self.elements,
522            &self.nodes,
523            &self.surface_pressure,
524            None,
525            &self.incident_field,
526            &self.physics,
527        );
528
529        field_points[0].p_total
530    }
531
532    /// Evaluate total pressure at multiple points
533    pub fn evaluate_pressure_field(&self, points: &Array2<f64>) -> Vec<FieldPoint> {
534        compute_total_field(
535            points,
536            &self.elements,
537            &self.nodes,
538            &self.surface_pressure,
539            None,
540            &self.incident_field,
541            &self.physics,
542        )
543    }
544
545    /// Get max surface pressure magnitude
546    pub fn max_surface_pressure(&self) -> f64 {
547        self.surface_pressure
548            .iter()
549            .map(|p| p.norm())
550            .fold(0.0f64, f64::max)
551    }
552
553    /// Get mean surface pressure magnitude
554    pub fn mean_surface_pressure(&self) -> f64 {
555        let sum: f64 = self.surface_pressure.iter().map(|p| p.norm()).sum();
556        sum / self.surface_pressure.len() as f64
557    }
558
559    /// Number of DOFs in the solution
560    pub fn num_dofs(&self) -> usize {
561        self.surface_pressure.len()
562    }
563}
564
565/// BEM solver errors
566#[derive(Debug, Clone)]
567pub enum BemError {
568    /// Feature not yet implemented
569    NotImplemented(String),
570    /// Linear solver failed
571    SolverFailed(String),
572    /// Invalid mesh
573    InvalidMesh(String),
574    /// Invalid parameters
575    InvalidParameters(String),
576}
577
578impl std::fmt::Display for BemError {
579    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580        match self {
581            BemError::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
582            BemError::SolverFailed(msg) => write!(f, "Solver failed: {}", msg),
583            BemError::InvalidMesh(msg) => write!(f, "Invalid mesh: {}", msg),
584            BemError::InvalidParameters(msg) => write!(f, "Invalid parameters: {}", msg),
585        }
586    }
587}
588
589impl std::error::Error for BemError {}
590
591/// Result of system assembly (dense matrix or FMM system)
592enum AssemblyResult {
593    /// Dense matrix for TBEM
594    Dense(Array2<Complex64>, Array1<Complex64>),
595    /// SLFMM system for iterative solvers
596    Slfmm(SlfmmSystem),
597}
598
599impl AssemblyResult {
600    #[allow(dead_code)]
601    fn num_dofs(&self) -> usize {
602        match self {
603            AssemblyResult::Dense(m, _) => m.nrows(),
604            AssemblyResult::Slfmm(s) => s.num_dofs,
605        }
606    }
607}
608
609struct DenseMatrixOperator<'a>(&'a Array2<Complex64>);
610
611impl<'a> LinearOperator<Complex64> for DenseMatrixOperator<'a> {
612    fn num_rows(&self) -> usize {
613        self.0.nrows()
614    }
615
616    fn num_cols(&self) -> usize {
617        self.0.ncols()
618    }
619
620    fn apply(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
621        self.0.dot(x)
622    }
623
624    fn apply_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
625        self.0.t().dot(x)
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_bem_problem_creation() {
635        let problem = BemProblem::rigid_sphere_scattering(0.1, 1000.0, 343.0, 1.21);
636
637        assert!(!problem.mesh.elements.is_empty());
638        assert!(problem.mesh.nodes.nrows() > 0);
639        assert!(problem.ka() > 0.0);
640    }
641
642    #[test]
643    fn test_bem_solver_creation() {
644        let solver = BemSolver::new()
645            .with_solver_method(SolverMethod::Direct)
646            .with_assembly_method(AssemblyMethod::Tbem)
647            .with_verbose(false);
648
649        assert_eq!(solver.solver_method, SolverMethod::Direct);
650        assert_eq!(solver.assembly_method, AssemblyMethod::Tbem);
651    }
652
653    // Direct solver tests work in both native and WASM modes
654    #[test]
655    fn test_bem_solver_small_problem() {
656        // Very small problem for quick test
657        let problem = BemProblem::rigid_sphere_scattering_custom(
658            0.1,   // radius
659            100.0, // very low frequency for quick test
660            343.0, 1.21, 4, // coarse mesh
661            8,
662        );
663
664        let solver = BemSolver::new();
665        let result = solver.solve(&problem);
666
667        assert!(result.is_ok());
668        let solution = result.unwrap();
669        assert!(solution.num_dofs() > 0);
670        assert!(solution.max_surface_pressure() > 0.0);
671    }
672
673    #[test]
674    fn test_field_evaluation() {
675        // Very small problem
676        let problem = BemProblem::rigid_sphere_scattering_custom(0.1, 100.0, 343.0, 1.21, 4, 8);
677
678        let solver = BemSolver::new();
679        let solution = solver.solve(&problem).unwrap();
680
681        // Evaluate at a point outside the sphere
682        let p = solution.evaluate_pressure(&[0.0, 0.0, 0.2]);
683
684        // Should have some pressure (incident + scattered)
685        assert!(p.norm() > 0.0);
686    }
687}