1use ndarray::{Array1, Array2};
13use num_complex::Complex64;
14
15pub struct SlfmmMatvecWorkspace {
19 pub multipoles: Array2<Complex64>,
21 pub locals: Array2<Complex64>,
23 pub dof_buffer: Array1<Complex64>,
25 pub num_clusters: usize,
27 pub num_sphere_points: usize,
29 pub num_dofs: usize,
31}
32
33impl SlfmmMatvecWorkspace {
34 pub fn new(num_clusters: usize, num_sphere_points: usize, num_dofs: usize) -> Self {
36 Self {
37 multipoles: Array2::zeros((num_clusters, num_sphere_points)),
38 locals: Array2::zeros((num_clusters, num_sphere_points)),
39 dof_buffer: Array1::zeros(num_dofs),
40 num_clusters,
41 num_sphere_points,
42 num_dofs,
43 }
44 }
45
46 pub fn clear(&mut self) {
48 self.multipoles.fill(Complex64::new(0.0, 0.0));
49 self.locals.fill(Complex64::new(0.0, 0.0));
50 self.dof_buffer.fill(Complex64::new(0.0, 0.0));
51 }
52}
53
54pub fn batched_t_matrix_apply(
63 t_matrices: &[Array2<Complex64>],
64 x: &Array1<Complex64>,
65 cluster_dof_indices: &[Vec<usize>],
66 multipoles: &mut Array2<Complex64>,
67) {
68 use rayon::prelude::*;
70
71 let results: Vec<(usize, Array1<Complex64>)> = t_matrices
73 .par_iter()
74 .enumerate()
75 .filter_map(|(cluster_idx, t_mat)| {
76 let cluster_dofs = &cluster_dof_indices[cluster_idx];
77 if cluster_dofs.is_empty() || t_mat.is_empty() {
78 return None;
79 }
80
81 let x_local: Array1<Complex64> =
83 Array1::from_iter(cluster_dofs.iter().map(|&i| x[i]));
84
85 let result = t_mat.dot(&x_local);
87 Some((cluster_idx, result))
88 })
89 .collect();
90
91 for (cluster_idx, result) in results {
93 multipoles.row_mut(cluster_idx).assign(&result);
94 }
95}
96
97pub fn batched_s_matrix_apply(
103 s_matrices: &[Array2<Complex64>],
104 locals: &Array2<Complex64>,
105 cluster_dof_indices: &[Vec<usize>],
106 y: &mut Array1<Complex64>,
107) {
108 use rayon::prelude::*;
109
110 let results: Vec<Vec<(usize, Complex64)>> = s_matrices
112 .par_iter()
113 .enumerate()
114 .filter_map(|(cluster_idx, s_mat)| {
115 let cluster_dofs = &cluster_dof_indices[cluster_idx];
116 if cluster_dofs.is_empty() || s_mat.is_empty() {
117 return None;
118 }
119
120 let y_local = s_mat.dot(&locals.row(cluster_idx));
122
123 let contributions: Vec<(usize, Complex64)> = cluster_dofs
125 .iter()
126 .enumerate()
127 .map(|(local_j, &global_j)| (global_j, y_local[local_j]))
128 .collect();
129
130 Some(contributions)
131 })
132 .collect();
133
134 for contributions in results {
136 for (idx, val) in contributions {
137 y[idx] += val;
138 }
139 }
140}
141
142pub fn batched_d_matrix_apply(
147 d_matrices: &[super::super::assembly::slfmm::DMatrixEntry],
148 multipoles: &Array2<Complex64>,
149 locals: &mut Array2<Complex64>,
150) {
151 use rayon::prelude::*;
152
153 let results: Vec<(usize, Array1<Complex64>)> = d_matrices
155 .par_iter()
156 .map(|d_entry| {
157 let src_mult = multipoles.row(d_entry.source_cluster);
158 let translated = d_entry.coefficients.dot(&src_mult);
159 (d_entry.field_cluster, translated)
160 })
161 .collect();
162
163 for (field_cluster, translated) in results {
165 for i in 0..translated.len() {
166 locals[[field_cluster, i]] += translated[i];
167 }
168 }
169}
170
171pub fn batched_near_field_apply(
176 near_blocks: &[super::super::assembly::slfmm::NearFieldBlock],
177 x: &Array1<Complex64>,
178 cluster_dof_indices: &[Vec<usize>],
179 y: &mut Array1<Complex64>,
180) {
181 use rayon::prelude::*;
182
183 let contributions: Vec<Vec<(usize, Complex64)>> = near_blocks
185 .par_iter()
186 .flat_map(|block| {
187 let src_dofs = &cluster_dof_indices[block.source_cluster];
188 let fld_dofs = &cluster_dof_indices[block.field_cluster];
189
190 let mut result = Vec::new();
191
192 let x_local: Array1<Complex64> = Array1::from_iter(fld_dofs.iter().map(|&i| x[i]));
194
195 let y_local = block.coefficients.dot(&x_local);
197
198 for (local_i, &global_i) in src_dofs.iter().enumerate() {
200 result.push((global_i, y_local[local_i]));
201 }
202
203 if block.source_cluster != block.field_cluster {
205 let x_src: Array1<Complex64> = Array1::from_iter(src_dofs.iter().map(|&i| x[i]));
206 let y_fld = block.coefficients.t().dot(&x_src);
207 for (local_j, &global_j) in fld_dofs.iter().enumerate() {
208 result.push((global_j, y_fld[local_j]));
209 }
210 }
211
212 vec![result]
213 })
214 .collect();
215
216 for block_contributions in contributions {
218 for (idx, val) in block_contributions {
219 y[idx] += val;
220 }
221 }
222}
223
224pub fn slfmm_matvec_batched(
229 system: &super::super::assembly::slfmm::SlfmmSystem,
230 x: &Array1<Complex64>,
231 workspace: &mut SlfmmMatvecWorkspace,
232) -> Array1<Complex64> {
233 workspace.clear();
235
236 let mut y = Array1::zeros(system.num_dofs);
238
239 batched_near_field_apply(&system.near_matrix, x, &system.cluster_dof_indices, &mut y);
241
242 batched_t_matrix_apply(
246 &system.t_matrices,
247 x,
248 &system.cluster_dof_indices,
249 &mut workspace.multipoles,
250 );
251
252 batched_d_matrix_apply(&system.d_matrices, &workspace.multipoles, &mut workspace.locals);
254
255 batched_s_matrix_apply(&system.s_matrices, &workspace.locals, &system.cluster_dof_indices, &mut y);
257
258 y
259}
260
261pub fn create_batched_matvec<'a>(
268 system: &'a super::super::assembly::slfmm::SlfmmSystem,
269) -> impl FnMut(&Array1<Complex64>) -> Array1<Complex64> + 'a {
270 let mut workspace =
272 SlfmmMatvecWorkspace::new(system.num_clusters, system.num_sphere_points, system.num_dofs);
273
274 move |x: &Array1<Complex64>| slfmm_matvec_batched(system, x, &mut workspace)
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_workspace_creation() {
283 let workspace = SlfmmMatvecWorkspace::new(10, 32, 100);
284 assert_eq!(workspace.num_clusters, 10);
285 assert_eq!(workspace.num_sphere_points, 32);
286 assert_eq!(workspace.num_dofs, 100);
287 assert_eq!(workspace.multipoles.shape(), &[10, 32]);
288 assert_eq!(workspace.locals.shape(), &[10, 32]);
289 assert_eq!(workspace.dof_buffer.len(), 100);
290 }
291
292 #[test]
293 fn test_workspace_clear() {
294 let mut workspace = SlfmmMatvecWorkspace::new(2, 4, 8);
295 workspace.multipoles[[0, 0]] = Complex64::new(1.0, 2.0);
296 workspace.clear();
297 assert_eq!(workspace.multipoles[[0, 0]], Complex64::new(0.0, 0.0));
298 }
299}