math_audio_bem/core/solver/
batched_blas.rs1use ndarray::{Array1, Array2};
16use num_complex::Complex64;
17
18pub struct SlfmmMatvecWorkspace {
22 pub multipoles: Array2<Complex64>,
24 pub locals: Array2<Complex64>,
26 pub dof_buffer: Array1<Complex64>,
28 pub num_clusters: usize,
30 pub num_sphere_points: usize,
32 pub num_dofs: usize,
34}
35
36impl SlfmmMatvecWorkspace {
37 pub fn new(num_clusters: usize, num_sphere_points: usize, num_dofs: usize) -> Self {
39 Self {
40 multipoles: Array2::zeros((num_clusters, num_sphere_points)),
41 locals: Array2::zeros((num_clusters, num_sphere_points)),
42 dof_buffer: Array1::zeros(num_dofs),
43 num_clusters,
44 num_sphere_points,
45 num_dofs,
46 }
47 }
48
49 pub fn clear(&mut self) {
51 self.multipoles.fill(Complex64::new(0.0, 0.0));
52 self.locals.fill(Complex64::new(0.0, 0.0));
53 self.dof_buffer.fill(Complex64::new(0.0, 0.0));
54 }
55}
56
57pub fn batched_t_matrix_apply(
66 t_matrices: &[Array2<Complex64>],
67 x: &Array1<Complex64>,
68 cluster_dof_indices: &[Vec<usize>],
69 multipoles: &mut Array2<Complex64>,
70) {
71 use rayon::prelude::*;
73
74 let results: Vec<(usize, Array1<Complex64>)> = t_matrices
76 .par_iter()
77 .enumerate()
78 .filter_map(|(cluster_idx, t_mat)| {
79 let cluster_dofs = &cluster_dof_indices[cluster_idx];
80 if cluster_dofs.is_empty() || t_mat.is_empty() {
81 return None;
82 }
83
84 let x_local: Array1<Complex64> = Array1::from_iter(cluster_dofs.iter().map(|&i| x[i]));
86
87 let result = t_mat.dot(&x_local);
89 Some((cluster_idx, result))
90 })
91 .collect();
92
93 for (cluster_idx, result) in results {
95 multipoles.row_mut(cluster_idx).assign(&result);
96 }
97}
98
99pub fn batched_s_matrix_apply(
105 s_matrices: &[Array2<Complex64>],
106 locals: &Array2<Complex64>,
107 cluster_dof_indices: &[Vec<usize>],
108 y: &mut Array1<Complex64>,
109) {
110 use rayon::prelude::*;
111
112 let results: Vec<Vec<(usize, Complex64)>> = s_matrices
114 .par_iter()
115 .enumerate()
116 .filter_map(|(cluster_idx, s_mat)| {
117 let cluster_dofs = &cluster_dof_indices[cluster_idx];
118 if cluster_dofs.is_empty() || s_mat.is_empty() {
119 return None;
120 }
121
122 let y_local = s_mat.dot(&locals.row(cluster_idx));
124
125 let contributions: Vec<(usize, Complex64)> = cluster_dofs
127 .iter()
128 .enumerate()
129 .map(|(local_j, &global_j)| (global_j, y_local[local_j]))
130 .collect();
131
132 Some(contributions)
133 })
134 .collect();
135
136 for contributions in results {
138 for (idx, val) in contributions {
139 y[idx] += val;
140 }
141 }
142}
143
144pub fn batched_d_matrix_apply(
149 d_matrices: &[super::super::assembly::slfmm::DMatrixEntry],
150 multipoles: &Array2<Complex64>,
151 locals: &mut Array2<Complex64>,
152) {
153 use rayon::prelude::*;
154
155 let results: Vec<(usize, Array1<Complex64>)> = d_matrices
158 .par_iter()
159 .map(|d_entry| {
160 let src_mult = multipoles.row(d_entry.source_cluster);
161 let translated: Array1<Complex64> = d_entry
163 .diagonal
164 .iter()
165 .zip(src_mult.iter())
166 .map(|(&d, &s)| d * s)
167 .collect();
168 (d_entry.field_cluster, translated)
169 })
170 .collect();
171
172 for (field_cluster, translated) in results {
174 for i in 0..translated.len() {
175 locals[[field_cluster, i]] += translated[i];
176 }
177 }
178}
179
180pub fn batched_near_field_apply(
185 near_blocks: &[super::super::assembly::slfmm::NearFieldBlock],
186 x: &Array1<Complex64>,
187 cluster_dof_indices: &[Vec<usize>],
188 y: &mut Array1<Complex64>,
189) {
190 use rayon::prelude::*;
191
192 let contributions: Vec<Vec<(usize, Complex64)>> = near_blocks
194 .par_iter()
195 .flat_map(|block| {
196 let src_dofs = &cluster_dof_indices[block.source_cluster];
197 let fld_dofs = &cluster_dof_indices[block.field_cluster];
198
199 let mut result = Vec::new();
200
201 let x_local: Array1<Complex64> = Array1::from_iter(fld_dofs.iter().map(|&i| x[i]));
203
204 let y_local = block.coefficients.dot(&x_local);
206
207 for (local_i, &global_i) in src_dofs.iter().enumerate() {
209 result.push((global_i, y_local[local_i]));
210 }
211
212 if block.source_cluster != block.field_cluster {
214 let x_src: Array1<Complex64> = Array1::from_iter(src_dofs.iter().map(|&i| x[i]));
215 let y_fld = block.coefficients.t().dot(&x_src);
216 for (local_j, &global_j) in fld_dofs.iter().enumerate() {
217 result.push((global_j, y_fld[local_j]));
218 }
219 }
220
221 vec![result]
222 })
223 .collect();
224
225 for block_contributions in contributions {
227 for (idx, val) in block_contributions {
228 y[idx] += val;
229 }
230 }
231}
232
233pub fn slfmm_matvec_batched(
238 system: &super::super::assembly::slfmm::SlfmmSystem,
239 x: &Array1<Complex64>,
240 workspace: &mut SlfmmMatvecWorkspace,
241) -> Array1<Complex64> {
242 workspace.clear();
244
245 let mut y = Array1::zeros(system.num_dofs);
247
248 batched_near_field_apply(&system.near_matrix, x, &system.cluster_dof_indices, &mut y);
250
251 batched_t_matrix_apply(
255 &system.t_matrices,
256 x,
257 &system.cluster_dof_indices,
258 &mut workspace.multipoles,
259 );
260
261 batched_d_matrix_apply(
263 &system.d_matrices,
264 &workspace.multipoles,
265 &mut workspace.locals,
266 );
267
268 batched_s_matrix_apply(
270 &system.s_matrices,
271 &workspace.locals,
272 &system.cluster_dof_indices,
273 &mut y,
274 );
275
276 y
277}
278
279pub fn create_batched_matvec<'a>(
286 system: &'a super::super::assembly::slfmm::SlfmmSystem,
287) -> impl FnMut(&Array1<Complex64>) -> Array1<Complex64> + 'a {
288 let mut workspace = SlfmmMatvecWorkspace::new(
290 system.num_clusters,
291 system.num_sphere_points,
292 system.num_dofs,
293 );
294
295 move |x: &Array1<Complex64>| slfmm_matvec_batched(system, x, &mut workspace)
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_workspace_creation() {
304 let workspace = SlfmmMatvecWorkspace::new(10, 32, 100);
305 assert_eq!(workspace.num_clusters, 10);
306 assert_eq!(workspace.num_sphere_points, 32);
307 assert_eq!(workspace.num_dofs, 100);
308 assert_eq!(workspace.multipoles.shape(), &[10, 32]);
309 assert_eq!(workspace.locals.shape(), &[10, 32]);
310 assert_eq!(workspace.dof_buffer.len(), 100);
311 }
312
313 #[test]
314 fn test_workspace_clear() {
315 let mut workspace = SlfmmMatvecWorkspace::new(2, 4, 8);
316 workspace.multipoles[[0, 0]] = Complex64::new(1.0, 2.0);
317 workspace.clear();
318 assert_eq!(workspace.multipoles[[0, 0]], Complex64::new(0.0, 0.0));
319 }
320}