Skip to main content

math_audio_bem/room_acoustics/
solver.rs

1//! BEM solver for room acoustics
2//!
3//! Solves the Helmholtz equation in the room interior with rigid boundary conditions.
4//!
5//! This module supports three build modes:
6//! - `native`: Uses native rayon for parallel processing (fastest)
7//! - `wasm`: Uses wasm-bindgen-rayon for Web Worker parallelism
8//! - Neither: Falls back to sequential processing
9
10use super::*;
11use crate::core::parallel::{parallel_map, parallel_map_indexed};
12use ndarray::{Array1, Array2};
13use num_complex::Complex64;
14use std::f64::consts::PI;
15
16/// Green's function for 3D Helmholtz equation
17/// G(r) = exp(ikr) / (4πr)
18fn greens_function_3d(r: f64, k: f64) -> Complex64 {
19    if r < 1e-10 {
20        return Complex64::new(0.0, 0.0);
21    }
22    let ikr = Complex64::new(0.0, k * r);
23    ikr.exp() / (4.0 * PI * r)
24}
25
26/// Derivative of Green's function in normal direction
27/// ∂G/∂n = (ikr - 1) * exp(ikr) / (4πr²) * cos(angle)
28fn greens_function_derivative(r: f64, k: f64, cos_angle: f64) -> Complex64 {
29    if r < 1e-10 {
30        return Complex64::new(0.0, 0.0);
31    }
32    let ikr = Complex64::new(0.0, k * r);
33    let factor = (ikr - 1.0) * ikr.exp() / (4.0 * PI * r * r);
34    factor * cos_angle
35}
36
37/// Calculate element center and normal vector
38fn element_center_and_normal(nodes: &[Point3D]) -> (Point3D, Point3D) {
39    // Assume quadrilateral element
40    let center = Point3D::new(
41        nodes.iter().map(|n| n.x).sum::<f64>() / nodes.len() as f64,
42        nodes.iter().map(|n| n.y).sum::<f64>() / nodes.len() as f64,
43        nodes.iter().map(|n| n.z).sum::<f64>() / nodes.len() as f64,
44    );
45
46    // Normal from cross product of edges (works for both triangles and quads)
47    let v1 = Point3D::new(
48        nodes[1].x - nodes[0].x,
49        nodes[1].y - nodes[0].y,
50        nodes[1].z - nodes[0].z,
51    );
52    let v2 = Point3D::new(
53        nodes[2].x - nodes[0].x,
54        nodes[2].y - nodes[0].y,
55        nodes[2].z - nodes[0].z,
56    );
57
58    // Cross product
59    let nx = v1.y * v2.z - v1.z * v2.y;
60    let ny = v1.z * v2.x - v1.x * v2.z;
61    let nz = v1.x * v2.y - v1.y * v2.x;
62
63    let norm = (nx * nx + ny * ny + nz * nz).sqrt();
64    let normal = Point3D::new(nx / norm, ny / norm, nz / norm);
65
66    (center, normal)
67}
68
69/// Calculate element area (supports triangles and quads)
70fn element_area(nodes: &[Point3D]) -> f64 {
71    if nodes.len() == 3 {
72        // Triangle: 0.5 * |v1 × v2|
73        let v1 = Point3D::new(
74            nodes[1].x - nodes[0].x,
75            nodes[1].y - nodes[0].y,
76            nodes[1].z - nodes[0].z,
77        );
78        let v2 = Point3D::new(
79            nodes[2].x - nodes[0].x,
80            nodes[2].y - nodes[0].y,
81            nodes[2].z - nodes[0].z,
82        );
83
84        let cross_x = v1.y * v2.z - v1.z * v2.y;
85        let cross_y = v1.z * v2.x - v1.x * v2.z;
86        let cross_z = v1.x * v2.y - v1.y * v2.x;
87
88        0.5 * (cross_x * cross_x + cross_y * cross_y + cross_z * cross_z).sqrt()
89    } else if nodes.len() == 4 {
90        // Quadrilateral: split into two triangles and sum areas
91        let v1 = Point3D::new(
92            nodes[1].x - nodes[0].x,
93            nodes[1].y - nodes[0].y,
94            nodes[1].z - nodes[0].z,
95        );
96        let v2 = Point3D::new(
97            nodes[2].x - nodes[0].x,
98            nodes[2].y - nodes[0].y,
99            nodes[2].z - nodes[0].z,
100        );
101
102        let cross1_x = v1.y * v2.z - v1.z * v2.y;
103        let cross1_y = v1.z * v2.x - v1.x * v2.z;
104        let cross1_z = v1.x * v2.y - v1.y * v2.x;
105        let area1 = 0.5 * (cross1_x * cross1_x + cross1_y * cross1_y + cross1_z * cross1_z).sqrt();
106
107        let v3 = Point3D::new(
108            nodes[3].x - nodes[0].x,
109            nodes[3].y - nodes[0].y,
110            nodes[3].z - nodes[0].z,
111        );
112
113        let cross2_x = v2.y * v3.z - v2.z * v3.y;
114        let cross2_y = v2.z * v3.x - v2.x * v3.z;
115        let cross2_z = v2.x * v3.y - v2.y * v3.x;
116        let area2 = 0.5 * (cross2_x * cross2_x + cross2_y * cross2_y + cross2_z * cross2_z).sqrt();
117
118        area1 + area2
119    } else {
120        0.0
121    }
122}
123
124/// Build BEM system matrix for rigid boundaries
125pub fn build_bem_matrix(mesh: &RoomMesh, k: f64) -> Array2<Complex64> {
126    let n = mesh.elements.len();
127    let mut matrix = Array2::zeros((n, n));
128
129    // Get element centers and normals
130    let mut centers = Vec::new();
131    let mut normals = Vec::new();
132    let mut areas = Vec::new();
133
134    for element in &mesh.elements {
135        let nodes: Vec<Point3D> = element.nodes.iter().map(|&i| mesh.nodes[i]).collect();
136        let (center, normal) = element_center_and_normal(&nodes);
137        let area = element_area(&nodes);
138
139        centers.push(center);
140        normals.push(normal);
141        areas.push(area);
142    }
143
144    // Fill matrix: rigid boundary condition ∂p/∂n = 0
145    // This gives: Σ_j (∂G/∂n_i)(r_ij) * p_j * A_j = incident field derivative
146    for i in 0..n {
147        for j in 0..n {
148            let r = centers[i].distance_to(&centers[j]);
149
150            if i == j {
151                // Diagonal: use approximation for self-interaction
152                // For planar element: ∂G/∂n ≈ -ik/(2π) for small kr
153                matrix[[i, j]] = Complex64::new(0.0, -k / (2.0 * PI)) * areas[j];
154            } else {
155                // Direction from j to i
156                let dx = centers[i].x - centers[j].x;
157                let dy = centers[i].y - centers[j].y;
158                let dz = centers[i].z - centers[j].z;
159
160                // Cosine of angle between (i-j) direction and normal at i
161                let cos_angle = (dx * normals[i].x + dy * normals[i].y + dz * normals[i].z) / r;
162
163                matrix[[i, j]] = greens_function_derivative(r, k, cos_angle) * areas[j];
164            }
165        }
166    }
167
168    matrix
169}
170
171/// Calculate incident field from sources at element centers
172pub fn calculate_incident_field(
173    mesh: &RoomMesh,
174    sources: &[Source],
175    k: f64,
176    frequency: f64,
177) -> Array1<Complex64> {
178    let n = mesh.elements.len();
179    let mut incident = Array1::zeros(n);
180
181    for (i, element) in mesh.elements.iter().enumerate() {
182        let nodes: Vec<Point3D> = element.nodes.iter().map(|&idx| mesh.nodes[idx]).collect();
183        let (center, _normal) = element_center_and_normal(&nodes);
184
185        let mut total_pressure = Complex64::new(0.0, 0.0);
186
187        for source in sources {
188            let r = center.distance_to(&source.position);
189            let amplitude = source.amplitude_towards(&center, frequency);
190
191            // Incident pressure from monopole source
192            total_pressure += greens_function_3d(r, k) * amplitude;
193        }
194
195        incident[i] = total_pressure;
196    }
197
198    incident
199}
200
201/// Calculate pressure at field points using double-layer potential representation
202///
203/// Uses: p(x) = p_inc(x) + ∫∫ (∂G/∂n)(x, y) * p_surface(y) dS(y)
204pub fn calculate_field_pressure(
205    mesh: &RoomMesh,
206    surface_pressure: &Array1<Complex64>,
207    sources: &[Source],
208    field_points: &[Point3D],
209    k: f64,
210    frequency: f64,
211) -> Array1<Complex64> {
212    let n_points = field_points.len();
213    let mut pressures = Array1::zeros(n_points);
214
215    // Get element data
216    let mut centers = Vec::new();
217    let mut normals = Vec::new();
218    let mut areas = Vec::new();
219
220    for element in &mesh.elements {
221        let nodes: Vec<Point3D> = element.nodes.iter().map(|&i| mesh.nodes[i]).collect();
222        let (center, normal) = element_center_and_normal(&nodes);
223        let area = element_area(&nodes);
224
225        centers.push(center);
226        normals.push(normal);
227        areas.push(area);
228    }
229
230    for (ip, point) in field_points.iter().enumerate() {
231        // Incident field from sources
232        let mut p_incident = Complex64::new(0.0, 0.0);
233        for source in sources {
234            let r = point.distance_to(&source.position);
235            if r < 1e-10 {
236                continue;
237            }
238            let amplitude = source.amplitude_towards(point, frequency);
239            p_incident += greens_function_3d(r, k) * amplitude;
240        }
241
242        // Scattered field from boundary integral using double-layer potential
243        // p_scattered = ∫∫ (∂G/∂n)(x, y) * p_surface(y) dS(y)
244        let mut p_scattered = Complex64::new(0.0, 0.0);
245
246        for j in 0..surface_pressure.len() {
247            let r = point.distance_to(&centers[j]);
248            if r < 1e-10 {
249                continue;
250            }
251
252            // Direction from element center to field point
253            let dx = point.x - centers[j].x;
254            let dy = point.y - centers[j].y;
255            let dz = point.z - centers[j].z;
256
257            // Cosine of angle between (point - center) direction and outward normal
258            let cos_angle = (dx * normals[j].x + dy * normals[j].y + dz * normals[j].z) / r;
259
260            // Normal derivative of Green's function
261            let dg_dn = greens_function_derivative(r, k, cos_angle);
262
263            p_scattered += dg_dn * surface_pressure[j] * areas[j];
264        }
265
266        pressures[ip] = p_incident + p_scattered;
267    }
268
269    pressures
270}
271
272// pressure_to_spl is now in math_audio_xem_common::types
273
274/// Simple GMRES solver for complex linear systems
275/// Solves Ax = b using restarted GMRES
276pub fn gmres_solve(
277    a: &Array2<Complex64>,
278    b: &Array1<Complex64>,
279    max_iter: usize,
280    restart: usize,
281    tol: f64,
282) -> Result<Array1<Complex64>, String> {
283    let n = b.len();
284    if a.nrows() != n || a.ncols() != n {
285        return Err("Matrix dimensions mismatch".to_string());
286    }
287
288    let mut x = Array1::zeros(n);
289
290    for _cycle in 0..max_iter {
291        // Compute initial residual
292        let ax = a.dot(&x);
293        let r = b - &ax;
294        let beta = r.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt();
295
296        if beta < tol {
297            return Ok(x);
298        }
299
300        // Arnoldi iteration
301        let m = restart.min(n);
302        let mut v = vec![Array1::zeros(n); m + 1];
303        let mut h = Array2::<Complex64>::zeros((m + 1, m));
304
305        v[0] = r.mapv(|ri| ri / Complex64::new(beta, 0.0));
306
307        for j in 0..m {
308            // w = A * v[j]
309            let w = a.dot(&v[j]);
310            let mut w_orth = w.clone();
311
312            // Modified Gram-Schmidt orthogonalization
313            for i in 0..=j {
314                h[[i, j]] = v[i]
315                    .iter()
316                    .zip(w_orth.iter())
317                    .map(|(vi, wi)| vi.conj() * wi)
318                    .sum();
319
320                for k in 0..n {
321                    w_orth[k] -= h[[i, j]] * v[i][k];
322                }
323            }
324
325            let h_norm = w_orth.iter().map(|wi| wi.norm_sqr()).sum::<f64>().sqrt();
326            h[[j + 1, j]] = Complex64::new(h_norm, 0.0);
327
328            if h_norm > 1e-12 {
329                v[j + 1] = w_orth.mapv(|wi| wi / Complex64::new(h_norm, 0.0));
330            } else {
331                break;
332            }
333        }
334
335        // Solve least squares problem: minimize ||β*e1 - H*y||
336        let mut e1 = Array1::<Complex64>::zeros(m + 1);
337        e1[0] = Complex64::new(beta, 0.0);
338
339        // Use QR decomposition to solve
340        let y = solve_least_squares(&h, &e1, m)?;
341
342        // Update solution: x = x + V*y
343        for j in 0..m {
344            for k in 0..n {
345                x[k] += y[j] * v[j][k];
346            }
347        }
348
349        // Check convergence
350        let ax = a.dot(&x);
351        let r_final = b - &ax;
352        let residual = r_final.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt();
353
354        if residual < tol {
355            return Ok(x);
356        }
357    }
358
359    Ok(x)
360}
361
362/// Solve least squares problem using back substitution on upper triangular part
363fn solve_least_squares(
364    h: &Array2<Complex64>,
365    e1: &Array1<Complex64>,
366    m: usize,
367) -> Result<Array1<Complex64>, String> {
368    let mut y = Array1::<Complex64>::zeros(m);
369    let mut rhs = e1.slice(ndarray::s![0..m]).to_owned();
370
371    // Apply Givens rotations to make H upper triangular
372    let mut h_tri = h.slice(ndarray::s![0..m, 0..m]).to_owned();
373
374    for i in 0..m {
375        for j in (i + 1)..m {
376            if h_tri[[j, i]].norm() > 1e-12 {
377                let a = h_tri[[i, i]];
378                let b = h_tri[[j, i]];
379                let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
380                let c = a.norm() / r;
381                let s = b / Complex64::new(r, 0.0);
382
383                // Apply rotation to rows i and j
384                for k in i..m {
385                    let temp = c * h_tri[[i, k]] + s * h_tri[[j, k]];
386                    h_tri[[j, k]] = -s.conj() * h_tri[[i, k]] + c * h_tri[[j, k]];
387                    h_tri[[i, k]] = temp;
388                }
389
390                let temp = c * rhs[i] + s * rhs[j];
391                rhs[j] = -s.conj() * rhs[i] + c * rhs[j];
392                rhs[i] = temp;
393            }
394        }
395    }
396
397    // Back substitution
398    for i in (0..m).rev() {
399        let mut sum = rhs[i];
400        for j in (i + 1)..m {
401            sum -= h_tri[[i, j]] * y[j];
402        }
403        if h_tri[[i, i]].norm() > 1e-12 {
404            y[i] = sum / h_tri[[i, i]];
405        }
406    }
407
408    Ok(y)
409}
410
411/// Solve BEM system using GMRES with parallel matrix assembly
412pub fn solve_bem_system(
413    mesh: &RoomMesh,
414    sources: &[Source],
415    k: f64,
416    frequency: f64,
417) -> Result<Array1<Complex64>, String> {
418    use crate::core::solver::{GmresConfig, solve_gmres};
419
420    // Build BEM matrix using double-layer potential formulation
421    let matrix = build_bem_matrix_parallel(mesh, k);
422
423    // Compute RHS from incident field normal derivative
424    let rhs = calculate_incident_field_derivative_parallel(mesh, sources, k, frequency);
425
426    // Use the core GMRES solver
427    let config = GmresConfig {
428        max_iterations: 100,
429        restart: 50,
430        tolerance: 1e-6,
431        print_interval: 0,
432    };
433
434    // Create matvec closure for dense matrix (implement LinearOperator or use wrapper)
435    // solve_gmres expects a LinearOperator.
436    // We can use DenseOperator from core::solver
437    use crate::core::solver::DenseOperator;
438    let op = DenseOperator::new(matrix);
439
440    let solution = solve_gmres(&op, &rhs, &config);
441
442    Ok(solution.x)
443}
444
445/// Build BEM matrix with parallel assembly
446///
447/// Uses portable parallel iteration that works with native rayon, WASM, or sequential fallback.
448pub fn build_bem_matrix_parallel(mesh: &RoomMesh, k: f64) -> Array2<Complex64> {
449    let n = mesh.elements.len();
450
451    // Precompute element data (parallel when available)
452    let element_data: Vec<_> = parallel_map(&mesh.elements, |element| {
453        let nodes: Vec<Point3D> = element.nodes.iter().map(|&i| mesh.nodes[i]).collect();
454        let (center, normal) = element_center_and_normal(&nodes);
455        let area = element_area(&nodes);
456        (center, normal, area)
457    });
458
459    // Build matrix rows (parallel when available)
460    let rows: Vec<_> = parallel_map_indexed(n, |i| {
461        let mut row = vec![Complex64::new(0.0, 0.0); n];
462        let (center_i, normal_i, _area_i) = &element_data[i];
463
464        for j in 0..n {
465            let (center_j, _normal_j, area_j) = &element_data[j];
466            let r = center_i.distance_to(center_j);
467
468            if i == j {
469                // Diagonal: self-interaction
470                row[j] = Complex64::new(0.0, -k / (2.0 * PI)) * area_j;
471            } else {
472                // Off-diagonal
473                let dx = center_i.x - center_j.x;
474                let dy = center_i.y - center_j.y;
475                let dz = center_i.z - center_j.z;
476
477                let cos_angle = (dx * normal_i.x + dy * normal_i.y + dz * normal_i.z) / r;
478                row[j] = greens_function_derivative(r, k, cos_angle) * area_j;
479            }
480        }
481        row
482    });
483
484    // Convert to ndarray
485    let mut matrix = Array2::zeros((n, n));
486    for (i, row) in rows.iter().enumerate() {
487        for (j, &val) in row.iter().enumerate() {
488            matrix[[i, j]] = val;
489        }
490    }
491
492    matrix
493}
494
495/// Build BEM matrix with adaptive integration for near-singular elements
496///
497/// Uses adaptive subdivision for self-interaction and nearby elements to improve accuracy.
498/// This is especially important at higher frequencies where standard point collocation
499/// becomes inaccurate.
500pub fn build_bem_matrix_adaptive(mesh: &RoomMesh, k: f64, use_adaptive: bool) -> Array2<Complex64> {
501    use crate::core::integration::singular::QuadratureParams;
502    use crate::core::types::{ElementType, PhysicsParams};
503
504    let n = mesh.elements.len();
505
506    // Physics parameters for singular integration
507    let speed_of_sound = 343.0;
508    let frequency = k * speed_of_sound / (2.0 * PI);
509    let omega = 2.0 * PI * frequency;
510
511    let physics = PhysicsParams {
512        wave_number: k,
513        density: 1.0,
514        speed_of_sound,
515        frequency,
516        omega,
517        wave_length: speed_of_sound / frequency,
518        harmonic_factor: 1.0, // exp(+ikr) convention
519        pressure_factor: 1.0 * omega * 1.0,
520        tau: -1.0, // internal problem (room interior)
521    };
522
523    // Precompute element data (parallel when available)
524    let element_data: Vec<_> = parallel_map(&mesh.elements, |element| {
525        let nodes: Vec<Point3D> = element.nodes.iter().map(|&i| mesh.nodes[i]).collect();
526        let (center, normal) = element_center_and_normal(&nodes);
527        let area = element_area(&nodes);
528        let char_length = element_characteristic_length(&nodes);
529        (center, normal, area, char_length, nodes)
530    });
531
532    // Build matrix rows (parallel when available)
533    let rows: Vec<_> = parallel_map_indexed(n, |i| {
534        let mut row = vec![Complex64::new(0.0, 0.0); n];
535        let (center_i, normal_i, _area_i, char_length_i, _nodes_i) = &element_data[i];
536
537        for j in 0..n {
538            let (center_j, _normal_j, area_j, char_length_j, nodes_j) = &element_data[j];
539            let r = center_i.distance_to(center_j);
540
541            // Criterion for near-singular: distance < 2 * characteristic length
542            let near_threshold = 2.0 * (char_length_i + char_length_j);
543            let is_near = r < near_threshold || i == j;
544
545            if use_adaptive && is_near {
546                // Use adaptive singular integration for near-field
547                let element_coords = nodes_to_array2(nodes_j);
548                let source_point = point_to_array1(center_i);
549                let source_normal = normal_to_array1(normal_i);
550
551                // Use frequency-adaptive quadrature
552                let ka = k * char_length_j;
553                let quad_params = QuadratureParams::for_ka(ka);
554
555                let result = crate::core::integration::singular::singular_integration_with_params(
556                    &source_point,
557                    &source_normal,
558                    &element_coords,
559                    ElementType::Tri3,
560                    &physics,
561                    None,
562                    0,     // Dirichlet BC type
563                    false, // don't compute RHS
564                    &quad_params,
565                );
566
567                // H matrix coefficient (∂G/∂n term) - double layer potential
568                row[j] = result.dg_dn_integral;
569            } else {
570                // Use point collocation for far-field
571                if i == j {
572                    // Diagonal: self-interaction approximation
573                    row[j] = Complex64::new(0.0, -k / (2.0 * PI)) * area_j;
574                } else {
575                    // Off-diagonal: standard collocation
576                    let dx = center_i.x - center_j.x;
577                    let dy = center_i.y - center_j.y;
578                    let dz = center_i.z - center_j.z;
579
580                    let cos_angle = (dx * normal_i.x + dy * normal_i.y + dz * normal_i.z) / r;
581                    row[j] = greens_function_derivative(r, k, cos_angle) * area_j;
582                }
583            }
584        }
585        row
586    });
587
588    // Convert to ndarray
589    let mut matrix = Array2::zeros((n, n));
590    for (i, row) in rows.iter().enumerate() {
591        for (j, &val) in row.iter().enumerate() {
592            matrix[[i, j]] = val;
593        }
594    }
595
596    matrix
597}
598
599/// Helper: compute characteristic length of an element
600fn element_characteristic_length(nodes: &[Point3D]) -> f64 {
601    if nodes.len() < 3 {
602        return 0.0;
603    }
604
605    // For triangles: use average edge length
606    let d01 = nodes[0].distance_to(&nodes[1]);
607    let d12 = nodes[1].distance_to(&nodes[2]);
608    let d20 = nodes[2].distance_to(&nodes[0]);
609
610    (d01 + d12 + d20) / 3.0
611}
612
613/// Helper: convert Point3D to Array1<f64>
614fn point_to_array1(p: &Point3D) -> Array1<f64> {
615    use ndarray::array;
616    array![p.x, p.y, p.z]
617}
618
619/// Helper: convert normal vector to Array1<f64>
620fn normal_to_array1(n: &Point3D) -> Array1<f64> {
621    use ndarray::array;
622    array![n.x, n.y, n.z]
623}
624
625/// Helper: convert nodes to Array2<f64> for singular integration
626fn nodes_to_array2(nodes: &[Point3D]) -> Array2<f64> {
627    let n = nodes.len();
628    let mut coords = Array2::zeros((n, 3));
629    for (i, node) in nodes.iter().enumerate() {
630        coords[[i, 0]] = node.x;
631        coords[[i, 1]] = node.y;
632        coords[[i, 2]] = node.z;
633    }
634    coords
635}
636
637/// Calculate incident field normal derivative in parallel
638pub fn calculate_incident_field_derivative_parallel(
639    mesh: &RoomMesh,
640    sources: &[Source],
641    k: f64,
642    frequency: f64,
643) -> Array1<Complex64> {
644    let element_data: Vec<_> = parallel_map(&mesh.elements, |element| {
645        let nodes: Vec<Point3D> = element.nodes.iter().map(|&idx| mesh.nodes[idx]).collect();
646        let (center, normal) = element_center_and_normal(&nodes);
647
648        // Compute incident field derivative
649        let mut dpdn_inc = Complex64::new(0.0, 0.0);
650
651        for source in sources {
652            let r = center.distance_to(&source.position);
653            if r < 1e-10 {
654                continue;
655            }
656
657            let amplitude = source.amplitude_towards(&center, frequency);
658
659            // Direction from source to point
660            let dx = center.x - source.position.x;
661            let dy = center.y - source.position.y;
662            let dz = center.z - source.position.z;
663
664            // Normal derivative: ∂G/∂n = ∇G · n
665            let cos_angle = (dx * normal.x + dy * normal.y + dz * normal.z) / r;
666
667            dpdn_inc += greens_function_derivative(r, k, cos_angle) * amplitude;
668        }
669
670        // For rigid boundary condition: ∂p/∂n = 0
671        // The BEM formulation gives: H * q = G * p_inc
672        // where q = ∂p/∂n is the unknown (should be zero for rigid)
673        // We solve: H * q = -∂p_inc/∂n  (to get total field with zero normal derivative)
674        -dpdn_inc
675    });
676
677    Array1::from_vec(element_data)
678}
679
680/// Calculate pressure at field points using BEM solution (parallel version)
681///
682/// Uses the double-layer potential (DLP) representation for field evaluation:
683/// p(x) = p_inc(x) + ∫∫ (∂G/∂n)(x, y) * p_surface(y) dS(y)
684///
685/// This matches the BEM formulation which solves for surface pressure using
686/// the H matrix (double-layer potential operator).
687pub fn calculate_field_pressure_bem_parallel(
688    mesh: &RoomMesh,
689    surface_pressure: &Array1<Complex64>,
690    sources: &[Source],
691    field_points: &[Point3D],
692    k: f64,
693    frequency: f64,
694) -> Array1<Complex64> {
695    // Precompute element data including normals for DLP evaluation
696    let element_data: Vec<_> = mesh
697        .elements
698        .iter()
699        .map(|element| {
700            let nodes: Vec<Point3D> = element.nodes.iter().map(|&i| mesh.nodes[i]).collect();
701            let (center, normal) = element_center_and_normal(&nodes);
702            let area = element_area(&nodes);
703            (center, normal, area)
704        })
705        .collect();
706
707    // Calculate pressure at each field point (parallel when available)
708    let pressures: Vec<_> = parallel_map(field_points, |point| {
709        // Incident field from sources
710        let mut p_incident = Complex64::new(0.0, 0.0);
711        for source in sources {
712            let r = point.distance_to(&source.position);
713            if r < 1e-10 {
714                continue;
715            }
716            let amplitude = source.amplitude_towards(point, frequency);
717            p_incident += greens_function_3d(r, k) * amplitude;
718        }
719
720        // Scattered field from boundary integral using double-layer potential
721        // p_scattered = ∫∫ (∂G/∂n)(x, y) * p_surface(y) dS(y)
722        // where ∂G/∂n is the normal derivative at the surface element (pointing outward)
723        let mut p_scattered = Complex64::new(0.0, 0.0);
724        for (j, (center_j, normal_j, area_j)) in element_data.iter().enumerate() {
725            let r = point.distance_to(center_j);
726            if r < 1e-10 {
727                continue;
728            }
729
730            // Direction from element center to field point
731            let dx = point.x - center_j.x;
732            let dy = point.y - center_j.y;
733            let dz = point.z - center_j.z;
734
735            // Cosine of angle between (point - center) direction and outward normal
736            let cos_angle = (dx * normal_j.x + dy * normal_j.y + dz * normal_j.z) / r;
737
738            // Normal derivative of Green's function: ∂G/∂n = (ikr-1) * exp(ikr) / (4πr²) * cos_angle
739            let dg_dn = greens_function_derivative(r, k, cos_angle);
740
741            p_scattered += dg_dn * surface_pressure[j] * area_j;
742        }
743
744        p_incident + p_scattered
745    });
746
747    Array1::from_vec(pressures)
748}
749
750// ============================================================================
751// FMM Integration for Room Acoustics
752// ============================================================================
753
754use crate::core::assembly::slfmm::{SlfmmSystem, build_slfmm_system};
755use crate::core::mesh::octree::Octree;
756use crate::core::solver::{
757    GmresConfig, GmresSolution, SlfmmOperator, gmres_solve_with_ilu_operator,
758};
759use crate::core::types::{
760    BoundaryCondition, Cluster, Element, ElementProperty, ElementType, PhysicsParams,
761};
762
763/// FMM solver configuration
764pub struct FmmSolverConfig {
765    /// Maximum elements per octree leaf (affects cluster size)
766    pub max_elements_per_leaf: usize,
767    /// Maximum octree depth
768    pub max_tree_depth: usize,
769    /// Number of theta integration points on unit sphere
770    pub n_theta: usize,
771    /// Number of phi integration points on unit sphere
772    pub n_phi: usize,
773    /// Number of expansion terms
774    pub n_terms: usize,
775    /// Separation ratio for near/far field classification
776    pub separation_ratio: f64,
777}
778
779impl Default for FmmSolverConfig {
780    fn default() -> Self {
781        Self {
782            max_elements_per_leaf: 50,
783            max_tree_depth: 8,
784            n_theta: 6,
785            n_phi: 12,
786            n_terms: 6,
787            separation_ratio: 1.5, // Standard FMM separation: 2/sqrt(3) ≈ 1.155
788        }
789    }
790}
791
792/// Convert RoomMesh to core Element and nodes arrays for FMM
793pub fn room_mesh_to_core_elements(mesh: &RoomMesh, _k: f64) -> (Vec<Element>, Array2<f64>) {
794    let n_nodes = mesh.nodes.len();
795    let n_elements = mesh.elements.len();
796
797    // Convert nodes to Array2
798    let mut nodes = Array2::zeros((n_nodes, 3));
799    for (i, node) in mesh.nodes.iter().enumerate() {
800        nodes[[i, 0]] = node.x;
801        nodes[[i, 1]] = node.y;
802        nodes[[i, 2]] = node.z;
803    }
804
805    // Convert elements
806    let mut elements = Vec::with_capacity(n_elements);
807    for (elem_idx, surface_elem) in mesh.elements.iter().enumerate() {
808        let elem_nodes: Vec<Point3D> = surface_elem.nodes.iter().map(|&i| mesh.nodes[i]).collect();
809
810        let (center, normal) = element_center_and_normal(&elem_nodes);
811        let area = element_area(&elem_nodes);
812
813        // Determine element type
814        let elem_type = if surface_elem.nodes.len() == 3 {
815            ElementType::Tri3
816        } else {
817            ElementType::Quad4
818        };
819
820        let element = Element {
821            connectivity: surface_elem.nodes.clone(),
822            element_type: elem_type,
823            property: ElementProperty::Surface,
824            normal: Array1::from_vec(vec![normal.x, normal.y, normal.z]),
825            node_normals: Array2::zeros((surface_elem.nodes.len(), 3)),
826            center: Array1::from_vec(vec![center.x, center.y, center.z]),
827            area,
828            boundary_condition: BoundaryCondition::Velocity(vec![Complex64::new(0.0, 0.0)]),
829            group: 0,
830            dof_addresses: vec![elem_idx],
831        };
832
833        elements.push(element);
834    }
835
836    (elements, nodes)
837}
838
839/// Build clusters from octree for FMM
840pub fn build_clusters_from_octree(octree: &Octree, elements: &[Element]) -> Vec<Cluster> {
841    let leaves = octree.leaves();
842    let mut clusters = Vec::with_capacity(leaves.len());
843
844    for &leaf_idx in &leaves {
845        let node = &octree.nodes[leaf_idx];
846        if node.element_indices.is_empty() {
847            continue;
848        }
849
850        let mut cluster = Cluster::new(node.center.clone());
851        cluster.element_indices = node.element_indices.clone();
852        cluster.num_elements = node.element_indices.len();
853        cluster.element_property = ElementProperty::Surface;
854        cluster.radius = node.radius();
855        cluster.level = node.level;
856
857        // Count DOFs
858        cluster.num_dofs = node
859            .element_indices
860            .iter()
861            .filter(|&&i| !elements[i].property.is_evaluation())
862            .count();
863        cluster.dofs_per_element = 1;
864
865        clusters.push(cluster);
866    }
867
868    // Build near/far lists using octree's computed lists
869    // Map from octree leaf indices to cluster indices
870    let leaf_to_cluster: std::collections::HashMap<usize, usize> = leaves
871        .iter()
872        .filter(|&&i| !octree.nodes[i].element_indices.is_empty())
873        .enumerate()
874        .map(|(cluster_idx, &leaf_idx)| (leaf_idx, cluster_idx))
875        .collect();
876
877    // Assign near/far cluster indices
878    for (cluster_idx, &leaf_idx) in leaves.iter().enumerate() {
879        if octree.nodes[leaf_idx].element_indices.is_empty() {
880            continue;
881        }
882
883        let octree_node = &octree.nodes[leaf_idx];
884
885        // Map near clusters
886        let near: Vec<usize> = octree_node
887            .near_clusters
888            .iter()
889            .filter_map(|&near_leaf| leaf_to_cluster.get(&near_leaf).copied())
890            .collect();
891
892        // Map far clusters
893        let far: Vec<usize> = octree_node
894            .far_clusters
895            .iter()
896            .filter_map(|&far_leaf| leaf_to_cluster.get(&far_leaf).copied())
897            .collect();
898
899        if let Some(cluster) = clusters.get_mut(cluster_idx) {
900            cluster.near_clusters = near;
901            cluster.far_clusters = far;
902        }
903    }
904
905    clusters
906}
907
908/// Build FMM system for room acoustics
909pub fn build_fmm_system(
910    mesh: &RoomMesh,
911    sources: &[Source],
912    k: f64,
913    frequency: f64,
914    fmm_config: &FmmSolverConfig,
915) -> Result<(SlfmmSystem, Vec<Element>, Array2<f64>), String> {
916    println!("  Converting room mesh to core elements...");
917    let (elements, nodes) = room_mesh_to_core_elements(mesh, k);
918    println!("    {} elements, {} nodes", elements.len(), nodes.nrows());
919
920    // Compute element centers for octree
921    println!("  Building octree...");
922    let centers: Vec<Array1<f64>> = elements.iter().map(|e| e.center.clone()).collect();
923
924    let mut octree = Octree::build(
925        &centers,
926        fmm_config.max_elements_per_leaf,
927        fmm_config.max_tree_depth,
928    );
929
930    // Compute interaction lists
931    octree.compute_interaction_lists(fmm_config.separation_ratio);
932
933    let stats = octree.stats();
934    println!(
935        "    {} leaves, {} levels, avg {:.1} elements/leaf",
936        stats.num_leaves, stats.num_levels, stats.avg_elements_per_leaf
937    );
938
939    // Build clusters from octree
940    println!("  Building clusters...");
941    let clusters = build_clusters_from_octree(&octree, &elements);
942    println!("    {} clusters", clusters.len());
943
944    // Create physics parameters
945    let speed_of_sound = 343.0;
946    let physics = PhysicsParams::new(frequency, speed_of_sound, 1.21, true);
947
948    // Build SLFMM system
949    println!("  Assembling SLFMM system...");
950    let mut system = build_slfmm_system(
951        &elements,
952        &nodes,
953        &clusters,
954        &physics,
955        fmm_config.n_theta,
956        fmm_config.n_phi,
957        fmm_config.n_terms,
958    );
959
960    // Compute RHS
961    println!("  Computing RHS...");
962    let rhs = calculate_incident_field_derivative_parallel(mesh, sources, k, frequency);
963    system.rhs = rhs;
964
965    Ok((system, elements, nodes))
966}
967
968/// Solve BEM system using FMM + GMRES + ILU
969///
970/// This is the recommended solver for large meshes (>1000 elements).
971/// Complexity: O(N log N) per iteration vs O(N²) for dense GMRES.
972///
973/// The near-field matrix for ILU preconditioning is extracted directly
974/// from the SLFMM system, avoiding the O(N²) dense matrix assembly.
975pub fn solve_bem_fmm_gmres_ilu(
976    mesh: &RoomMesh,
977    sources: &[Source],
978    k: f64,
979    frequency: f64,
980    fmm_config: &FmmSolverConfig,
981    gmres_max_iter: usize,
982    gmres_restart: usize,
983    gmres_tolerance: f64,
984) -> Result<Array1<Complex64>, String> {
985    // Build FMM system
986    let (system, _elements, _nodes) = build_fmm_system(mesh, sources, k, frequency, fmm_config)?;
987
988    // Extract near-field matrix for ILU preconditioning BEFORE moving system to operator
989    // This uses only the already-computed near-field blocks, not a full O(N²) assembly
990    println!("  Extracting near-field matrix for ILU...");
991    let nearfield_matrix = system.extract_near_field_matrix();
992    println!(
993        "    Near-field matrix: {}x{}",
994        nearfield_matrix.nrows(),
995        nearfield_matrix.ncols()
996    );
997
998    // Create FMM operator (takes ownership of system)
999    let fmm_operator = SlfmmOperator::new(system);
1000
1001    // GMRES configuration
1002    let gmres_config = GmresConfig {
1003        max_iterations: gmres_max_iter,
1004        restart: gmres_restart,
1005        tolerance: gmres_tolerance,
1006        print_interval: 0,
1007    };
1008
1009    // Get RHS from operator
1010    let rhs = fmm_operator.rhs().clone();
1011
1012    // Solve with FMM-accelerated GMRES + ILU preconditioning
1013    println!("  Solving with FMM + GMRES + ILU...");
1014    let result =
1015        gmres_solve_with_ilu_operator(&fmm_operator, &nearfield_matrix, &rhs, &gmres_config);
1016
1017    if result.converged {
1018        println!(
1019            "    Converged in {} iterations, residual: {:.2e}",
1020            result.iterations, result.residual
1021        );
1022    } else {
1023        println!(
1024            "    Warning: Did not converge after {} iterations, residual: {:.2e}",
1025            result.iterations, result.residual
1026        );
1027    }
1028
1029    Ok(result.x)
1030}
1031
1032/// Solve BEM system using FMM + GMRES with hierarchical preconditioner
1033///
1034/// This is an alternative to `solve_bem_fmm_gmres_ilu` that uses a hierarchical
1035/// block-diagonal preconditioner based on the FMM near-field blocks.
1036///
1037/// ## Advantages
1038/// - O(N) preconditioner setup (vs O(N²) for ILU on extracted dense matrix)
1039/// - Parallel LU factorization of each cluster block
1040/// - No dense matrix extraction needed
1041///
1042/// ## When to use
1043/// - For very large problems where ILU setup time dominates
1044/// - When memory is constrained (no dense matrix extraction)
1045pub fn solve_bem_fmm_gmres_hierarchical(
1046    mesh: &RoomMesh,
1047    sources: &[Source],
1048    k: f64,
1049    frequency: f64,
1050    fmm_config: &FmmSolverConfig,
1051    gmres_max_iter: usize,
1052    gmres_restart: usize,
1053    gmres_tolerance: f64,
1054) -> Result<Array1<Complex64>, String> {
1055    use crate::core::solver::gmres_solve_with_hierarchical_precond;
1056
1057    // Build FMM system
1058    let (system, _elements, _nodes) = build_fmm_system(mesh, sources, k, frequency, fmm_config)?;
1059
1060    // GMRES configuration
1061    let gmres_config = GmresConfig {
1062        max_iterations: gmres_max_iter,
1063        restart: gmres_restart,
1064        tolerance: gmres_tolerance,
1065        print_interval: 0,
1066    };
1067
1068    // Get RHS
1069    let rhs = system.rhs.clone();
1070
1071    // Solve with hierarchical preconditioner
1072    println!("  Solving with FMM + GMRES + Hierarchical Preconditioner...");
1073    let result = gmres_solve_with_hierarchical_precond(&system, &rhs, &gmres_config);
1074
1075    if result.converged {
1076        println!(
1077            "    Converged in {} iterations, residual: {:.2e}",
1078            result.iterations, result.residual
1079        );
1080    } else {
1081        println!(
1082            "    Warning: Did not converge after {} iterations, residual: {:.2e}",
1083            result.iterations, result.residual
1084        );
1085    }
1086
1087    Ok(result.x)
1088}
1089
1090/// Solve BEM system using FMM + GMRES + ILU with full result
1091///
1092/// Same as `solve_bem_fmm_gmres_ilu` but returns the full GmresSolution with
1093/// convergence info.
1094pub fn solve_bem_fmm_gmres_ilu_with_result(
1095    mesh: &RoomMesh,
1096    sources: &[Source],
1097    k: f64,
1098    frequency: f64,
1099    fmm_config: &FmmSolverConfig,
1100    gmres_max_iter: usize,
1101    gmres_restart: usize,
1102    gmres_tolerance: f64,
1103) -> Result<GmresSolution<Complex64>, String> {
1104    // Build FMM system
1105    let (system, _elements, _nodes) = build_fmm_system(mesh, sources, k, frequency, fmm_config)?;
1106
1107    // Extract near-field matrix for ILU preconditioning BEFORE moving system to operator
1108    println!("  Extracting near-field matrix for ILU...");
1109    let nearfield_matrix = system.extract_near_field_matrix();
1110    println!(
1111        "    Near-field matrix: {}x{}",
1112        nearfield_matrix.nrows(),
1113        nearfield_matrix.ncols()
1114    );
1115
1116    // Create FMM operator (takes ownership of system)
1117    let fmm_operator = SlfmmOperator::new(system);
1118
1119    // GMRES configuration
1120    let gmres_config = GmresConfig {
1121        max_iterations: gmres_max_iter,
1122        restart: gmres_restart,
1123        tolerance: gmres_tolerance,
1124        print_interval: 0,
1125    };
1126
1127    // Get RHS from operator
1128    let rhs = fmm_operator.rhs().clone();
1129
1130    // Solve with FMM-accelerated GMRES + ILU preconditioning
1131    println!("  Solving with FMM + GMRES + ILU...");
1132    let result =
1133        gmres_solve_with_ilu_operator(&fmm_operator, &nearfield_matrix, &rhs, &gmres_config);
1134
1135    if result.converged {
1136        println!(
1137            "    Converged in {} iterations, residual: {:.2e}",
1138            result.iterations, result.residual
1139        );
1140    } else {
1141        println!(
1142            "    Warning: Did not converge after {} iterations, residual: {:.2e}",
1143            result.iterations, result.residual
1144        );
1145    }
1146
1147    Ok(result)
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152    use super::*;
1153
1154    #[test]
1155    fn test_greens_function() {
1156        let k = 2.0 * PI * 1000.0 / 343.0;
1157        let r = 1.0;
1158        let g = greens_function_3d(r, k);
1159        // Should have magnitude approximately 1/(4πr)
1160        assert!((g.norm() - 1.0 / (4.0 * PI)).abs() < 0.1);
1161    }
1162
1163    #[test]
1164    fn test_pressure_to_spl() {
1165        let p = Complex64::new(1.0, 0.0); // 1 Pa
1166        let spl = pressure_to_spl(p);
1167        // 1 Pa = 94 dB SPL
1168        assert!((spl - 94.0).abs() < 1.0);
1169    }
1170
1171    #[test]
1172    fn test_room_mesh_to_core_elements() {
1173        // Create a simple test mesh
1174        let nodes = vec![
1175            Point3D::new(0.0, 0.0, 0.0),
1176            Point3D::new(1.0, 0.0, 0.0),
1177            Point3D::new(0.5, 1.0, 0.0),
1178        ];
1179        let elements = vec![SurfaceElement {
1180            nodes: vec![0, 1, 2],
1181        }];
1182        let mesh = RoomMesh { nodes, elements };
1183
1184        let k = 2.0 * PI * 100.0 / 343.0;
1185        let (core_elements, core_nodes) = room_mesh_to_core_elements(&mesh, k);
1186
1187        assert_eq!(core_elements.len(), 1);
1188        assert_eq!(core_nodes.nrows(), 3);
1189        assert_eq!(core_elements[0].element_type, ElementType::Tri3);
1190    }
1191}