1use ndarray::{Array1, Array2};
36use num_complex::Complex64;
37use std::f64::consts::PI;
38
39use crate::core::assembly::slfmm::{SlfmmSystem, build_slfmm_system};
40use crate::core::assembly::tbem::build_tbem_system_with_beta;
41use crate::core::incident::IncidentField;
42use crate::core::mesh::generators::{generate_icosphere_mesh, generate_sphere_mesh};
43use crate::core::postprocess::pressure::{FieldPoint, compute_total_field};
44use crate::core::types::{BoundaryCondition, Element, Mesh, PhysicsParams};
45use math_audio_solvers::direct::lu_solve;
46use math_audio_solvers::iterative::{BiCgstabConfig, bicgstab};
47use math_audio_solvers::traits::LinearOperator;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum SolverMethod {
52 #[default]
54 Direct,
55 Cgs,
57 BiCgStab,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum AssemblyMethod {
64 #[default]
66 Tbem,
67 Slfmm,
69 Mlfmm,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum BoundaryConditionType {
76 Rigid,
78 Soft,
80 Impedance,
82}
83
84#[derive(Debug, Clone)]
86pub struct BemProblem {
87 pub mesh: Mesh,
89 pub physics: PhysicsParams,
91 pub incident_field: IncidentField,
93 pub bc_type: BoundaryConditionType,
95 pub use_burton_miller: bool,
97}
98
99impl BemProblem {
100 pub fn rigid_sphere_scattering(
108 radius: f64,
109 frequency: f64,
110 speed_of_sound: f64,
111 density: f64,
112 ) -> Self {
113 let k = 2.0 * PI * frequency / speed_of_sound;
115 let ka = k * radius;
116
117 let subdivisions = if ka < 1.0 {
120 2 } else if ka < 5.0 {
122 3 } else {
124 4 };
126
127 let mesh = generate_icosphere_mesh(radius, subdivisions);
128 let physics = PhysicsParams::new(frequency, speed_of_sound, density, false);
129 let incident_field = IncidentField::plane_wave_z();
130
131 Self {
132 mesh,
133 physics,
134 incident_field,
135 bc_type: BoundaryConditionType::Rigid,
136 use_burton_miller: true,
137 }
138 }
139
140 pub fn rigid_sphere_scattering_custom(
142 radius: f64,
143 frequency: f64,
144 speed_of_sound: f64,
145 density: f64,
146 n_theta: usize,
147 n_phi: usize,
148 ) -> Self {
149 let mesh = generate_sphere_mesh(radius, n_theta, n_phi);
150 let physics = PhysicsParams::new(frequency, speed_of_sound, density, false);
151 let incident_field = IncidentField::plane_wave_z();
152
153 Self {
154 mesh,
155 physics,
156 incident_field,
157 bc_type: BoundaryConditionType::Rigid,
158 use_burton_miller: true,
159 }
160 }
161
162 pub fn with_incident_field(mut self, field: IncidentField) -> Self {
164 self.incident_field = field;
165 self
166 }
167
168 pub fn with_boundary_condition(mut self, bc_type: BoundaryConditionType) -> Self {
170 self.bc_type = bc_type;
171 self
172 }
173
174 pub fn with_burton_miller(mut self, use_bm: bool) -> Self {
176 self.use_burton_miller = use_bm;
177 self
178 }
179
180 pub fn ka(&self) -> f64 {
182 self.physics.wave_number * self.mesh_radius()
183 }
184
185 fn mesh_radius(&self) -> f64 {
187 let mut max_r = 0.0f64;
189 for i in 0..self.mesh.nodes.nrows() {
190 let r = (self.mesh.nodes[[i, 0]].powi(2)
191 + self.mesh.nodes[[i, 1]].powi(2)
192 + self.mesh.nodes[[i, 2]].powi(2))
193 .sqrt();
194 max_r = max_r.max(r);
195 }
196 max_r
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct BemSolver {
203 pub solver_method: SolverMethod,
205 pub assembly_method: AssemblyMethod,
207 pub max_iterations: usize,
209 pub tolerance: f64,
211 pub verbose: bool,
213 pub beta_scale: f64,
215}
216
217impl Default for BemSolver {
218 fn default() -> Self {
219 Self {
220 solver_method: SolverMethod::Direct,
221 assembly_method: AssemblyMethod::Tbem,
222 max_iterations: 1000,
223 tolerance: 1e-8,
224 verbose: false,
225 beta_scale: 4.0, }
227 }
228}
229
230impl BemSolver {
231 pub fn new() -> Self {
233 Self::default()
234 }
235
236 pub fn with_solver_method(mut self, method: SolverMethod) -> Self {
238 self.solver_method = method;
239 self
240 }
241
242 pub fn with_assembly_method(mut self, method: AssemblyMethod) -> Self {
244 self.assembly_method = method;
245 self
246 }
247
248 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
250 self.max_iterations = max_iter;
251 self
252 }
253
254 pub fn with_tolerance(mut self, tol: f64) -> Self {
256 self.tolerance = tol;
257 self
258 }
259
260 pub fn with_verbose(mut self, verbose: bool) -> Self {
262 self.verbose = verbose;
263 self
264 }
265
266 pub fn solve(&self, problem: &BemProblem) -> Result<BemSolution, BemError> {
274 if self.verbose {
275 log::info!(
276 "Solving BEM problem: {} elements, ka = {:.3}",
277 problem.mesh.elements.len(),
278 problem.ka()
279 );
280 }
281
282 let elements = self.prepare_elements(problem);
283
284 let assembly_result =
285 self.assemble_system(&elements, &problem.mesh.nodes, &problem.physics)?;
286
287 let rhs = match &assembly_result {
288 AssemblyResult::Dense(_, rhs) => rhs.clone(),
289 AssemblyResult::Slfmm(system) => system.rhs.clone(),
290 };
291
292 let rhs = self.add_incident_field_rhs(
293 rhs,
294 &elements,
295 &problem.incident_field,
296 &problem.physics,
297 problem.use_burton_miller,
298 );
299
300 let surface_pressure = match assembly_result {
301 AssemblyResult::Dense(matrix, _) => self.solve_dense_system(&matrix, &rhs)?,
302 AssemblyResult::Slfmm(system) => self.solve_fmm_system(system, &rhs)?,
303 };
304
305 if self.verbose {
306 log::info!(
307 "Solution complete. Max surface pressure: {:.6}",
308 surface_pressure
309 .iter()
310 .map(|p| p.norm())
311 .fold(0.0f64, f64::max)
312 );
313 }
314
315 Ok(BemSolution {
316 surface_pressure,
317 elements,
318 nodes: problem.mesh.nodes.clone(),
319 incident_field: problem.incident_field.clone(),
320 physics: problem.physics.clone(),
321 })
322 }
323
324 fn prepare_elements(&self, problem: &BemProblem) -> Vec<Element> {
326 let mut elements = problem.mesh.elements.clone();
327
328 let bc = match problem.bc_type {
330 BoundaryConditionType::Rigid => {
331 BoundaryCondition::Velocity(vec![Complex64::new(0.0, 0.0)])
333 }
334 BoundaryConditionType::Soft => {
335 BoundaryCondition::Pressure(vec![Complex64::new(0.0, 0.0)])
337 }
338 BoundaryConditionType::Impedance => {
339 let z0 = problem.physics.density * problem.physics.speed_of_sound;
341 BoundaryCondition::VelocityWithAdmittance {
342 velocity: vec![Complex64::new(0.0, 0.0)],
343 admittance: Complex64::new(1.0 / z0, 0.0),
344 }
345 }
346 };
347
348 for (i, elem) in elements.iter_mut().enumerate() {
350 elem.boundary_condition = bc.clone();
351 elem.dof_addresses = vec![i];
352 }
353
354 elements
355 }
356
357 fn assemble_system(
359 &self,
360 elements: &[Element],
361 nodes: &Array2<f64>,
362 physics: &PhysicsParams,
363 ) -> Result<AssemblyResult, BemError> {
364 match self.assembly_method {
365 AssemblyMethod::Tbem => {
366 let beta = physics.burton_miller_beta_scaled(self.beta_scale);
367 let system = build_tbem_system_with_beta(elements, nodes, physics, beta);
368 Ok(AssemblyResult::Dense(system.matrix, system.rhs))
369 }
370 AssemblyMethod::Slfmm => {
371 #[cfg(any(feature = "native", feature = "wasm"))]
372 {
373 use crate::core::types::Cluster;
374
375 let num_elements = elements.len();
376 let _elements_per_cluster = 16usize;
377
378 let cluster = Cluster::new(Array1::from_vec(vec![0.0, 0.0, 0.0]));
379 let mut clusters = vec![cluster];
380 clusters[0].element_indices = (0..num_elements).collect();
381
382 let n_theta = 6;
383 let n_phi = 12;
384 let n_terms = 5;
385
386 let system = build_slfmm_system(
387 elements, nodes, &clusters, physics, n_theta, n_phi, n_terms,
388 );
389 Ok(AssemblyResult::Slfmm(system))
390 }
391 #[cfg(not(any(feature = "native", feature = "wasm")))]
392 Err(BemError::NotImplemented(
393 "SLFMM requires native or wasm feature".to_string(),
394 ))
395 }
396 AssemblyMethod::Mlfmm => Err(BemError::NotImplemented(
397 "MLFMM not yet integrated in high-level API".to_string(),
398 )),
399 }
400 }
401
402 fn add_incident_field_rhs(
404 &self,
405 mut rhs: Array1<Complex64>,
406 elements: &[Element],
407 incident_field: &IncidentField,
408 physics: &PhysicsParams,
409 use_burton_miller: bool,
410 ) -> Array1<Complex64> {
411 let n = elements.len();
412 let mut centers = Array2::zeros((n, 3));
413 let mut normals = Array2::zeros((n, 3));
414
415 for (i, elem) in elements.iter().enumerate() {
416 for j in 0..3 {
417 centers[[i, j]] = elem.center[j];
418 normals[[i, j]] = elem.normal[j];
419 }
420 }
421
422 let incident_rhs = if use_burton_miller {
423 let beta = physics.burton_miller_beta_scaled(self.beta_scale);
424 incident_field.compute_rhs_with_beta(¢ers, &normals, physics, beta)
425 } else {
426 incident_field.compute_rhs(¢ers, &normals, physics, false)
427 };
428
429 rhs = rhs + incident_rhs;
430
431 rhs
432 }
433
434 fn solve_dense_system(
436 &self,
437 matrix: &Array2<Complex64>,
438 rhs: &Array1<Complex64>,
439 ) -> Result<Array1<Complex64>, BemError> {
440 match self.solver_method {
441 SolverMethod::Direct => {
442 lu_solve(matrix, rhs).map_err(|e| BemError::SolverFailed(e.to_string()))
443 }
444 SolverMethod::Cgs | SolverMethod::BiCgStab => {
445 let config = BiCgstabConfig {
446 max_iterations: self.max_iterations,
447 tolerance: self.tolerance,
448 print_interval: 0,
449 };
450
451 match bicgstab(&DenseMatrixOperator(matrix), rhs, &config) {
452 sol if sol.converged => Ok(sol.x),
453 sol => Err(BemError::SolverFailed(format!(
454 "BiCGSTAB did not converge: residual = {}",
455 sol.residual
456 ))),
457 }
458 }
459 }
460 }
461
462 fn solve_fmm_system(
464 &self,
465 system: SlfmmSystem,
466 rhs: &Array1<Complex64>,
467 ) -> Result<Array1<Complex64>, BemError> {
468 match self.solver_method {
469 SolverMethod::Direct => {
470 if system.num_dofs <= 2000 {
471 let matrix = system.extract_near_field_matrix();
472 lu_solve(&matrix, rhs).map_err(|e| BemError::SolverFailed(e.to_string()))
473 } else {
474 Err(BemError::NotImplemented(
475 "Direct solver not available for large FMM problems".to_string(),
476 ))
477 }
478 }
479 SolverMethod::Cgs | SolverMethod::BiCgStab => {
480 let config = BiCgstabConfig {
481 max_iterations: self.max_iterations,
482 tolerance: self.tolerance,
483 print_interval: 0,
484 };
485
486 match bicgstab(&system, rhs, &config) {
487 sol if sol.converged => Ok(sol.x),
488 sol => Err(BemError::SolverFailed(format!(
489 "BiCGSTAB did not converge: residual = {}",
490 sol.residual
491 ))),
492 }
493 }
494 }
495 }
496}
497
498#[derive(Debug, Clone)]
500pub struct BemSolution {
501 pub surface_pressure: Array1<Complex64>,
503 pub elements: Vec<Element>,
505 pub nodes: Array2<f64>,
507 pub incident_field: IncidentField,
509 pub physics: PhysicsParams,
511}
512
513impl BemSolution {
514 pub fn evaluate_pressure(&self, point: &[f64; 3]) -> Complex64 {
516 let eval_points =
517 Array2::from_shape_vec((1, 3), vec![point[0], point[1], point[2]]).unwrap();
518
519 let field_points = compute_total_field(
520 &eval_points,
521 &self.elements,
522 &self.nodes,
523 &self.surface_pressure,
524 None,
525 &self.incident_field,
526 &self.physics,
527 );
528
529 field_points[0].p_total
530 }
531
532 pub fn evaluate_pressure_field(&self, points: &Array2<f64>) -> Vec<FieldPoint> {
534 compute_total_field(
535 points,
536 &self.elements,
537 &self.nodes,
538 &self.surface_pressure,
539 None,
540 &self.incident_field,
541 &self.physics,
542 )
543 }
544
545 pub fn max_surface_pressure(&self) -> f64 {
547 self.surface_pressure
548 .iter()
549 .map(|p| p.norm())
550 .fold(0.0f64, f64::max)
551 }
552
553 pub fn mean_surface_pressure(&self) -> f64 {
555 let sum: f64 = self.surface_pressure.iter().map(|p| p.norm()).sum();
556 sum / self.surface_pressure.len() as f64
557 }
558
559 pub fn num_dofs(&self) -> usize {
561 self.surface_pressure.len()
562 }
563}
564
565#[derive(Debug, Clone)]
567pub enum BemError {
568 NotImplemented(String),
570 SolverFailed(String),
572 InvalidMesh(String),
574 InvalidParameters(String),
576}
577
578impl std::fmt::Display for BemError {
579 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580 match self {
581 BemError::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
582 BemError::SolverFailed(msg) => write!(f, "Solver failed: {}", msg),
583 BemError::InvalidMesh(msg) => write!(f, "Invalid mesh: {}", msg),
584 BemError::InvalidParameters(msg) => write!(f, "Invalid parameters: {}", msg),
585 }
586 }
587}
588
589impl std::error::Error for BemError {}
590
591enum AssemblyResult {
593 Dense(Array2<Complex64>, Array1<Complex64>),
595 Slfmm(SlfmmSystem),
597}
598
599impl AssemblyResult {
600 #[allow(dead_code)]
601 fn num_dofs(&self) -> usize {
602 match self {
603 AssemblyResult::Dense(m, _) => m.nrows(),
604 AssemblyResult::Slfmm(s) => s.num_dofs,
605 }
606 }
607}
608
609struct DenseMatrixOperator<'a>(&'a Array2<Complex64>);
610
611impl<'a> LinearOperator<Complex64> for DenseMatrixOperator<'a> {
612 fn num_rows(&self) -> usize {
613 self.0.nrows()
614 }
615
616 fn num_cols(&self) -> usize {
617 self.0.ncols()
618 }
619
620 fn apply(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
621 self.0.dot(x)
622 }
623
624 fn apply_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
625 self.0.t().dot(x)
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_bem_problem_creation() {
635 let problem = BemProblem::rigid_sphere_scattering(0.1, 1000.0, 343.0, 1.21);
636
637 assert!(!problem.mesh.elements.is_empty());
638 assert!(problem.mesh.nodes.nrows() > 0);
639 assert!(problem.ka() > 0.0);
640 }
641
642 #[test]
643 fn test_bem_solver_creation() {
644 let solver = BemSolver::new()
645 .with_solver_method(SolverMethod::Direct)
646 .with_assembly_method(AssemblyMethod::Tbem)
647 .with_verbose(false);
648
649 assert_eq!(solver.solver_method, SolverMethod::Direct);
650 assert_eq!(solver.assembly_method, AssemblyMethod::Tbem);
651 }
652
653 #[test]
655 fn test_bem_solver_small_problem() {
656 let problem = BemProblem::rigid_sphere_scattering_custom(
658 0.1, 100.0, 343.0, 1.21, 4, 8,
662 );
663
664 let solver = BemSolver::new();
665 let result = solver.solve(&problem);
666
667 assert!(result.is_ok());
668 let solution = result.unwrap();
669 assert!(solution.num_dofs() > 0);
670 assert!(solution.max_surface_pressure() > 0.0);
671 }
672
673 #[test]
674 fn test_field_evaluation() {
675 let problem = BemProblem::rigid_sphere_scattering_custom(0.1, 100.0, 343.0, 1.21, 4, 8);
677
678 let solver = BemSolver::new();
679 let solution = solver.solve(&problem).unwrap();
680
681 let p = solution.evaluate_pressure(&[0.0, 0.0, 0.2]);
683
684 assert!(p.norm() > 0.0);
686 }
687}