Skip to main content

math_audio_bem/core/solver/
batched_blas.rs

1//! Batched BLAS operations for complex matrices
2//!
3//! This module provides optimized batched matrix operations that leverage
4//! BLAS Level 3 (GEMM) for better performance on modern CPUs.
5//!
6//! Key optimizations:
7//! - Pre-allocated workspace buffers
8//! - Batched GEMM operations when possible
9//! - Reduced memory allocations in hot paths
10//! - Contiguous memory layouts for cache efficiency
11//!
12//! **Note**: This module requires the `native` feature as it uses rayon for parallel processing.
13//! The feature gate is applied at the module declaration in `mod.rs`.
14
15use ndarray::{Array1, Array2};
16use num_complex::Complex64;
17
18/// Workspace for batched SLFMM matvec operations
19///
20/// Pre-allocates all necessary buffers to avoid allocations in hot path
21pub struct SlfmmMatvecWorkspace {
22    /// Workspace for multipole expansions: [num_clusters, num_sphere_points]
23    pub multipoles: Array2<Complex64>,
24    /// Workspace for local expansions: [num_clusters, num_sphere_points]
25    pub locals: Array2<Complex64>,
26    /// Workspace for DOF scatter/gather: \[num_dofs\]
27    pub dof_buffer: Array1<Complex64>,
28    /// Number of clusters
29    pub num_clusters: usize,
30    /// Number of sphere integration points
31    pub num_sphere_points: usize,
32    /// Number of DOFs
33    pub num_dofs: usize,
34}
35
36impl SlfmmMatvecWorkspace {
37    /// Create a new workspace with pre-allocated buffers
38    pub fn new(num_clusters: usize, num_sphere_points: usize, num_dofs: usize) -> Self {
39        Self {
40            multipoles: Array2::zeros((num_clusters, num_sphere_points)),
41            locals: Array2::zeros((num_clusters, num_sphere_points)),
42            dof_buffer: Array1::zeros(num_dofs),
43            num_clusters,
44            num_sphere_points,
45            num_dofs,
46        }
47    }
48
49    /// Clear all workspace buffers (zero out)
50    pub fn clear(&mut self) {
51        self.multipoles.fill(Complex64::new(0.0, 0.0));
52        self.locals.fill(Complex64::new(0.0, 0.0));
53        self.dof_buffer.fill(Complex64::new(0.0, 0.0));
54    }
55}
56
57/// Batched matrix-vector multiply for T-matrix application
58///
59/// Computes: `multipoles[c] = T[c] * x[cluster_dofs[c]]` for all clusters
60///
61/// This is more efficient than individual GEMV calls because:
62/// 1. Single allocation for gathering x values
63/// 2. Better memory locality
64/// 3. Can use GEMM when clusters have similar sizes
65pub fn batched_t_matrix_apply(
66    t_matrices: &[Array2<Complex64>],
67    x: &Array1<Complex64>,
68    cluster_dof_indices: &[Vec<usize>],
69    multipoles: &mut Array2<Complex64>,
70) {
71    // Process clusters in parallel using rayon
72    use rayon::prelude::*;
73
74    // Parallel computation with immediate storage
75    let results: Vec<(usize, Array1<Complex64>)> = t_matrices
76        .par_iter()
77        .enumerate()
78        .filter_map(|(cluster_idx, t_mat)| {
79            let cluster_dofs = &cluster_dof_indices[cluster_idx];
80            if cluster_dofs.is_empty() || t_mat.is_empty() {
81                return None;
82            }
83
84            // Gather x values (avoiding allocation by reusing iterator)
85            let x_local: Array1<Complex64> = Array1::from_iter(cluster_dofs.iter().map(|&i| x[i]));
86
87            // Apply T-matrix
88            let result = t_mat.dot(&x_local);
89            Some((cluster_idx, result))
90        })
91        .collect();
92
93    // Store results (sequential to avoid race conditions)
94    for (cluster_idx, result) in results {
95        multipoles.row_mut(cluster_idx).assign(&result);
96    }
97}
98
99/// Batched matrix-vector multiply for S-matrix application
100///
101/// Computes: `y[cluster_dofs[c]] += S[c] * locals[c]` for all clusters
102///
103/// Similar optimization strategy to T-matrix application
104pub fn batched_s_matrix_apply(
105    s_matrices: &[Array2<Complex64>],
106    locals: &Array2<Complex64>,
107    cluster_dof_indices: &[Vec<usize>],
108    y: &mut Array1<Complex64>,
109) {
110    use rayon::prelude::*;
111
112    // Parallel computation
113    let results: Vec<Vec<(usize, Complex64)>> = s_matrices
114        .par_iter()
115        .enumerate()
116        .filter_map(|(cluster_idx, s_mat)| {
117            let cluster_dofs = &cluster_dof_indices[cluster_idx];
118            if cluster_dofs.is_empty() || s_mat.is_empty() {
119                return None;
120            }
121
122            // Apply S-matrix to local expansion
123            let y_local = s_mat.dot(&locals.row(cluster_idx));
124
125            // Collect contributions
126            let contributions: Vec<(usize, Complex64)> = cluster_dofs
127                .iter()
128                .enumerate()
129                .map(|(local_j, &global_j)| (global_j, y_local[local_j]))
130                .collect();
131
132            Some(contributions)
133        })
134        .collect();
135
136    // Accumulate results
137    for contributions in results {
138        for (idx, val) in contributions {
139            y[idx] += val;
140        }
141    }
142}
143
144/// Batched D-matrix translation
145///
146/// Computes: `locals[field_cluster] += D[entry] * multipoles[source_cluster]`
147/// for all D-matrix entries
148pub fn batched_d_matrix_apply(
149    d_matrices: &[super::super::assembly::slfmm::DMatrixEntry],
150    multipoles: &Array2<Complex64>,
151    locals: &mut Array2<Complex64>,
152) {
153    use rayon::prelude::*;
154
155    // Parallel computation of all D-matrix translations
156    // D-matrix is diagonal, so D*x is element-wise multiplication
157    let results: Vec<(usize, Array1<Complex64>)> = d_matrices
158        .par_iter()
159        .map(|d_entry| {
160            let src_mult = multipoles.row(d_entry.source_cluster);
161            // Diagonal matrix-vector multiply: translated[i] = diagonal[i] * src_mult[i]
162            let translated: Array1<Complex64> = d_entry
163                .diagonal
164                .iter()
165                .zip(src_mult.iter())
166                .map(|(&d, &s)| d * s)
167                .collect();
168            (d_entry.field_cluster, translated)
169        })
170        .collect();
171
172    // Accumulate into locals (sequential to avoid race)
173    for (field_cluster, translated) in results {
174        for i in 0..translated.len() {
175            locals[[field_cluster, i]] += translated[i];
176        }
177    }
178}
179
180/// Batched near-field block application
181///
182/// Computes: y += N * x where N is the sparse near-field matrix
183/// represented as dense blocks
184pub fn batched_near_field_apply(
185    near_blocks: &[super::super::assembly::slfmm::NearFieldBlock],
186    x: &Array1<Complex64>,
187    cluster_dof_indices: &[Vec<usize>],
188    y: &mut Array1<Complex64>,
189) {
190    use rayon::prelude::*;
191
192    // Parallel computation of all block contributions
193    let contributions: Vec<Vec<(usize, Complex64)>> = near_blocks
194        .par_iter()
195        .flat_map(|block| {
196            let src_dofs = &cluster_dof_indices[block.source_cluster];
197            let fld_dofs = &cluster_dof_indices[block.field_cluster];
198
199            let mut result = Vec::new();
200
201            // Gather x values from field cluster
202            let x_local: Array1<Complex64> = Array1::from_iter(fld_dofs.iter().map(|&i| x[i]));
203
204            // Apply block matrix
205            let y_local = block.coefficients.dot(&x_local);
206
207            // Collect contributions for source DOFs
208            for (local_i, &global_i) in src_dofs.iter().enumerate() {
209                result.push((global_i, y_local[local_i]));
210            }
211
212            // Handle symmetric storage
213            if block.source_cluster != block.field_cluster {
214                let x_src: Array1<Complex64> = Array1::from_iter(src_dofs.iter().map(|&i| x[i]));
215                let y_fld = block.coefficients.t().dot(&x_src);
216                for (local_j, &global_j) in fld_dofs.iter().enumerate() {
217                    result.push((global_j, y_fld[local_j]));
218                }
219            }
220
221            vec![result]
222        })
223        .collect();
224
225    // Accumulate contributions
226    for block_contributions in contributions {
227        for (idx, val) in block_contributions {
228            y[idx] += val;
229        }
230    }
231}
232
233/// Optimized SLFMM matvec using batched operations and pre-allocated workspace
234///
235/// This version avoids allocations in the hot path by using a pre-allocated workspace.
236/// Call `SlfmmMatvecWorkspace::new()` once before solving, then reuse for all matvec calls.
237pub fn slfmm_matvec_batched(
238    system: &super::super::assembly::slfmm::SlfmmSystem,
239    x: &Array1<Complex64>,
240    workspace: &mut SlfmmMatvecWorkspace,
241) -> Array1<Complex64> {
242    // Clear workspace
243    workspace.clear();
244
245    // Initialize output vector
246    let mut y = Array1::zeros(system.num_dofs);
247
248    // === Near-field contribution ===
249    batched_near_field_apply(&system.near_matrix, x, &system.cluster_dof_indices, &mut y);
250
251    // === Far-field contribution: y += [S][D][T] * x ===
252
253    // Step 1: T-matrix application: multipoles = T * x
254    batched_t_matrix_apply(
255        &system.t_matrices,
256        x,
257        &system.cluster_dof_indices,
258        &mut workspace.multipoles,
259    );
260
261    // Step 2: D-matrix translation: locals = D * multipoles
262    batched_d_matrix_apply(
263        &system.d_matrices,
264        &workspace.multipoles,
265        &mut workspace.locals,
266    );
267
268    // Step 3: S-matrix application: y += S * locals
269    batched_s_matrix_apply(
270        &system.s_matrices,
271        &workspace.locals,
272        &system.cluster_dof_indices,
273        &mut y,
274    );
275
276    y
277}
278
279/// Create a batched matvec closure for use with iterative solvers
280///
281/// Returns a closure that applies the SLFMM operator using batched BLAS operations.
282/// The workspace is created once and reused for all matvec calls.
283///
284/// Note: Returns FnMut because the workspace is mutated on each call.
285pub fn create_batched_matvec<'a>(
286    system: &'a super::super::assembly::slfmm::SlfmmSystem,
287) -> impl FnMut(&Array1<Complex64>) -> Array1<Complex64> + 'a {
288    // Pre-allocate workspace
289    let mut workspace = SlfmmMatvecWorkspace::new(
290        system.num_clusters,
291        system.num_sphere_points,
292        system.num_dofs,
293    );
294
295    move |x: &Array1<Complex64>| slfmm_matvec_batched(system, x, &mut workspace)
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_workspace_creation() {
304        let workspace = SlfmmMatvecWorkspace::new(10, 32, 100);
305        assert_eq!(workspace.num_clusters, 10);
306        assert_eq!(workspace.num_sphere_points, 32);
307        assert_eq!(workspace.num_dofs, 100);
308        assert_eq!(workspace.multipoles.shape(), &[10, 32]);
309        assert_eq!(workspace.locals.shape(), &[10, 32]);
310        assert_eq!(workspace.dof_buffer.len(), 100);
311    }
312
313    #[test]
314    fn test_workspace_clear() {
315        let mut workspace = SlfmmMatvecWorkspace::new(2, 4, 8);
316        workspace.multipoles[[0, 0]] = Complex64::new(1.0, 2.0);
317        workspace.clear();
318        assert_eq!(workspace.multipoles[[0, 0]], Complex64::new(0.0, 0.0));
319    }
320}