bem/core/solver/
batched_blas.rs

1//! Batched BLAS operations for complex matrices
2//!
3//! This module provides optimized batched matrix operations that leverage
4//! BLAS Level 3 (GEMM) for better performance on modern CPUs.
5//!
6//! Key optimizations:
7//! - Pre-allocated workspace buffers
8//! - Batched GEMM operations when possible
9//! - Reduced memory allocations in hot paths
10//! - Contiguous memory layouts for cache efficiency
11
12use ndarray::{Array1, Array2};
13use num_complex::Complex64;
14
15/// Workspace for batched SLFMM matvec operations
16///
17/// Pre-allocates all necessary buffers to avoid allocations in hot path
18pub struct SlfmmMatvecWorkspace {
19    /// Workspace for multipole expansions: [num_clusters, num_sphere_points]
20    pub multipoles: Array2<Complex64>,
21    /// Workspace for local expansions: [num_clusters, num_sphere_points]
22    pub locals: Array2<Complex64>,
23    /// Workspace for DOF scatter/gather: [num_dofs]
24    pub dof_buffer: Array1<Complex64>,
25    /// Number of clusters
26    pub num_clusters: usize,
27    /// Number of sphere integration points
28    pub num_sphere_points: usize,
29    /// Number of DOFs
30    pub num_dofs: usize,
31}
32
33impl SlfmmMatvecWorkspace {
34    /// Create a new workspace with pre-allocated buffers
35    pub fn new(num_clusters: usize, num_sphere_points: usize, num_dofs: usize) -> Self {
36        Self {
37            multipoles: Array2::zeros((num_clusters, num_sphere_points)),
38            locals: Array2::zeros((num_clusters, num_sphere_points)),
39            dof_buffer: Array1::zeros(num_dofs),
40            num_clusters,
41            num_sphere_points,
42            num_dofs,
43        }
44    }
45
46    /// Clear all workspace buffers (zero out)
47    pub fn clear(&mut self) {
48        self.multipoles.fill(Complex64::new(0.0, 0.0));
49        self.locals.fill(Complex64::new(0.0, 0.0));
50        self.dof_buffer.fill(Complex64::new(0.0, 0.0));
51    }
52}
53
54/// Batched matrix-vector multiply for T-matrix application
55///
56/// Computes: multipoles[c] = T[c] * x[cluster_dofs[c]] for all clusters
57///
58/// This is more efficient than individual GEMV calls because:
59/// 1. Single allocation for gathering x values
60/// 2. Better memory locality
61/// 3. Can use GEMM when clusters have similar sizes
62pub fn batched_t_matrix_apply(
63    t_matrices: &[Array2<Complex64>],
64    x: &Array1<Complex64>,
65    cluster_dof_indices: &[Vec<usize>],
66    multipoles: &mut Array2<Complex64>,
67) {
68    // Process clusters in parallel using rayon
69    use rayon::prelude::*;
70
71    // Parallel computation with immediate storage
72    let results: Vec<(usize, Array1<Complex64>)> = t_matrices
73        .par_iter()
74        .enumerate()
75        .filter_map(|(cluster_idx, t_mat)| {
76            let cluster_dofs = &cluster_dof_indices[cluster_idx];
77            if cluster_dofs.is_empty() || t_mat.is_empty() {
78                return None;
79            }
80
81            // Gather x values (avoiding allocation by reusing iterator)
82            let x_local: Array1<Complex64> =
83                Array1::from_iter(cluster_dofs.iter().map(|&i| x[i]));
84
85            // Apply T-matrix
86            let result = t_mat.dot(&x_local);
87            Some((cluster_idx, result))
88        })
89        .collect();
90
91    // Store results (sequential to avoid race conditions)
92    for (cluster_idx, result) in results {
93        multipoles.row_mut(cluster_idx).assign(&result);
94    }
95}
96
97/// Batched matrix-vector multiply for S-matrix application
98///
99/// Computes: y[cluster_dofs[c]] += S[c] * locals[c] for all clusters
100///
101/// Similar optimization strategy to T-matrix application
102pub fn batched_s_matrix_apply(
103    s_matrices: &[Array2<Complex64>],
104    locals: &Array2<Complex64>,
105    cluster_dof_indices: &[Vec<usize>],
106    y: &mut Array1<Complex64>,
107) {
108    use rayon::prelude::*;
109
110    // Parallel computation
111    let results: Vec<Vec<(usize, Complex64)>> = s_matrices
112        .par_iter()
113        .enumerate()
114        .filter_map(|(cluster_idx, s_mat)| {
115            let cluster_dofs = &cluster_dof_indices[cluster_idx];
116            if cluster_dofs.is_empty() || s_mat.is_empty() {
117                return None;
118            }
119
120            // Apply S-matrix to local expansion
121            let y_local = s_mat.dot(&locals.row(cluster_idx));
122
123            // Collect contributions
124            let contributions: Vec<(usize, Complex64)> = cluster_dofs
125                .iter()
126                .enumerate()
127                .map(|(local_j, &global_j)| (global_j, y_local[local_j]))
128                .collect();
129
130            Some(contributions)
131        })
132        .collect();
133
134    // Accumulate results
135    for contributions in results {
136        for (idx, val) in contributions {
137            y[idx] += val;
138        }
139    }
140}
141
142/// Batched D-matrix translation
143///
144/// Computes: locals[field_cluster] += D[entry] * multipoles[source_cluster]
145/// for all D-matrix entries
146pub fn batched_d_matrix_apply(
147    d_matrices: &[super::super::assembly::slfmm::DMatrixEntry],
148    multipoles: &Array2<Complex64>,
149    locals: &mut Array2<Complex64>,
150) {
151    use rayon::prelude::*;
152
153    // Parallel computation of all D-matrix translations
154    let results: Vec<(usize, Array1<Complex64>)> = d_matrices
155        .par_iter()
156        .map(|d_entry| {
157            let src_mult = multipoles.row(d_entry.source_cluster);
158            let translated = d_entry.coefficients.dot(&src_mult);
159            (d_entry.field_cluster, translated)
160        })
161        .collect();
162
163    // Accumulate into locals (sequential to avoid race)
164    for (field_cluster, translated) in results {
165        for i in 0..translated.len() {
166            locals[[field_cluster, i]] += translated[i];
167        }
168    }
169}
170
171/// Batched near-field block application
172///
173/// Computes: y += N * x where N is the sparse near-field matrix
174/// represented as dense blocks
175pub fn batched_near_field_apply(
176    near_blocks: &[super::super::assembly::slfmm::NearFieldBlock],
177    x: &Array1<Complex64>,
178    cluster_dof_indices: &[Vec<usize>],
179    y: &mut Array1<Complex64>,
180) {
181    use rayon::prelude::*;
182
183    // Parallel computation of all block contributions
184    let contributions: Vec<Vec<(usize, Complex64)>> = near_blocks
185        .par_iter()
186        .flat_map(|block| {
187            let src_dofs = &cluster_dof_indices[block.source_cluster];
188            let fld_dofs = &cluster_dof_indices[block.field_cluster];
189
190            let mut result = Vec::new();
191
192            // Gather x values from field cluster
193            let x_local: Array1<Complex64> = Array1::from_iter(fld_dofs.iter().map(|&i| x[i]));
194
195            // Apply block matrix
196            let y_local = block.coefficients.dot(&x_local);
197
198            // Collect contributions for source DOFs
199            for (local_i, &global_i) in src_dofs.iter().enumerate() {
200                result.push((global_i, y_local[local_i]));
201            }
202
203            // Handle symmetric storage
204            if block.source_cluster != block.field_cluster {
205                let x_src: Array1<Complex64> = Array1::from_iter(src_dofs.iter().map(|&i| x[i]));
206                let y_fld = block.coefficients.t().dot(&x_src);
207                for (local_j, &global_j) in fld_dofs.iter().enumerate() {
208                    result.push((global_j, y_fld[local_j]));
209                }
210            }
211
212            vec![result]
213        })
214        .collect();
215
216    // Accumulate contributions
217    for block_contributions in contributions {
218        for (idx, val) in block_contributions {
219            y[idx] += val;
220        }
221    }
222}
223
224/// Optimized SLFMM matvec using batched operations and pre-allocated workspace
225///
226/// This version avoids allocations in the hot path by using a pre-allocated workspace.
227/// Call `SlfmmMatvecWorkspace::new()` once before solving, then reuse for all matvec calls.
228pub fn slfmm_matvec_batched(
229    system: &super::super::assembly::slfmm::SlfmmSystem,
230    x: &Array1<Complex64>,
231    workspace: &mut SlfmmMatvecWorkspace,
232) -> Array1<Complex64> {
233    // Clear workspace
234    workspace.clear();
235
236    // Initialize output vector
237    let mut y = Array1::zeros(system.num_dofs);
238
239    // === Near-field contribution ===
240    batched_near_field_apply(&system.near_matrix, x, &system.cluster_dof_indices, &mut y);
241
242    // === Far-field contribution: y += [S][D][T] * x ===
243
244    // Step 1: T-matrix application: multipoles = T * x
245    batched_t_matrix_apply(
246        &system.t_matrices,
247        x,
248        &system.cluster_dof_indices,
249        &mut workspace.multipoles,
250    );
251
252    // Step 2: D-matrix translation: locals = D * multipoles
253    batched_d_matrix_apply(&system.d_matrices, &workspace.multipoles, &mut workspace.locals);
254
255    // Step 3: S-matrix application: y += S * locals
256    batched_s_matrix_apply(&system.s_matrices, &workspace.locals, &system.cluster_dof_indices, &mut y);
257
258    y
259}
260
261/// Create a batched matvec closure for use with iterative solvers
262///
263/// Returns a closure that applies the SLFMM operator using batched BLAS operations.
264/// The workspace is created once and reused for all matvec calls.
265///
266/// Note: Returns FnMut because the workspace is mutated on each call.
267pub fn create_batched_matvec<'a>(
268    system: &'a super::super::assembly::slfmm::SlfmmSystem,
269) -> impl FnMut(&Array1<Complex64>) -> Array1<Complex64> + 'a {
270    // Pre-allocate workspace
271    let mut workspace =
272        SlfmmMatvecWorkspace::new(system.num_clusters, system.num_sphere_points, system.num_dofs);
273
274    move |x: &Array1<Complex64>| slfmm_matvec_batched(system, x, &mut workspace)
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_workspace_creation() {
283        let workspace = SlfmmMatvecWorkspace::new(10, 32, 100);
284        assert_eq!(workspace.num_clusters, 10);
285        assert_eq!(workspace.num_sphere_points, 32);
286        assert_eq!(workspace.num_dofs, 100);
287        assert_eq!(workspace.multipoles.shape(), &[10, 32]);
288        assert_eq!(workspace.locals.shape(), &[10, 32]);
289        assert_eq!(workspace.dof_buffer.len(), 100);
290    }
291
292    #[test]
293    fn test_workspace_clear() {
294        let mut workspace = SlfmmMatvecWorkspace::new(2, 4, 8);
295        workspace.multipoles[[0, 0]] = Complex64::new(1.0, 2.0);
296        workspace.clear();
297        assert_eq!(workspace.multipoles[[0, 0]], Complex64::new(0.0, 0.0));
298    }
299}