1use crate::geometry::{Axis, Geometry, Grid};
2use crate::weight_functions::*;
3use ndarray::linalg::Dot;
4use ndarray::prelude::*;
5use ndarray::{Axis as Axis_nd, RemoveAxis, Slice};
6use num_dual::*;
7use num_traits::Zero;
8use rustdct::DctNum;
9use std::ops::{AddAssign, MulAssign, SubAssign};
10use std::sync::Arc;
11
12mod periodic_convolver;
13mod transform;
14pub use periodic_convolver::PeriodicConvolver;
15use transform::*;
16
17pub trait Convolver<T, D: Dimension>: Send + Sync {
25 fn convolve(&self, profile: Array<T, D>, weight_function: &WeightFunction<T>) -> Array<T, D>;
27
28 fn weighted_densities(&self, density: &Array<T, D::Larger>) -> Vec<Array<T, D::Larger>>;
30
31 fn functional_derivative(
34 &self,
35 partial_derivatives: &[Array<T, D::Larger>],
36 ) -> Array<T, D::Larger>;
37}
38
39pub(crate) struct BulkConvolver<T> {
40 weight_constants: Vec<Array2<T>>,
41}
42
43impl<T: DualNum<f64> + Copy + Send + Sync> BulkConvolver<T> {
44 #[expect(clippy::new_ret_no_self)]
45 pub(crate) fn new(weight_functions: Vec<WeightFunctionInfo<T>>) -> Arc<dyn Convolver<T, Ix0>> {
46 let weight_constants = weight_functions
47 .into_iter()
48 .map(|w| w.weight_constants(Zero::zero(), 0))
49 .collect();
50 Arc::new(Self { weight_constants })
51 }
52}
53
54impl<T: DualNum<f64> + Copy + Send + Sync> Convolver<T, Ix0> for BulkConvolver<T>
55where
56 Array2<T>: Dot<Array1<T>, Output = Array1<T>>,
57{
58 fn convolve(&self, _: Array0<T>, _: &WeightFunction<T>) -> Array0<T> {
59 unreachable!()
60 }
61
62 fn weighted_densities(&self, density: &Array1<T>) -> Vec<Array1<T>> {
63 self.weight_constants
64 .iter()
65 .map(|w| w.dot(density))
66 .collect()
67 }
68
69 fn functional_derivative(&self, partial_derivatives: &[Array1<T>]) -> Array1<T> {
70 self.weight_constants
71 .iter()
72 .zip(partial_derivatives.iter())
73 .map(|(w, pd)| pd.dot(w))
74 .reduce(|a, b| a + b)
75 .unwrap()
76 }
77}
78
79#[derive(Debug, Clone)]
83struct FFTWeightFunctions<T, D: Dimension> {
84 pub(crate) segments: usize,
87 pub(crate) local_density: bool,
89 pub(crate) scalar_component_weighted_densities: Vec<Array<T, D::Larger>>,
91 pub(crate) vector_component_weighted_densities: Vec<Array<T, <D::Larger as Dimension>::Larger>>,
93 pub(crate) scalar_fmt_weighted_densities: Vec<Array<T, D::Larger>>,
95 pub(crate) vector_fmt_weighted_densities: Vec<Array<T, <D::Larger as Dimension>::Larger>>,
97}
98
99impl<T, D: Dimension> FFTWeightFunctions<T, D> {
100 pub fn n_weighted_densities(&self, dimensions: usize) -> usize {
103 (if self.local_density { self.segments } else { 0 })
104 + self.scalar_component_weighted_densities.len() * self.segments
105 + self.vector_component_weighted_densities.len() * self.segments * dimensions
106 + self.scalar_fmt_weighted_densities.len()
107 + self.vector_fmt_weighted_densities.len() * dimensions
108 }
109}
110
111pub struct ConvolverFFT<T, D: Dimension> {
116 k_abs: Array<f64, D>,
118 weight_functions: Vec<FFTWeightFunctions<T, D>>,
120 lanczos_sigma: Option<Array<f64, D>>,
122 transform: Box<dyn FourierTransform<T>>,
124 cartesian_transforms: Vec<CartesianTransform<T>>,
126}
127
128impl<T, D: Dimension + RemoveAxis + 'static> ConvolverFFT<T, D>
129where
130 T: DctNum + DualNum<f64>,
131 D::Larger: Dimension<Smaller = D>,
132 D::Smaller: Dimension<Larger = D>,
133 <D::Larger as Dimension>::Larger: Dimension<Smaller = D::Larger>,
134{
135 pub fn plan(
137 grid: &Grid,
138 weight_functions: &[WeightFunctionInfo<T>],
139 lanczos: Option<i32>,
140 ) -> Arc<dyn Convolver<T, D>> {
141 match grid {
142 Grid::Polar(r) => CurvilinearConvolver::new(r, &[], weight_functions, lanczos),
143 Grid::Spherical(r) => CurvilinearConvolver::new(r, &[], weight_functions, lanczos),
144 Grid::Cartesian1(z) => Self::new(Some(z), &[], weight_functions, lanczos),
145 Grid::Cylindrical { r, z } => {
146 CurvilinearConvolver::new(r, &[z], weight_functions, lanczos)
147 }
148 Grid::Cartesian2(x, y) => Self::new(Some(x), &[y], weight_functions, lanczos),
149 Grid::Periodical2(x, y, alpha) => {
150 PeriodicConvolver::new_2d(&[x, y], *alpha, weight_functions, lanczos)
151 }
152 Grid::Cartesian3(x, y, z) => Self::new(Some(x), &[y, z], weight_functions, lanczos),
153 Grid::Periodical3(x, y, z, angles) => {
154 PeriodicConvolver::new_3d(&[x, y, z], *angles, weight_functions, lanczos)
155 }
156 }
157 }
158}
159
160impl<T, D: Dimension + 'static> ConvolverFFT<T, D>
161where
162 T: DctNum + DualNum<f64>,
163 D::Larger: Dimension<Smaller = D>,
164 <D::Larger as Dimension>::Larger: Dimension<Smaller = D::Larger>,
165{
166 #[expect(clippy::new_ret_no_self)]
167 fn new(
168 axis: Option<&Axis>,
169 cartesian_axes: &[&Axis],
170 weight_functions: &[WeightFunctionInfo<T>],
171 lanczos: Option<i32>,
172 ) -> Arc<dyn Convolver<T, D>> {
173 let mut cartesian_transforms = Vec::with_capacity(cartesian_axes.len());
175 let mut k_vec = Vec::with_capacity(cartesian_axes.len() + 1);
176 let mut lengths = Vec::with_capacity(cartesian_axes.len() + 1);
177 let (transform, k_x) = match axis {
178 Some(axis) => match axis.geometry {
179 Geometry::Cartesian => CartesianTransform::new(axis),
180 Geometry::Cylindrical => PolarTransform::new(axis),
181 Geometry::Spherical => SphericalTransform::new(axis),
182 },
183 None => NoTransform::new(),
184 };
185 k_vec.push(k_x);
186 lengths.push(axis.map_or(1.0, |axis| axis.length()));
187 for ax in cartesian_axes {
188 let (transform, k_x) = CartesianTransform::new_cartesian(ax);
189 cartesian_transforms.push(transform);
190 k_vec.push(k_x);
191 lengths.push(ax.length());
192 }
193
194 let mut dim = vec![k_vec.len()];
196 k_vec.iter().for_each(|k_x| dim.push(k_x.len()));
197 let mut k: Array<_, D::Larger> = Array::zeros(dim).into_dimensionality().unwrap();
198 let mut k_abs = Array::zeros(k.raw_dim().remove_axis(Axis(0)));
199 for (i, (mut k_i, k_x)) in k.outer_iter_mut().zip(k_vec.iter()).enumerate() {
200 k_i.lanes_mut(Axis_nd(i))
201 .into_iter()
202 .for_each(|mut l| l.assign(k_x));
203 k_abs.add_assign(&k_i.mapv(|k| k.powi(2)));
204 }
205 k_abs.map_inplace(|k| *k = k.sqrt());
206
207 let lanczos_sigma = lanczos.map(|exp| {
209 let mut lanczos = Array::ones(k_abs.raw_dim());
210 for (i, (k_x, &l)) in k_vec.iter().zip(lengths.iter()).enumerate() {
211 let points = k_x.len();
212 let m2 = if points % 2 == 0 {
213 points as f64 + 2.0
214 } else {
215 points as f64 + 1.0
216 };
217 let l_x = k_x.mapv(|k| (k * l / m2).sph_j0().powi(exp));
218 for mut l in lanczos.lanes_mut(Axis_nd(i)) {
219 l.mul_assign(&l_x);
220 }
221 }
222 lanczos
223 });
224
225 let mut fft_weight_functions = Vec::with_capacity(weight_functions.len());
227 for wf in weight_functions {
228 let mut scal_comp = Vec::with_capacity(wf.scalar_component_weighted_densities.len());
231 for wf_i in &wf.scalar_component_weighted_densities {
233 scal_comp.push(wf_i.fft_scalar_weight_functions(&k_abs, &lanczos_sigma));
234 }
235
236 let mut vec_comp = Vec::with_capacity(wf.vector_component_weighted_densities.len());
238 for wf_i in &wf.vector_component_weighted_densities {
240 vec_comp.push(wf_i.fft_vector_weight_functions(&k_abs, &k, &lanczos_sigma));
241 }
242
243 let mut scal_fmt = Vec::with_capacity(wf.scalar_fmt_weighted_densities.len());
245 for wf_i in &wf.scalar_fmt_weighted_densities {
247 scal_fmt.push(wf_i.fft_scalar_weight_functions(&k_abs, &lanczos_sigma));
248 }
249
250 let mut vec_fmt = Vec::with_capacity(wf.vector_fmt_weighted_densities.len());
252 for wf_i in &wf.vector_fmt_weighted_densities {
254 vec_fmt.push(wf_i.fft_vector_weight_functions(&k_abs, &k, &lanczos_sigma));
255 }
256
257 fft_weight_functions.push(FFTWeightFunctions::<_, D> {
259 segments: wf.component_index.len(),
260 local_density: wf.local_density,
261 scalar_component_weighted_densities: scal_comp,
262 vector_component_weighted_densities: vec_comp,
263 scalar_fmt_weighted_densities: scal_fmt,
264 vector_fmt_weighted_densities: vec_fmt,
265 });
266 }
267
268 Arc::new(Self {
270 k_abs,
271 weight_functions: fft_weight_functions,
272 lanczos_sigma,
273 transform,
274 cartesian_transforms,
275 })
276 }
277}
278
279impl<T, D: Dimension> ConvolverFFT<T, D>
280where
281 T: DctNum + DualNum<f64>,
282 D::Larger: Dimension<Smaller = D>,
283 <D::Larger as Dimension>::Larger: Dimension<Smaller = D::Larger>,
284{
285 fn forward_transform(&self, f: ArrayView<T, D>, vector_index: Option<usize>) -> Array<T, D> {
286 let mut dim = vec![self.k_abs.shape()[0]];
287 f.shape().iter().skip(1).for_each(|&d| dim.push(d));
288 let mut result: Array<_, D> = Array::zeros(dim.clone()).into_dimensionality().unwrap();
289 for (f, r) in f
290 .lanes(Axis_nd(0))
291 .into_iter()
292 .zip(result.lanes_mut(Axis_nd(0)))
293 {
294 self.transform
295 .forward_transform(f, r, vector_index != Some(0));
296 }
297 for (i, transform) in self.cartesian_transforms.iter().enumerate() {
298 dim[i + 1] = self.k_abs.shape()[i + 1];
299 let mut res: Array<_, D> = Array::zeros(dim.clone()).into_dimensionality().unwrap();
300 for (f, r) in result
301 .lanes(Axis_nd(i + 1))
302 .into_iter()
303 .zip(res.lanes_mut(Axis_nd(i + 1)))
304 {
305 transform.forward_transform(f, r, vector_index.is_none_or(|ind| ind != i + 1));
306 }
307 result = res;
308 }
309
310 result
311 }
312
313 fn forward_transform_comps(
314 &self,
315 f: ArrayView<T, D::Larger>,
316 vector_index: Option<usize>,
317 ) -> Array<T, D::Larger> {
318 let mut dim = vec![f.shape()[0]];
319 self.k_abs.shape().iter().for_each(|&d| dim.push(d));
320 let mut result = Array::zeros(dim).into_dimensionality().unwrap();
321 for (f, mut r) in f.outer_iter().zip(result.outer_iter_mut()) {
322 r.assign(&self.forward_transform(f, vector_index));
323 }
324 result
325 }
326
327 fn back_transform(
328 &self,
329 mut f: ArrayViewMut<T, D>,
330 mut result: ArrayViewMut<T, D>,
331 vector_index: Option<usize>,
332 ) {
333 let mut dim = vec![result.shape()[0]];
334 f.shape().iter().skip(1).for_each(|&d| dim.push(d));
335 let mut res: Array<_, D> = Array::zeros(dim.clone()).into_dimensionality().unwrap();
336 for (f, r) in f
337 .lanes_mut(Axis_nd(0))
338 .into_iter()
339 .zip(res.lanes_mut(Axis_nd(0)))
340 {
341 self.transform.back_transform(f, r, vector_index != Some(0));
342 }
343 for (i, transform) in self.cartesian_transforms.iter().enumerate() {
344 dim[i + 1] = result.shape()[i + 1];
345 let mut res2: Array<_, D> = Array::zeros(dim.clone()).into_dimensionality().unwrap();
346 for (f, r) in res
347 .lanes_mut(Axis_nd(i + 1))
348 .into_iter()
349 .zip(res2.lanes_mut(Axis_nd(i + 1)))
350 {
351 transform.back_transform(f, r, vector_index.is_none_or(|ind| ind != i + 1));
352 }
353 res = res2;
354 }
355
356 result.assign(&res);
357 }
358
359 fn back_transform_comps(
360 &self,
361 mut f: Array<T, D::Larger>,
362 mut result: ArrayViewMut<T, D::Larger>,
363 vector_index: Option<usize>,
364 ) {
365 for (f, r) in f.outer_iter_mut().zip(result.outer_iter_mut()) {
366 self.back_transform(f, r, vector_index);
367 }
368 }
369}
370
371impl<T, D: Dimension> Convolver<T, D> for ConvolverFFT<T, D>
372where
373 T: DctNum + DualNum<f64>,
374 D::Larger: Dimension<Smaller = D>,
375 <D::Larger as Dimension>::Larger: Dimension<Smaller = D::Larger>,
376{
377 fn convolve(&self, profile: Array<T, D>, weight_function: &WeightFunction<T>) -> Array<T, D> {
378 let f_k = self.forward_transform(profile.view(), None);
380
381 let w = weight_function
383 .fft_scalar_weight_functions(&self.k_abs, &self.lanczos_sigma)
384 .index_axis_move(Axis(0), 0);
385
386 let mut result = Array::zeros(profile.raw_dim());
388 self.back_transform((f_k * w).view_mut(), result.view_mut(), None);
389 result
390 }
391
392 fn weighted_densities(&self, density: &Array<T, D::Larger>) -> Vec<Array<T, D::Larger>> {
393 let rho_k = self.forward_transform_comps(density.view(), None);
395
396 let mut weighted_densities_vec = Vec::with_capacity(self.weight_functions.len());
398 for wf in &self.weight_functions {
399 let n_wd = wf.n_weighted_densities(density.ndim() - 1);
401
402 let mut dim = vec![n_wd];
404 density.shape().iter().skip(1).for_each(|&d| dim.push(d));
405 let mut weighted_densities = Array::zeros(dim).into_dimensionality().unwrap();
406
407 let mut k = 0;
409
410 if wf.local_density {
412 weighted_densities
413 .slice_axis_mut(Axis(0), Slice::from(0..wf.segments))
414 .assign(density);
415 k += wf.segments;
416 }
417
418 for wf_i in &wf.scalar_component_weighted_densities {
420 self.back_transform_comps(
421 &rho_k * wf_i,
422 weighted_densities.slice_axis_mut(Axis(0), Slice::from(k..k + wf.segments)),
423 None,
424 );
425 k += wf.segments;
426 }
427
428 for wf_i in &wf.vector_component_weighted_densities {
430 for (i, wf_i) in wf_i.outer_iter().enumerate() {
431 self.back_transform_comps(
432 &rho_k * &wf_i,
433 weighted_densities.slice_axis_mut(Axis(0), Slice::from(k..k + wf.segments)),
434 Some(i),
435 );
436 k += wf.segments;
437 }
438 }
439
440 for wf_i in &wf.scalar_fmt_weighted_densities {
442 self.back_transform(
443 (&rho_k * wf_i).sum_axis(Axis(0)).view_mut(),
444 weighted_densities.index_axis_mut(Axis(0), k),
445 None,
446 );
447 k += 1;
448 }
449
450 for wf_i in &wf.vector_fmt_weighted_densities {
452 for (i, wf_i) in wf_i.outer_iter().enumerate() {
453 self.back_transform(
454 (&rho_k * &wf_i).sum_axis(Axis(0)).view_mut(),
455 weighted_densities.index_axis_mut(Axis(0), k),
456 Some(i),
457 );
458 k += 1;
459 }
460 }
461
462 weighted_densities_vec.push(weighted_densities);
464 }
465 weighted_densities_vec
467 }
468
469 fn functional_derivative(
470 &self,
471 partial_derivatives: &[Array<T, D::Larger>],
472 ) -> Array<T, D::Larger> {
473 let mut dim = vec![self.weight_functions[0].segments];
476 partial_derivatives[0]
477 .shape()
478 .iter()
479 .skip(1)
480 .for_each(|&d| dim.push(d));
481 let mut functional_deriv = Array::zeros(dim).into_dimensionality().unwrap();
482 let mut functional_deriv_local = Array::zeros(functional_deriv.raw_dim());
483 let mut dim = vec![self.weight_functions[0].segments];
484 self.k_abs.shape().iter().for_each(|&d| dim.push(d));
485 let mut functional_deriv_k = Array::zeros(dim).into_dimensionality().unwrap();
486
487 for (pd, wf) in partial_derivatives.iter().zip(&self.weight_functions) {
489 let mut k = 0;
495
496 if wf.local_density {
498 functional_deriv_local += &pd.slice_axis(Axis(0), Slice::from(..wf.segments));
499 k += wf.segments;
500 }
501
502 for wf_i in &wf.scalar_component_weighted_densities {
504 let pd_k = self.forward_transform_comps(
505 pd.slice_axis(Axis(0), Slice::from(k..k + wf.segments)),
506 None,
507 );
508 functional_deriv_k.add_assign(&(&pd_k * wf_i));
509 k += wf.segments;
510 }
511
512 for wf_i in &wf.vector_component_weighted_densities {
514 for (i, wf_i) in wf_i.outer_iter().enumerate() {
515 let pd_k = self.forward_transform_comps(
516 pd.slice_axis(Axis(0), Slice::from(k..k + wf.segments)),
517 Some(i),
518 );
519 functional_deriv_k.add_assign(&(pd_k * &wf_i));
520 k += wf.segments;
521 }
522 }
523
524 for wf_i in &wf.scalar_fmt_weighted_densities {
526 let pd_k = self.forward_transform(pd.index_axis(Axis(0), k), None);
527 functional_deriv_k.add_assign(&(wf_i * &pd_k));
528 k += 1;
529 }
530
531 for wf_i in &wf.vector_fmt_weighted_densities {
533 for (i, wf_i) in wf_i.outer_iter().enumerate() {
534 let pd_k = self.forward_transform(pd.index_axis(Axis(0), k), Some(i));
535 functional_deriv_k.add_assign(&(&wf_i * &pd_k));
536 k += 1;
537 }
538 }
539 }
540
541 self.back_transform_comps(functional_deriv_k, functional_deriv.view_mut(), None);
543
544 functional_deriv + functional_deriv_local
546 }
547}
548
549struct CurvilinearConvolver<T, D> {
552 convolver: Arc<dyn Convolver<T, D>>,
553 convolver_boundary: Arc<dyn Convolver<T, D>>,
554}
555
556impl<T, D: Dimension + RemoveAxis + 'static> CurvilinearConvolver<T, D>
557where
558 T: DctNum + DualNum<f64>,
559 D::Larger: Dimension<Smaller = D>,
560 D::Smaller: Dimension<Larger = D>,
561 <D::Larger as Dimension>::Larger: Dimension<Smaller = D::Larger>,
562{
563 #[expect(clippy::new_ret_no_self)]
564 fn new(
565 r: &Axis,
566 z: &[&Axis],
567 weight_functions: &[WeightFunctionInfo<T>],
568 lanczos: Option<i32>,
569 ) -> Arc<dyn Convolver<T, D>> {
570 Arc::new(Self {
571 convolver: ConvolverFFT::new(Some(r), z, weight_functions, lanczos),
572 convolver_boundary: ConvolverFFT::new(None, z, weight_functions, lanczos),
573 })
574 }
575}
576
577impl<T, D: Dimension + RemoveAxis> Convolver<T, D> for CurvilinearConvolver<T, D>
578where
579 T: DctNum + DualNum<f64>,
580 D::Smaller: Dimension<Larger = D>,
581 D::Larger: Dimension<Smaller = D>,
582{
583 fn convolve(
584 &self,
585 mut profile: Array<T, D>,
586 weight_function: &WeightFunction<T>,
587 ) -> Array<T, D> {
588 let profile_boundary = profile
590 .index_axis(Axis(0), profile.shape()[0] - 1)
591 .into_owned();
592 for mut lane in profile.outer_iter_mut() {
593 lane.sub_assign(&profile_boundary);
594 }
595
596 let mut result = self.convolver.convolve(profile, weight_function);
598
599 let profile_boundary = profile_boundary.insert_axis(Axis(0));
601 let result_boundary = self
602 .convolver_boundary
603 .convolve(profile_boundary, weight_function);
604
605 let result_boundary = result_boundary.index_axis(Axis(0), 0);
607 for mut lane in result.outer_iter_mut() {
608 lane.add_assign(&result_boundary);
609 }
610 result
611 }
612
613 fn weighted_densities(&self, density: &Array<T, D::Larger>) -> Vec<Array<T, D::Larger>> {
615 let density_boundary = density.index_axis(Axis(1), density.shape()[1] - 1);
617 let mut density = density.to_owned();
618 for mut lane in density.axis_iter_mut(Axis(1)) {
619 lane.sub_assign(&density_boundary);
620 }
621
622 let mut wd = self.convolver.weighted_densities(&density);
624
625 let density_boundary = density_boundary.insert_axis(Axis(1));
627 let wd_boundary = self
628 .convolver_boundary
629 .weighted_densities(&density_boundary.to_owned());
630
631 for (wd, wd_boundary) in wd.iter_mut().zip(wd_boundary.iter()) {
633 let wd_view = wd_boundary.index_axis(Axis(1), 0);
634 for mut lane in wd.axis_iter_mut(Axis(1)) {
635 lane.add_assign(&wd_view);
636 }
637 }
638
639 wd
640 }
641
642 fn functional_derivative(
645 &self,
646 partial_derivatives: &[Array<T, D::Larger>],
647 ) -> Array<T, D::Larger> {
648 let mut partial_derivatives_full = Vec::new();
650 let mut partial_derivatives_boundary = Vec::new();
651 for pd in partial_derivatives {
652 let pd_boundary = pd.index_axis(Axis(1), pd.shape()[1] - 1).to_owned();
653 let mut pd_full = pd.to_owned();
654 for mut lane in pd_full.axis_iter_mut(Axis(1)) {
655 lane.sub_assign(&pd_boundary);
656 }
657 partial_derivatives_full.push(pd_full);
658 partial_derivatives_boundary.push(pd_boundary);
659 }
660
661 let mut functional_derivative = self
663 .convolver
664 .functional_derivative(&partial_derivatives_full);
665
666 let mut partial_derivatives_boundary = Vec::new();
668 for pd in partial_derivatives {
669 let mut pd_boundary = pd.view();
670 pd_boundary.collapse_axis(Axis(1), pd.shape()[1] - 1);
671 partial_derivatives_boundary.push(pd_boundary.to_owned());
672 }
673 let functional_derivative_boundary = self
674 .convolver_boundary
675 .functional_derivative(&partial_derivatives_boundary);
676
677 let functional_derivative_view = functional_derivative_boundary.index_axis(Axis(1), 0);
679 for mut lane in functional_derivative.axis_iter_mut(Axis(1)) {
680 lane.add_assign(&functional_derivative_view);
681 }
682
683 functional_derivative
684 }
685}