autd3_backend_arrayfire/
lib.rs

1#![allow(unknown_lints)]
2#![allow(clippy::manual_slice_size_calculation)]
3
4use std::collections::HashMap;
5
6use arrayfire::*;
7
8use autd3_core::{
9    acoustics::{
10        directivity::{Directivity, Sphere},
11        propagate,
12    },
13    gain::BitVec,
14    geometry::Geometry,
15};
16use autd3_gain_holo::{
17    Complex, HoloError, LinAlgBackend, MatrixX, MatrixXc, Trans, VectorX, VectorXc,
18};
19
20pub type AFBackend = arrayfire::Backend;
21pub type AFDeviceInfo = (String, String, String, String);
22
23fn convert(trans: Trans) -> MatProp {
24    match trans {
25        Trans::NoTrans => MatProp::NONE,
26        Trans::Trans => MatProp::TRANS,
27        Trans::ConjTrans => MatProp::CTRANS,
28    }
29}
30
31pub struct ArrayFireBackend<D: Directivity> {
32    _phantom: std::marker::PhantomData<D>,
33}
34
35impl ArrayFireBackend<Sphere> {
36    pub fn get_available_backends() -> Vec<AFBackend> {
37        arrayfire::get_available_backends()
38    }
39
40    pub fn set_backend(backend: AFBackend) {
41        arrayfire::set_backend(backend);
42    }
43
44    pub fn set_device(device: i32) {
45        arrayfire::set_device(device);
46    }
47
48    pub fn get_available_devices() -> Vec<AFDeviceInfo> {
49        let cur_dev = arrayfire::get_device();
50        let r = (0..arrayfire::device_count())
51            .map(|i| {
52                arrayfire::set_device(i);
53                arrayfire::device_info()
54            })
55            .collect();
56        arrayfire::set_device(cur_dev);
57        r
58    }
59}
60
61impl Default for ArrayFireBackend<Sphere> {
62    fn default() -> Self {
63        Self {
64            _phantom: Default::default(),
65        }
66    }
67}
68
69impl<D: Directivity> ArrayFireBackend<D> {
70    pub fn new() -> Self {
71        Self {
72            _phantom: std::marker::PhantomData,
73        }
74    }
75}
76
77impl<D: Directivity> LinAlgBackend<D> for ArrayFireBackend<D> {
78    type MatrixXc = Array<c32>;
79    type MatrixX = Array<f32>;
80    type VectorXc = Array<c32>;
81    type VectorX = Array<f32>;
82
83    fn generate_propagation_matrix(
84        &self,
85        geometry: &Geometry,
86        foci: &[autd3_core::geometry::Point3],
87        filter: Option<&HashMap<usize, BitVec>>,
88    ) -> Result<Self::MatrixXc, HoloError> {
89        let g = if let Some(filter) = filter {
90            geometry
91                .devices()
92                .flat_map(|dev| {
93                    dev.iter().filter_map(move |tr| {
94                        if let Some(filter) = filter.get(&dev.idx()) {
95                            if filter[tr.idx()] {
96                                Some(foci.iter().map(move |fp| {
97                                    propagate::<D>(tr, dev.wavenumber(), dev.axial_direction(), fp)
98                                }))
99                            } else {
100                                None
101                            }
102                        } else {
103                            None
104                        }
105                    })
106                })
107                .flatten()
108                .collect::<Vec<_>>()
109        } else {
110            geometry
111                .devices()
112                .flat_map(|dev| {
113                    dev.iter().flat_map(move |tr| {
114                        foci.iter().map(move |fp| {
115                            propagate::<D>(tr, dev.wavenumber(), dev.axial_direction(), fp)
116                        })
117                    })
118                })
119                .collect::<Vec<_>>()
120        };
121
122        unsafe {
123            Ok(Array::new(
124                std::slice::from_raw_parts(g.as_ptr() as *const c32, g.len()),
125                Dim4::new(&[foci.len() as u64, (g.len() / foci.len()) as _, 1, 1]),
126            ))
127        }
128    }
129
130    fn alloc_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
131        Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
132    }
133
134    fn alloc_m(&self, rows: usize, cols: usize) -> Result<Self::MatrixX, HoloError> {
135        Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
136    }
137
138    fn alloc_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
139        Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
140    }
141
142    fn alloc_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
143        Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
144    }
145
146    fn alloc_zeros_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
147        Ok(arrayfire::constant(0., Dim4::new(&[size as _, 1, 1, 1])))
148    }
149
150    fn alloc_zeros_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
151        Ok(arrayfire::constant(
152            c32::new(0., 0.),
153            Dim4::new(&[size as _, 1, 1, 1]),
154        ))
155    }
156
157    fn alloc_zeros_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
158        Ok(arrayfire::constant(
159            c32::new(0., 0.),
160            Dim4::new(&[rows as _, cols as _, 1, 1]),
161        ))
162    }
163
164    fn to_host_v(&self, v: Self::VectorX) -> Result<VectorX, HoloError> {
165        let mut r = VectorX::zeros(v.elements());
166        v.host(r.as_mut_slice());
167        Ok(r)
168    }
169
170    fn to_host_m(&self, v: Self::MatrixX) -> Result<MatrixX, HoloError> {
171        let mut r = MatrixX::zeros(v.dims()[0] as _, v.dims()[1] as _);
172        v.host(r.as_mut_slice());
173        Ok(r)
174    }
175
176    fn to_host_cv(&self, v: Self::VectorXc) -> Result<VectorXc, HoloError> {
177        let n = v.elements();
178        let mut r = VectorXc::zeros(n);
179        unsafe {
180            v.host(std::slice::from_raw_parts_mut(
181                r.as_mut_ptr() as *mut c32,
182                n,
183            ));
184        }
185        Ok(r)
186    }
187
188    fn to_host_cm(&self, v: Self::MatrixXc) -> Result<MatrixXc, HoloError> {
189        let n = v.elements();
190        let mut r = MatrixXc::zeros(v.dims()[0] as _, v.dims()[1] as _);
191        unsafe {
192            v.host(std::slice::from_raw_parts_mut(
193                r.as_mut_ptr() as *mut c32,
194                n,
195            ));
196        }
197        Ok(r)
198    }
199
200    fn from_slice_v(&self, v: &[f32]) -> Result<Self::VectorX, HoloError> {
201        Ok(Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1])))
202    }
203
204    fn from_slice_m(
205        &self,
206        rows: usize,
207        cols: usize,
208        v: &[f32],
209    ) -> Result<Self::MatrixX, HoloError> {
210        Ok(Array::new(v, Dim4::new(&[rows as _, cols as _, 1, 1])))
211    }
212
213    fn from_slice_cv(&self, v: &[f32]) -> Result<Self::VectorXc, HoloError> {
214        let r = Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1]));
215        Ok(arrayfire::cplx(&r))
216    }
217
218    fn from_slice2_cv(&self, r: &[f32], i: &[f32]) -> Result<Self::VectorXc, HoloError> {
219        let r = Array::new(r, Dim4::new(&[r.len() as _, 1, 1, 1]));
220        let i = Array::new(i, Dim4::new(&[i.len() as _, 1, 1, 1]));
221        Ok(arrayfire::cplx2(&r, &i, false).cast())
222    }
223
224    fn from_slice2_cm(
225        &self,
226        rows: usize,
227        cols: usize,
228        r: &[f32],
229        i: &[f32],
230    ) -> Result<Self::MatrixXc, HoloError> {
231        let r = Array::new(r, Dim4::new(&[rows as _, cols as _, 1, 1]));
232        let i = Array::new(i, Dim4::new(&[rows as _, cols as _, 1, 1]));
233        Ok(arrayfire::cplx2(&r, &i, false).cast())
234    }
235
236    fn copy_from_slice_v(&self, v: &[f32], dst: &mut Self::VectorX) -> Result<(), HoloError> {
237        let n = v.len();
238        if n == 0 {
239            return Ok(());
240        }
241        let v = self.from_slice_v(v)?;
242        let seqs = [Seq::new(0u32, n as u32 - 1, 1)];
243        arrayfire::assign_seq(dst, &seqs, &v);
244        Ok(())
245    }
246
247    fn copy_to_v(&self, src: &Self::VectorX, dst: &mut Self::VectorX) -> Result<(), HoloError> {
248        let seqs = [Seq::new(0u32, src.elements() as u32 - 1, 1)];
249        arrayfire::assign_seq(dst, &seqs, src);
250        Ok(())
251    }
252
253    fn copy_to_m(&self, src: &Self::MatrixX, dst: &mut Self::MatrixX) -> Result<(), HoloError> {
254        let seqs = [
255            Seq::new(0u32, src.dims()[0] as u32 - 1, 1),
256            Seq::new(0u32, src.dims()[1] as u32 - 1, 1),
257        ];
258        arrayfire::assign_seq(dst, &seqs, src);
259        Ok(())
260    }
261
262    fn clone_v(&self, v: &Self::VectorX) -> Result<Self::VectorX, HoloError> {
263        Ok(v.copy())
264    }
265
266    fn clone_m(&self, v: &Self::MatrixX) -> Result<Self::MatrixX, HoloError> {
267        Ok(v.copy())
268    }
269
270    fn clone_cv(&self, v: &Self::VectorXc) -> Result<Self::VectorXc, HoloError> {
271        Ok(v.copy())
272    }
273
274    fn clone_cm(&self, v: &Self::MatrixXc) -> Result<Self::MatrixXc, HoloError> {
275        Ok(v.copy())
276    }
277
278    fn make_complex2_v(
279        &self,
280        real: &Self::VectorX,
281        imag: &Self::VectorX,
282        v: &mut Self::VectorXc,
283    ) -> Result<(), HoloError> {
284        *v = arrayfire::cplx2(real, imag, false).cast();
285        Ok(())
286    }
287
288    fn create_diagonal(&self, v: &Self::VectorX, a: &mut Self::MatrixX) -> Result<(), HoloError> {
289        *a = arrayfire::diag_create(v, 0);
290        Ok(())
291    }
292
293    fn create_diagonal_c(
294        &self,
295        v: &Self::VectorXc,
296        a: &mut Self::MatrixXc,
297    ) -> Result<(), HoloError> {
298        *a = arrayfire::diag_create(v, 0);
299        Ok(())
300    }
301
302    fn get_diagonal(&self, a: &Self::MatrixX, v: &mut Self::VectorX) -> Result<(), HoloError> {
303        *v = arrayfire::diag_extract(a, 0);
304        Ok(())
305    }
306
307    fn real_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
308        *b = arrayfire::real(a);
309        Ok(())
310    }
311
312    fn imag_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
313        *b = arrayfire::imag(a);
314        Ok(())
315    }
316
317    fn scale_assign_cv(
318        &self,
319        a: autd3_gain_holo::Complex,
320        b: &mut Self::VectorXc,
321    ) -> Result<(), HoloError> {
322        let a = c32::new(a.re, a.im);
323        *b = arrayfire::mul(b, &a, false);
324        Ok(())
325    }
326
327    fn conj_assign_v(&self, b: &mut Self::VectorXc) -> Result<(), HoloError> {
328        *b = arrayfire::conjg(b);
329        Ok(())
330    }
331
332    fn exp_assign_cv(&self, v: &mut Self::VectorXc) -> Result<(), HoloError> {
333        *v = arrayfire::exp(v);
334        Ok(())
335    }
336
337    fn concat_col_cm(
338        &self,
339        a: &Self::MatrixXc,
340        b: &Self::MatrixXc,
341        c: &mut Self::MatrixXc,
342    ) -> Result<(), HoloError> {
343        *c = arrayfire::join(1, a, b);
344        Ok(())
345    }
346
347    fn max_v(&self, m: &Self::VectorX) -> Result<f32, HoloError> {
348        Ok(arrayfire::max_all(m).0)
349    }
350
351    fn hadamard_product_cm(
352        &self,
353        x: &Self::MatrixXc,
354        y: &Self::MatrixXc,
355        z: &mut Self::MatrixXc,
356    ) -> Result<(), HoloError> {
357        *z = arrayfire::mul(x, y, false);
358        Ok(())
359    }
360
361    fn dot(&self, x: &Self::VectorX, y: &Self::VectorX) -> Result<f32, HoloError> {
362        let r = arrayfire::dot(x, y, MatProp::NONE, MatProp::NONE);
363        let mut v = [0.];
364        r.host(&mut v);
365        Ok(v[0])
366    }
367
368    fn dot_c(
369        &self,
370        x: &Self::VectorXc,
371        y: &Self::VectorXc,
372    ) -> Result<autd3_gain_holo::Complex, HoloError> {
373        let r = arrayfire::dot(x, y, MatProp::CONJ, MatProp::NONE);
374        let mut v = [c32::new(0., 0.)];
375        r.host(&mut v);
376        Ok(autd3_gain_holo::Complex::new(v[0].re, v[0].im))
377    }
378
379    fn add_v(&self, alpha: f32, a: &Self::VectorX, b: &mut Self::VectorX) -> Result<(), HoloError> {
380        *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
381        Ok(())
382    }
383
384    fn add_m(&self, alpha: f32, a: &Self::MatrixX, b: &mut Self::MatrixX) -> Result<(), HoloError> {
385        *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
386        Ok(())
387    }
388
389    fn gevv_c(
390        &self,
391        trans_a: autd3_gain_holo::Trans,
392        trans_b: autd3_gain_holo::Trans,
393        alpha: autd3_gain_holo::Complex,
394        a: &Self::VectorXc,
395        x: &Self::VectorXc,
396        beta: autd3_gain_holo::Complex,
397        y: &mut Self::MatrixXc,
398    ) -> Result<(), HoloError> {
399        let alpha = vec![c32::new(alpha.re, alpha.im)];
400        let beta = vec![c32::new(beta.re, beta.im)];
401        let trans_a = convert(trans_a);
402        let trans_b = convert(trans_b);
403        arrayfire::gemm(y, trans_a, trans_b, alpha, a, x, beta);
404        Ok(())
405    }
406
407    fn gemv_c(
408        &self,
409        trans: autd3_gain_holo::Trans,
410        alpha: autd3_gain_holo::Complex,
411        a: &Self::MatrixXc,
412        x: &Self::VectorXc,
413        beta: autd3_gain_holo::Complex,
414        y: &mut Self::VectorXc,
415    ) -> Result<(), HoloError> {
416        let alpha = vec![c32::new(alpha.re, alpha.im)];
417        let beta = vec![c32::new(beta.re, beta.im)];
418        let trans = convert(trans);
419        arrayfire::gemm(y, trans, MatProp::NONE, alpha, a, x, beta);
420        Ok(())
421    }
422
423    fn gemm_c(
424        &self,
425        trans_a: autd3_gain_holo::Trans,
426        trans_b: autd3_gain_holo::Trans,
427        alpha: autd3_gain_holo::Complex,
428        a: &Self::MatrixXc,
429        b: &Self::MatrixXc,
430        beta: autd3_gain_holo::Complex,
431        y: &mut Self::MatrixXc,
432    ) -> Result<(), HoloError> {
433        let alpha = vec![c32::new(alpha.re, alpha.im)];
434        let beta = vec![c32::new(beta.re, beta.im)];
435        let trans_a = convert(trans_a);
436        let trans_b = convert(trans_b);
437        arrayfire::gemm(y, trans_a, trans_b, alpha, a, b, beta);
438        Ok(())
439    }
440
441    fn solve_inplace(&self, a: &Self::MatrixX, x: &mut Self::VectorX) -> Result<(), HoloError> {
442        *x = arrayfire::solve(a, x, MatProp::NONE);
443        Ok(())
444    }
445
446    fn reduce_col(&self, a: &Self::MatrixX, b: &mut Self::VectorX) -> Result<(), HoloError> {
447        *b = arrayfire::sum(a, 1);
448        Ok(())
449    }
450
451    fn cols_c(&self, m: &Self::MatrixXc) -> Result<usize, HoloError> {
452        Ok(m.dims()[1] as _)
453    }
454
455    fn scaled_to_cv(
456        &self,
457        a: &Self::VectorXc,
458        b: &Self::VectorXc,
459        c: &mut Self::VectorXc,
460    ) -> Result<(), HoloError> {
461        let tmp = arrayfire::div(a, &arrayfire::abs(a), false);
462        *c = arrayfire::mul(&tmp, b, false);
463        Ok(())
464    }
465
466    fn scaled_to_assign_cv(
467        &self,
468        a: &Self::VectorXc,
469        b: &mut Self::VectorXc,
470    ) -> Result<(), HoloError> {
471        *b = arrayfire::div(b, &arrayfire::abs(b), false);
472        *b = arrayfire::mul(a, b, false);
473        Ok(())
474    }
475
476    fn gen_back_prop(
477        &self,
478        m: usize,
479        n: usize,
480        transfer: &Self::MatrixXc,
481    ) -> Result<Self::MatrixXc, HoloError> {
482        let mut b = self.alloc_zeros_cm(m, n)?;
483
484        let mut tmp = self.alloc_zeros_cm(n, n)?;
485
486        self.gemm_c(
487            Trans::NoTrans,
488            Trans::ConjTrans,
489            Complex::new(1., 0.),
490            transfer,
491            transfer,
492            Complex::new(0., 0.),
493            &mut tmp,
494        )?;
495
496        let mut denominator = arrayfire::diag_extract(&tmp, 0);
497        let a = c32::new(1., 0.);
498        denominator = arrayfire::div(&a, &denominator, false);
499
500        self.create_diagonal_c(&denominator, &mut tmp)?;
501
502        self.gemm_c(
503            Trans::ConjTrans,
504            Trans::NoTrans,
505            Complex::new(1., 0.),
506            transfer,
507            &tmp,
508            Complex::new(0., 0.),
509            &mut b,
510        )?;
511
512        Ok(b)
513    }
514
515    fn norm_squared_cv(&self, a: &Self::VectorXc, b: &mut Self::VectorX) -> Result<(), HoloError> {
516        *b = arrayfire::abs(a);
517        *b = arrayfire::mul(b, b, false);
518        Ok(())
519    }
520}
521#[cfg(test)]
522mod tests {
523    use autd3::driver::autd3_device::AUTD3;
524    use autd3_core::{
525        acoustics::directivity::Sphere,
526        defined::PI,
527        geometry::{IntoDevice, Point3, UnitQuaternion},
528    };
529
530    use nalgebra::{ComplexField, Normed};
531
532    use autd3_gain_holo::{Amplitude, Pa, Trans};
533
534    use super::*;
535
536    use rand::Rng;
537
538    const N: usize = 10;
539    const EPS: f32 = 1e-3;
540
541    fn generate_geometry(size: usize) -> Geometry {
542        Geometry::new(
543            (0..size)
544                .flat_map(|i| {
545                    (0..size).map(move |j| {
546                        AUTD3 {
547                            pos: Point3::new(
548                                i as f32 * AUTD3::DEVICE_WIDTH,
549                                j as f32 * AUTD3::DEVICE_HEIGHT,
550                                0.,
551                            ),
552                            rot: UnitQuaternion::identity(),
553                        }
554                        .into_device((j + i * size) as _)
555                    })
556                })
557                .collect(),
558        )
559    }
560
561    fn gen_foci(n: usize) -> impl Iterator<Item = (Point3, Amplitude)> {
562        (0..n).map(move |i| {
563            (
564                Point3::new(
565                    90. + 10. * (2.0 * PI * i as f32 / n as f32).cos(),
566                    70. + 10. * (2.0 * PI * i as f32 / n as f32).sin(),
567                    150.,
568                ),
569                10e3 * Pa,
570            )
571        })
572    }
573
574    fn make_random_v(
575        backend: &ArrayFireBackend<Sphere>,
576        size: usize,
577    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorX, HoloError> {
578        let mut rng = rand::rng();
579        let v: Vec<f32> = (&mut rng)
580            .sample_iter(rand::distr::StandardUniform)
581            .take(size)
582            .collect();
583        backend.from_slice_v(&v)
584    }
585
586    fn make_random_m(
587        backend: &ArrayFireBackend<Sphere>,
588        rows: usize,
589        cols: usize,
590    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixX, HoloError> {
591        let mut rng = rand::rng();
592        let v: Vec<f32> = (&mut rng)
593            .sample_iter(rand::distr::StandardUniform)
594            .take(rows * cols)
595            .collect();
596        backend.from_slice_m(rows, cols, &v)
597    }
598
599    fn make_random_cv(
600        backend: &ArrayFireBackend<Sphere>,
601        size: usize,
602    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorXc, HoloError> {
603        let mut rng = rand::rng();
604        let real: Vec<f32> = (&mut rng)
605            .sample_iter(rand::distr::StandardUniform)
606            .take(size)
607            .collect();
608        let imag: Vec<f32> = (&mut rng)
609            .sample_iter(rand::distr::StandardUniform)
610            .take(size)
611            .collect();
612        backend.from_slice2_cv(&real, &imag)
613    }
614
615    fn make_random_cm(
616        backend: &ArrayFireBackend<Sphere>,
617        rows: usize,
618        cols: usize,
619    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixXc, HoloError> {
620        let mut rng = rand::rng();
621        let real: Vec<f32> = (&mut rng)
622            .sample_iter(rand::distr::StandardUniform)
623            .take(rows * cols)
624            .collect();
625        let imag: Vec<f32> = (&mut rng)
626            .sample_iter(rand::distr::StandardUniform)
627            .take(rows * cols)
628            .collect();
629        backend.from_slice2_cm(rows, cols, &real, &imag)
630    }
631
632    #[rstest::fixture]
633    fn backend() -> ArrayFireBackend<Sphere> {
634        ArrayFireBackend::set_backend(AFBackend::CPU);
635        ArrayFireBackend {
636            _phantom: std::marker::PhantomData,
637        }
638    }
639
640    #[rstest::rstest]
641    #[test]
642    #[cfg_attr(miri, ignore)]
643    fn test_alloc_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
644        let v = backend.alloc_v(N)?;
645        let v = backend.to_host_v(v)?;
646
647        assert_eq!(N, v.len());
648        Ok(())
649    }
650
651    #[rstest::rstest]
652    #[test]
653    #[cfg_attr(miri, ignore)]
654    fn test_alloc_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
655        let m = backend.alloc_m(N, 2 * N)?;
656        let m = backend.to_host_m(m)?;
657
658        assert_eq!(N, m.nrows());
659        assert_eq!(2 * N, m.ncols());
660        Ok(())
661    }
662
663    #[rstest::rstest]
664    #[test]
665    #[cfg_attr(miri, ignore)]
666    fn test_alloc_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
667        let v = backend.alloc_cv(N)?;
668        let v = backend.to_host_cv(v)?;
669
670        assert_eq!(N, v.len());
671        Ok(())
672    }
673
674    #[rstest::rstest]
675    #[test]
676    #[cfg_attr(miri, ignore)]
677    fn test_alloc_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
678        let m = backend.alloc_cm(N, 2 * N)?;
679        let m = backend.to_host_cm(m)?;
680
681        assert_eq!(N, m.nrows());
682        assert_eq!(2 * N, m.ncols());
683        Ok(())
684    }
685
686    #[rstest::rstest]
687    #[test]
688    #[cfg_attr(miri, ignore)]
689    fn test_alloc_zeros_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
690        let v = backend.alloc_zeros_v(N)?;
691        let v = backend.to_host_v(v)?;
692
693        assert_eq!(N, v.len());
694        assert!(v.iter().all(|&v| v == 0.));
695        Ok(())
696    }
697
698    #[rstest::rstest]
699    #[test]
700    #[cfg_attr(miri, ignore)]
701    fn test_alloc_zeros_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
702        let v = backend.alloc_zeros_cv(N)?;
703        let v = backend.to_host_cv(v)?;
704
705        assert_eq!(N, v.len());
706        assert!(v.iter().all(|&v| v == Complex::new(0., 0.)));
707        Ok(())
708    }
709
710    #[rstest::rstest]
711    #[test]
712    #[cfg_attr(miri, ignore)]
713    fn test_alloc_zeros_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
714        let m = backend.alloc_zeros_cm(N, 2 * N)?;
715        let m = backend.to_host_cm(m)?;
716
717        assert_eq!(N, m.nrows());
718        assert_eq!(2 * N, m.ncols());
719        assert!(m.iter().all(|&v| v == Complex::new(0., 0.)));
720        Ok(())
721    }
722
723    #[rstest::rstest]
724    #[test]
725    #[cfg_attr(miri, ignore)]
726    fn test_cols_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
727        let m = backend.alloc_cm(N, 2 * N)?;
728
729        assert_eq!(2 * N, backend.cols_c(&m)?);
730
731        Ok(())
732    }
733
734    #[rstest::rstest]
735    #[test]
736    #[cfg_attr(miri, ignore)]
737    fn test_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
738        let rng = rand::rng();
739
740        let v: Vec<f32> = rng
741            .sample_iter(rand::distr::StandardUniform)
742            .take(N)
743            .collect();
744
745        let c = backend.from_slice_v(&v)?;
746        let c = backend.to_host_v(c)?;
747
748        assert_eq!(N, c.len());
749        v.iter().zip(c.iter()).for_each(|(&r, &c)| {
750            assert_eq!(r, c);
751        });
752        Ok(())
753    }
754
755    #[rstest::rstest]
756    #[test]
757    #[cfg_attr(miri, ignore)]
758    fn test_from_slice_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
759        let rng = rand::rng();
760
761        let v: Vec<f32> = rng
762            .sample_iter(rand::distr::StandardUniform)
763            .take(N * 2 * N)
764            .collect();
765
766        let c = backend.from_slice_m(N, 2 * N, &v)?;
767        let c = backend.to_host_m(c)?;
768
769        assert_eq!(N, c.nrows());
770        assert_eq!(2 * N, c.ncols());
771        (0..2 * N).for_each(|col| {
772            (0..N).for_each(|row| {
773                assert_eq!(v[col * N + row], c[(row, col)]);
774            })
775        });
776        Ok(())
777    }
778
779    #[rstest::rstest]
780    #[test]
781    #[cfg_attr(miri, ignore)]
782    fn test_from_slice_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
783        let rng = rand::rng();
784
785        let real: Vec<f32> = rng
786            .sample_iter(rand::distr::StandardUniform)
787            .take(N)
788            .collect();
789
790        let c = backend.from_slice_cv(&real)?;
791        let c = backend.to_host_cv(c)?;
792
793        assert_eq!(N, c.len());
794        real.iter().zip(c.iter()).for_each(|(r, c)| {
795            assert_eq!(r, &c.re);
796            assert_eq!(0.0, c.im);
797        });
798        Ok(())
799    }
800
801    #[rstest::rstest]
802    #[test]
803    #[cfg_attr(miri, ignore)]
804    fn test_from_slice2_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
805        let mut rng = rand::rng();
806
807        let real: Vec<f32> = (&mut rng)
808            .sample_iter(rand::distr::StandardUniform)
809            .take(N)
810            .collect();
811        let imag: Vec<f32> = (&mut rng)
812            .sample_iter(rand::distr::StandardUniform)
813            .take(N)
814            .collect();
815
816        let c = backend.from_slice2_cv(&real, &imag)?;
817        let c = backend.to_host_cv(c)?;
818
819        assert_eq!(N, c.len());
820        real.iter()
821            .zip(imag.iter())
822            .zip(c.iter())
823            .for_each(|((r, i), c)| {
824                assert_eq!(r, &c.re);
825                assert_eq!(i, &c.im);
826            });
827        Ok(())
828    }
829
830    #[rstest::rstest]
831    #[test]
832    #[cfg_attr(miri, ignore)]
833    fn test_from_slice2_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
834        let mut rng = rand::rng();
835
836        let real: Vec<f32> = (&mut rng)
837            .sample_iter(rand::distr::StandardUniform)
838            .take(N * 2 * N)
839            .collect();
840        let imag: Vec<f32> = (&mut rng)
841            .sample_iter(rand::distr::StandardUniform)
842            .take(N * 2 * N)
843            .collect();
844
845        let c = backend.from_slice2_cm(N, 2 * N, &real, &imag)?;
846        let c = backend.to_host_cm(c)?;
847
848        assert_eq!(N, c.nrows());
849        assert_eq!(2 * N, c.ncols());
850        (0..2 * N).for_each(|col| {
851            (0..N).for_each(|row| {
852                assert_eq!(real[col * N + row], c[(row, col)].re);
853                assert_eq!(imag[col * N + row], c[(row, col)].im);
854            })
855        });
856        Ok(())
857    }
858
859    #[rstest::rstest]
860    #[test]
861    #[cfg_attr(miri, ignore)]
862    fn test_copy_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
863        {
864            let mut a = backend.alloc_zeros_v(N)?;
865            let mut rng = rand::rng();
866            let v = (&mut rng)
867                .sample_iter(rand::distr::StandardUniform)
868                .take(N / 2)
869                .collect::<Vec<f32>>();
870
871            backend.copy_from_slice_v(&v, &mut a)?;
872
873            let a = backend.to_host_v(a)?;
874            (0..N / 2).for_each(|i| {
875                assert_eq!(v[i], a[i]);
876            });
877            (N / 2..N).for_each(|i| {
878                assert_eq!(0., a[i]);
879            });
880        }
881
882        {
883            let mut a = backend.alloc_zeros_v(N)?;
884            let v = [];
885
886            backend.copy_from_slice_v(&v, &mut a)?;
887
888            let a = backend.to_host_v(a)?;
889            a.iter().for_each(|&a| {
890                assert_eq!(0., a);
891            });
892        }
893
894        Ok(())
895    }
896
897    #[rstest::rstest]
898    #[test]
899    #[cfg_attr(miri, ignore)]
900    fn test_copy_to_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
901        let a = make_random_v(&backend, N)?;
902        let mut b = backend.alloc_v(N)?;
903
904        backend.copy_to_v(&a, &mut b)?;
905
906        let a = backend.to_host_v(a)?;
907        let b = backend.to_host_v(b)?;
908        a.iter().zip(b.iter()).for_each(|(a, b)| {
909            assert_eq!(a, b);
910        });
911        Ok(())
912    }
913
914    #[rstest::rstest]
915    #[test]
916    #[cfg_attr(miri, ignore)]
917    fn test_copy_to_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
918        let a = make_random_m(&backend, N, N)?;
919        let mut b = backend.alloc_m(N, N)?;
920
921        backend.copy_to_m(&a, &mut b)?;
922
923        let a = backend.to_host_m(a)?;
924        let b = backend.to_host_m(b)?;
925        a.iter().zip(b.iter()).for_each(|(a, b)| {
926            assert_eq!(a, b);
927        });
928        Ok(())
929    }
930
931    #[rstest::rstest]
932    #[test]
933    #[cfg_attr(miri, ignore)]
934    fn test_clone_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
935        let c = make_random_v(&backend, N)?;
936        let c2 = backend.clone_v(&c)?;
937
938        let c = backend.to_host_v(c)?;
939        let c2 = backend.to_host_v(c2)?;
940
941        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
942            assert_eq!(c, c2);
943        });
944        Ok(())
945    }
946
947    #[rstest::rstest]
948    #[test]
949    #[cfg_attr(miri, ignore)]
950    fn test_clone_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
951        let c = make_random_m(&backend, N, N)?;
952        let c2 = backend.clone_m(&c)?;
953
954        let c = backend.to_host_m(c)?;
955        let c2 = backend.to_host_m(c2)?;
956
957        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
958            assert_eq!(c, c2);
959        });
960        Ok(())
961    }
962
963    #[rstest::rstest]
964    #[test]
965    #[cfg_attr(miri, ignore)]
966    fn test_clone_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
967        let c = make_random_cv(&backend, N)?;
968        let c2 = backend.clone_cv(&c)?;
969
970        let c = backend.to_host_cv(c)?;
971        let c2 = backend.to_host_cv(c2)?;
972
973        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
974            assert_eq!(c.re, c2.re);
975            assert_eq!(c.im, c2.im);
976        });
977        Ok(())
978    }
979
980    #[rstest::rstest]
981    #[test]
982    #[cfg_attr(miri, ignore)]
983    fn test_clone_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
984        let c = make_random_cm(&backend, N, N)?;
985        let c2 = backend.clone_cm(&c)?;
986
987        let c = backend.to_host_cm(c)?;
988        let c2 = backend.to_host_cm(c2)?;
989
990        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
991            assert_eq!(c.re, c2.re);
992            assert_eq!(c.im, c2.im);
993        });
994        Ok(())
995    }
996
997    #[rstest::rstest]
998    #[test]
999    #[cfg_attr(miri, ignore)]
1000    fn test_make_complex2_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1001        let real = make_random_v(&backend, N)?;
1002        let imag = make_random_v(&backend, N)?;
1003
1004        let mut c = backend.alloc_cv(N)?;
1005        backend.make_complex2_v(&real, &imag, &mut c)?;
1006
1007        let real = backend.to_host_v(real)?;
1008        let imag = backend.to_host_v(imag)?;
1009        let c = backend.to_host_cv(c)?;
1010        real.iter()
1011            .zip(imag.iter())
1012            .zip(c.iter())
1013            .for_each(|((r, i), c)| {
1014                assert_eq!(r, &c.re);
1015                assert_eq!(i, &c.im);
1016            });
1017        Ok(())
1018    }
1019
1020    #[rstest::rstest]
1021    #[test]
1022    #[cfg_attr(miri, ignore)]
1023    fn test_create_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1024        let diagonal = make_random_v(&backend, N)?;
1025
1026        let mut c = backend.alloc_m(N, N)?;
1027
1028        backend.create_diagonal(&diagonal, &mut c)?;
1029
1030        let diagonal = backend.to_host_v(diagonal)?;
1031        let c = backend.to_host_m(c)?;
1032        (0..N).for_each(|i| {
1033            (0..N).for_each(|j| {
1034                if i == j {
1035                    assert_eq!(diagonal[i], c[(i, j)]);
1036                } else {
1037                    assert_eq!(0.0, c[(i, j)]);
1038                }
1039            })
1040        });
1041        Ok(())
1042    }
1043
1044    #[rstest::rstest]
1045    #[test]
1046    #[cfg_attr(miri, ignore)]
1047    fn test_create_diagonal_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1048        let diagonal = make_random_cv(&backend, N)?;
1049
1050        let mut c = backend.alloc_cm(N, N)?;
1051
1052        backend.create_diagonal_c(&diagonal, &mut c)?;
1053
1054        let diagonal = backend.to_host_cv(diagonal)?;
1055        let c = backend.to_host_cm(c)?;
1056        (0..N).for_each(|i| {
1057            (0..N).for_each(|j| {
1058                if i == j {
1059                    assert_eq!(diagonal[i].re, c[(i, j)].re);
1060                    assert_eq!(diagonal[i].im, c[(i, j)].im);
1061                } else {
1062                    assert_eq!(0.0, c[(i, j)].re);
1063                    assert_eq!(0.0, c[(i, j)].im);
1064                }
1065            })
1066        });
1067        Ok(())
1068    }
1069
1070    #[rstest::rstest]
1071    #[test]
1072    #[cfg_attr(miri, ignore)]
1073    fn test_get_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1074        let m = make_random_m(&backend, N, N)?;
1075        let mut diagonal = backend.alloc_v(N)?;
1076
1077        backend.get_diagonal(&m, &mut diagonal)?;
1078
1079        let m = backend.to_host_m(m)?;
1080        let diagonal = backend.to_host_v(diagonal)?;
1081        (0..N).for_each(|i| {
1082            assert_eq!(m[(i, i)], diagonal[i]);
1083        });
1084        Ok(())
1085    }
1086
1087    #[rstest::rstest]
1088    #[test]
1089    #[cfg_attr(miri, ignore)]
1090    fn test_norm_squared_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1091        let v = make_random_cv(&backend, N)?;
1092
1093        let mut abs = backend.alloc_v(N)?;
1094        backend.norm_squared_cv(&v, &mut abs)?;
1095
1096        let v = backend.to_host_cv(v)?;
1097        let abs = backend.to_host_v(abs)?;
1098        v.iter().zip(abs.iter()).for_each(|(v, abs)| {
1099            assert_approx_eq::assert_approx_eq!(v.norm_squared(), abs, EPS);
1100        });
1101        Ok(())
1102    }
1103
1104    #[rstest::rstest]
1105    #[test]
1106    #[cfg_attr(miri, ignore)]
1107    fn test_real_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1108        let v = make_random_cm(&backend, N, N)?;
1109        let mut r = backend.alloc_m(N, N)?;
1110
1111        backend.real_cm(&v, &mut r)?;
1112
1113        let v = backend.to_host_cm(v)?;
1114        let r = backend.to_host_m(r)?;
1115        (0..N).for_each(|i| {
1116            (0..N).for_each(|j| {
1117                assert_approx_eq::assert_approx_eq!(v[(i, j)].re, r[(i, j)], EPS);
1118            })
1119        });
1120        Ok(())
1121    }
1122
1123    #[rstest::rstest]
1124    #[test]
1125    #[cfg_attr(miri, ignore)]
1126    fn test_imag_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1127        let v = make_random_cm(&backend, N, N)?;
1128        let mut r = backend.alloc_m(N, N)?;
1129
1130        backend.imag_cm(&v, &mut r)?;
1131
1132        let v = backend.to_host_cm(v)?;
1133        let r = backend.to_host_m(r)?;
1134        (0..N).for_each(|i| {
1135            (0..N).for_each(|j| {
1136                assert_approx_eq::assert_approx_eq!(v[(i, j)].im, r[(i, j)], EPS);
1137            })
1138        });
1139        Ok(())
1140    }
1141
1142    #[rstest::rstest]
1143    #[test]
1144    #[cfg_attr(miri, ignore)]
1145    fn test_scale_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1146        let mut v = make_random_cv(&backend, N)?;
1147        let vc = backend.clone_cv(&v)?;
1148        let mut rng = rand::rng();
1149        let scale = Complex::new(rng.random(), rng.random());
1150
1151        backend.scale_assign_cv(scale, &mut v)?;
1152
1153        let v = backend.to_host_cv(v)?;
1154        let vc = backend.to_host_cv(vc)?;
1155        v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1156            assert_approx_eq::assert_approx_eq!(scale * vc, v, EPS);
1157        });
1158        Ok(())
1159    }
1160
1161    #[rstest::rstest]
1162    #[test]
1163    #[cfg_attr(miri, ignore)]
1164    fn test_conj_assign_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1165        let mut v = make_random_cv(&backend, N)?;
1166        let vc = backend.clone_cv(&v)?;
1167
1168        backend.conj_assign_v(&mut v)?;
1169
1170        let v = backend.to_host_cv(v)?;
1171        let vc = backend.to_host_cv(vc)?;
1172        v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1173            assert_eq!(vc.re, v.re);
1174            assert_eq!(vc.im, -v.im);
1175        });
1176        Ok(())
1177    }
1178
1179    #[rstest::rstest]
1180    #[test]
1181    #[cfg_attr(miri, ignore)]
1182    fn test_exp_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1183        let mut v = make_random_cv(&backend, N)?;
1184        let vc = backend.clone_cv(&v)?;
1185
1186        backend.exp_assign_cv(&mut v)?;
1187
1188        let v = backend.to_host_cv(v)?;
1189        let vc = backend.to_host_cv(vc)?;
1190        v.iter().zip(vc.iter()).for_each(|(v, vc)| {
1191            assert_approx_eq::assert_approx_eq!(vc.exp(), v, EPS);
1192        });
1193        Ok(())
1194    }
1195
1196    #[rstest::rstest]
1197    #[test]
1198    #[cfg_attr(miri, ignore)]
1199    fn test_concat_col_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1200        let a = make_random_cm(&backend, N, N)?;
1201        let b = make_random_cm(&backend, N, 2 * N)?;
1202        let mut c = backend.alloc_cm(N, N + 2 * N)?;
1203
1204        backend.concat_col_cm(&a, &b, &mut c)?;
1205
1206        let a = backend.to_host_cm(a)?;
1207        let b = backend.to_host_cm(b)?;
1208        let c = backend.to_host_cm(c)?;
1209        (0..N).for_each(|col| (0..N).for_each(|row| assert_eq!(a[(row, col)], c[(row, col)])));
1210        (0..2 * N)
1211            .for_each(|col| (0..N).for_each(|row| assert_eq!(b[(row, col)], c[(row, N + col)])));
1212        Ok(())
1213    }
1214
1215    #[rstest::rstest]
1216    #[test]
1217    #[cfg_attr(miri, ignore)]
1218    fn test_max_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1219        let v = make_random_v(&backend, N)?;
1220
1221        let max = backend.max_v(&v)?;
1222
1223        let v = backend.to_host_v(v)?;
1224        assert_eq!(
1225            *v.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
1226            max
1227        );
1228        Ok(())
1229    }
1230
1231    #[rstest::rstest]
1232    #[test]
1233    #[cfg_attr(miri, ignore)]
1234    fn test_hadamard_product_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1235        let a = make_random_cm(&backend, N, N)?;
1236        let b = make_random_cm(&backend, N, N)?;
1237        let mut c = backend.alloc_cm(N, N)?;
1238
1239        backend.hadamard_product_cm(&a, &b, &mut c)?;
1240
1241        let a = backend.to_host_cm(a)?;
1242        let b = backend.to_host_cm(b)?;
1243        let c = backend.to_host_cm(c)?;
1244        c.iter()
1245            .zip(a.iter())
1246            .zip(b.iter())
1247            .for_each(|((c, a), b)| {
1248                assert_approx_eq::assert_approx_eq!(a.re * b.re - a.im * b.im, c.re, EPS);
1249                assert_approx_eq::assert_approx_eq!(a.re * b.im + a.im * b.re, c.im, EPS);
1250            });
1251        Ok(())
1252    }
1253
1254    #[rstest::rstest]
1255    #[test]
1256    #[cfg_attr(miri, ignore)]
1257    fn test_dot(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1258        let a = make_random_v(&backend, N)?;
1259        let b = make_random_v(&backend, N)?;
1260
1261        let dot = backend.dot(&a, &b)?;
1262
1263        let a = backend.to_host_v(a)?;
1264        let b = backend.to_host_v(b)?;
1265        let expect = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum::<f32>();
1266        assert_approx_eq::assert_approx_eq!(dot, expect, EPS);
1267        Ok(())
1268    }
1269
1270    #[rstest::rstest]
1271    #[test]
1272    #[cfg_attr(miri, ignore)]
1273    fn test_dot_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1274        let a = make_random_cv(&backend, N)?;
1275        let b = make_random_cv(&backend, N)?;
1276
1277        let dot = backend.dot_c(&a, &b)?;
1278
1279        let a = backend.to_host_cv(a)?;
1280        let b = backend.to_host_cv(b)?;
1281        let expect = a
1282            .iter()
1283            .zip(b.iter())
1284            .map(|(a, b)| a.conj() * b)
1285            .sum::<Complex>();
1286        assert_approx_eq::assert_approx_eq!(dot.re, expect.re, EPS);
1287        assert_approx_eq::assert_approx_eq!(dot.im, expect.im, EPS);
1288        Ok(())
1289    }
1290
1291    #[rstest::rstest]
1292    #[test]
1293    #[cfg_attr(miri, ignore)]
1294    fn test_add_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1295        let a = make_random_v(&backend, N)?;
1296        let mut b = make_random_v(&backend, N)?;
1297        let bc = backend.clone_v(&b)?;
1298
1299        let mut rng = rand::rng();
1300        let alpha = rng.random();
1301
1302        backend.add_v(alpha, &a, &mut b)?;
1303
1304        let a = backend.to_host_v(a)?;
1305        let b = backend.to_host_v(b)?;
1306        let bc = backend.to_host_v(bc)?;
1307        b.iter()
1308            .zip(a.iter())
1309            .zip(bc.iter())
1310            .for_each(|((b, a), bc)| {
1311                assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1312            });
1313        Ok(())
1314    }
1315
1316    #[rstest::rstest]
1317    #[test]
1318    #[cfg_attr(miri, ignore)]
1319    fn test_add_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1320        let a = make_random_m(&backend, N, N)?;
1321        let mut b = make_random_m(&backend, N, N)?;
1322        let bc = backend.clone_m(&b)?;
1323
1324        let mut rng = rand::rng();
1325        let alpha = rng.random();
1326
1327        backend.add_m(alpha, &a, &mut b)?;
1328
1329        let a = backend.to_host_m(a)?;
1330        let b = backend.to_host_m(b)?;
1331        let bc = backend.to_host_m(bc)?;
1332        b.iter()
1333            .zip(a.iter())
1334            .zip(bc.iter())
1335            .for_each(|((b, a), bc)| {
1336                assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1337            });
1338        Ok(())
1339    }
1340
1341    #[rstest::rstest]
1342    #[test]
1343    #[cfg_attr(miri, ignore)]
1344    fn test_gevv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1345        let mut rng = rand::rng();
1346
1347        {
1348            let a = make_random_cv(&backend, N)?;
1349            let b = make_random_cv(&backend, N)?;
1350            let mut c = make_random_cm(&backend, N, N)?;
1351            let cc = backend.clone_cm(&c)?;
1352
1353            let alpha = Complex::new(rng.random(), rng.random());
1354            let beta = Complex::new(rng.random(), rng.random());
1355            backend.gevv_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1356
1357            let a = backend.to_host_cv(a)?;
1358            let b = backend.to_host_cv(b)?;
1359            let c = backend.to_host_cm(c)?;
1360            let cc = backend.to_host_cm(cc)?;
1361            let expected = a * b.transpose() * alpha + cc * beta;
1362            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1363                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1364                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1365            });
1366        }
1367
1368        {
1369            let a = make_random_cv(&backend, N)?;
1370            let b = make_random_cv(&backend, N)?;
1371            let mut c = make_random_cm(&backend, N, N)?;
1372            let cc = backend.clone_cm(&c)?;
1373
1374            let alpha = Complex::new(rng.random(), rng.random());
1375            let beta = Complex::new(rng.random(), rng.random());
1376            backend.gevv_c(
1377                Trans::NoTrans,
1378                Trans::ConjTrans,
1379                alpha,
1380                &a,
1381                &b,
1382                beta,
1383                &mut c,
1384            )?;
1385
1386            let a = backend.to_host_cv(a)?;
1387            let b = backend.to_host_cv(b)?;
1388            let c = backend.to_host_cm(c)?;
1389            let cc = backend.to_host_cm(cc)?;
1390            let expected = a * b.adjoint() * alpha + cc * beta;
1391            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1392                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1393                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1394            });
1395        }
1396
1397        {
1398            let a = make_random_cv(&backend, N)?;
1399            let b = make_random_cv(&backend, N)?;
1400            let mut c = make_random_cm(&backend, 1, 1)?;
1401            let cc = backend.clone_cm(&c)?;
1402
1403            let alpha = Complex::new(rng.random(), rng.random());
1404            let beta = Complex::new(rng.random(), rng.random());
1405            backend.gevv_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1406
1407            let a = backend.to_host_cv(a)?;
1408            let b = backend.to_host_cv(b)?;
1409            let c = backend.to_host_cm(c)?;
1410            let cc = backend.to_host_cm(cc)?;
1411            let expected = a.transpose() * b * alpha + cc * beta;
1412            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1413                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1414                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1415            });
1416        }
1417
1418        {
1419            let a = make_random_cv(&backend, N)?;
1420            let b = make_random_cv(&backend, N)?;
1421            let mut c = make_random_cm(&backend, 1, 1)?;
1422            let cc = backend.clone_cm(&c)?;
1423
1424            let alpha = Complex::new(rng.random(), rng.random());
1425            let beta = Complex::new(rng.random(), rng.random());
1426            backend.gevv_c(
1427                Trans::ConjTrans,
1428                Trans::NoTrans,
1429                alpha,
1430                &a,
1431                &b,
1432                beta,
1433                &mut c,
1434            )?;
1435
1436            let a = backend.to_host_cv(a)?;
1437            let b = backend.to_host_cv(b)?;
1438            let c = backend.to_host_cm(c)?;
1439            let cc = backend.to_host_cm(cc)?;
1440            let expected = a.adjoint() * b * alpha + cc * beta;
1441            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1442                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1443                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1444            });
1445        }
1446
1447        Ok(())
1448    }
1449
1450    #[rstest::rstest]
1451    #[test]
1452    #[cfg_attr(miri, ignore)]
1453    fn test_gemv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1454        let m = N;
1455        let n = 2 * N;
1456
1457        let mut rng = rand::rng();
1458
1459        {
1460            let a = make_random_cm(&backend, m, n)?;
1461            let b = make_random_cv(&backend, n)?;
1462            let mut c = make_random_cv(&backend, m)?;
1463            let cc = backend.clone_cv(&c)?;
1464
1465            let alpha = Complex::new(rng.random(), rng.random());
1466            let beta = Complex::new(rng.random(), rng.random());
1467            backend.gemv_c(Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1468
1469            let a = backend.to_host_cm(a)?;
1470            let b = backend.to_host_cv(b)?;
1471            let c = backend.to_host_cv(c)?;
1472            let cc = backend.to_host_cv(cc)?;
1473            let expected = a * b * alpha + cc * beta;
1474            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1475                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1476                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1477            });
1478        }
1479
1480        {
1481            let a = make_random_cm(&backend, n, m)?;
1482            let b = make_random_cv(&backend, n)?;
1483            let mut c = make_random_cv(&backend, m)?;
1484            let cc = backend.clone_cv(&c)?;
1485
1486            let alpha = Complex::new(rng.random(), rng.random());
1487            let beta = Complex::new(rng.random(), rng.random());
1488            backend.gemv_c(Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1489
1490            let a = backend.to_host_cm(a)?;
1491            let b = backend.to_host_cv(b)?;
1492            let c = backend.to_host_cv(c)?;
1493            let cc = backend.to_host_cv(cc)?;
1494            let expected = a.transpose() * b * alpha + cc * beta;
1495            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1496                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1497                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1498            });
1499        }
1500
1501        {
1502            let a = make_random_cm(&backend, n, m)?;
1503            let b = make_random_cv(&backend, n)?;
1504            let mut c = make_random_cv(&backend, m)?;
1505            let cc = backend.clone_cv(&c)?;
1506
1507            let alpha = Complex::new(rng.random(), rng.random());
1508            let beta = Complex::new(rng.random(), rng.random());
1509            backend.gemv_c(Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1510
1511            let a = backend.to_host_cm(a)?;
1512            let b = backend.to_host_cv(b)?;
1513            let c = backend.to_host_cv(c)?;
1514            let cc = backend.to_host_cv(cc)?;
1515            let expected = a.adjoint() * b * alpha + cc * beta;
1516            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1517                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1518                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1519            });
1520        }
1521        Ok(())
1522    }
1523
1524    #[rstest::rstest]
1525    #[test]
1526    #[cfg_attr(miri, ignore)]
1527    fn test_gemm_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1528        let m = N;
1529        let n = 2 * N;
1530        let k = 3 * N;
1531
1532        let mut rng = rand::rng();
1533
1534        {
1535            let a = make_random_cm(&backend, m, k)?;
1536            let b = make_random_cm(&backend, k, n)?;
1537            let mut c = make_random_cm(&backend, m, n)?;
1538            let cc = backend.clone_cm(&c)?;
1539
1540            let alpha = Complex::new(rng.random(), rng.random());
1541            let beta = Complex::new(rng.random(), rng.random());
1542            backend.gemm_c(Trans::NoTrans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1543
1544            let a = backend.to_host_cm(a)?;
1545            let b = backend.to_host_cm(b)?;
1546            let c = backend.to_host_cm(c)?;
1547            let cc = backend.to_host_cm(cc)?;
1548            let expected = a * b * alpha + cc * beta;
1549            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1550                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1551                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1552            });
1553        }
1554
1555        {
1556            let a = make_random_cm(&backend, m, k)?;
1557            let b = make_random_cm(&backend, n, k)?;
1558            let mut c = make_random_cm(&backend, m, n)?;
1559            let cc = backend.clone_cm(&c)?;
1560
1561            let alpha = Complex::new(rng.random(), rng.random());
1562            let beta = Complex::new(rng.random(), rng.random());
1563            backend.gemm_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1564
1565            let a = backend.to_host_cm(a)?;
1566            let b = backend.to_host_cm(b)?;
1567            let c = backend.to_host_cm(c)?;
1568            let cc = backend.to_host_cm(cc)?;
1569            let expected = a * b.transpose() * alpha + cc * beta;
1570            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1571                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1572                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1573            });
1574        }
1575
1576        {
1577            let a = make_random_cm(&backend, m, k)?;
1578            let b = make_random_cm(&backend, n, k)?;
1579            let mut c = make_random_cm(&backend, m, n)?;
1580            let cc = backend.clone_cm(&c)?;
1581
1582            let alpha = Complex::new(rng.random(), rng.random());
1583            let beta = Complex::new(rng.random(), rng.random());
1584            backend.gemm_c(
1585                Trans::NoTrans,
1586                Trans::ConjTrans,
1587                alpha,
1588                &a,
1589                &b,
1590                beta,
1591                &mut c,
1592            )?;
1593
1594            let a = backend.to_host_cm(a)?;
1595            let b = backend.to_host_cm(b)?;
1596            let c = backend.to_host_cm(c)?;
1597            let cc = backend.to_host_cm(cc)?;
1598            let expected = a * b.adjoint() * alpha + cc * beta;
1599            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1600                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1601                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1602            });
1603        }
1604
1605        {
1606            let a = make_random_cm(&backend, k, m)?;
1607            let b = make_random_cm(&backend, k, n)?;
1608            let mut c = make_random_cm(&backend, m, n)?;
1609            let cc = backend.clone_cm(&c)?;
1610
1611            let alpha = Complex::new(rng.random(), rng.random());
1612            let beta = Complex::new(rng.random(), rng.random());
1613            backend.gemm_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1614
1615            let a = backend.to_host_cm(a)?;
1616            let b = backend.to_host_cm(b)?;
1617            let c = backend.to_host_cm(c)?;
1618            let cc = backend.to_host_cm(cc)?;
1619            let expected = a.transpose() * b * alpha + cc * beta;
1620            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1621                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1622                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1623            });
1624        }
1625
1626        {
1627            let a = make_random_cm(&backend, k, m)?;
1628            let b = make_random_cm(&backend, n, k)?;
1629            let mut c = make_random_cm(&backend, m, n)?;
1630            let cc = backend.clone_cm(&c)?;
1631
1632            let alpha = Complex::new(rng.random(), rng.random());
1633            let beta = Complex::new(rng.random(), rng.random());
1634            backend.gemm_c(Trans::Trans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1635
1636            let a = backend.to_host_cm(a)?;
1637            let b = backend.to_host_cm(b)?;
1638            let c = backend.to_host_cm(c)?;
1639            let cc = backend.to_host_cm(cc)?;
1640            let expected = a.transpose() * b.transpose() * alpha + cc * beta;
1641            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1642                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1643                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1644            });
1645        }
1646
1647        {
1648            let a = make_random_cm(&backend, k, m)?;
1649            let b = make_random_cm(&backend, n, k)?;
1650            let mut c = make_random_cm(&backend, m, n)?;
1651            let cc = backend.clone_cm(&c)?;
1652
1653            let alpha = Complex::new(rng.random(), rng.random());
1654            let beta = Complex::new(rng.random(), rng.random());
1655            backend.gemm_c(Trans::Trans, Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1656
1657            let a = backend.to_host_cm(a)?;
1658            let b = backend.to_host_cm(b)?;
1659            let c = backend.to_host_cm(c)?;
1660            let cc = backend.to_host_cm(cc)?;
1661            let expected = a.transpose() * b.adjoint() * alpha + cc * beta;
1662            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1663                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1664                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1665            });
1666        }
1667
1668        {
1669            let a = make_random_cm(&backend, k, m)?;
1670            let b = make_random_cm(&backend, k, n)?;
1671            let mut c = make_random_cm(&backend, m, n)?;
1672            let cc = backend.clone_cm(&c)?;
1673
1674            let alpha = Complex::new(rng.random(), rng.random());
1675            let beta = Complex::new(rng.random(), rng.random());
1676            backend.gemm_c(
1677                Trans::ConjTrans,
1678                Trans::NoTrans,
1679                alpha,
1680                &a,
1681                &b,
1682                beta,
1683                &mut c,
1684            )?;
1685
1686            let a = backend.to_host_cm(a)?;
1687            let b = backend.to_host_cm(b)?;
1688            let c = backend.to_host_cm(c)?;
1689            let cc = backend.to_host_cm(cc)?;
1690            let expected = a.adjoint() * b * alpha + cc * beta;
1691            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1692                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1693                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1694            });
1695        }
1696
1697        {
1698            let a = make_random_cm(&backend, k, m)?;
1699            let b = make_random_cm(&backend, n, k)?;
1700            let mut c = make_random_cm(&backend, m, n)?;
1701            let cc = backend.clone_cm(&c)?;
1702
1703            let alpha = Complex::new(rng.random(), rng.random());
1704            let beta = Complex::new(rng.random(), rng.random());
1705            backend.gemm_c(Trans::ConjTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1706
1707            let a = backend.to_host_cm(a)?;
1708            let b = backend.to_host_cm(b)?;
1709            let c = backend.to_host_cm(c)?;
1710            let cc = backend.to_host_cm(cc)?;
1711            let expected = a.adjoint() * b.transpose() * alpha + cc * beta;
1712            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1713                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1714                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1715            });
1716        }
1717
1718        {
1719            let a = make_random_cm(&backend, k, m)?;
1720            let b = make_random_cm(&backend, n, k)?;
1721            let mut c = make_random_cm(&backend, m, n)?;
1722            let cc = backend.clone_cm(&c)?;
1723
1724            let alpha = Complex::new(rng.random(), rng.random());
1725            let beta = Complex::new(rng.random(), rng.random());
1726            backend.gemm_c(
1727                Trans::ConjTrans,
1728                Trans::ConjTrans,
1729                alpha,
1730                &a,
1731                &b,
1732                beta,
1733                &mut c,
1734            )?;
1735
1736            let a = backend.to_host_cm(a)?;
1737            let b = backend.to_host_cm(b)?;
1738            let c = backend.to_host_cm(c)?;
1739            let cc = backend.to_host_cm(cc)?;
1740            let expected = a.adjoint() * b.adjoint() * alpha + cc * beta;
1741            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1742                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1743                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1744            });
1745        }
1746        Ok(())
1747    }
1748
1749    #[rstest::rstest]
1750    #[test]
1751    #[cfg_attr(miri, ignore)]
1752    fn test_solve_inplace(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1753        {
1754            let tmp = make_random_m(&backend, N, N)?;
1755            let tmp = backend.to_host_m(tmp)?;
1756
1757            let a = &tmp * tmp.adjoint();
1758
1759            let mut rng = rand::rng();
1760            let x = VectorX::from_iterator(N, (0..N).map(|_| rng.random()));
1761
1762            let b = &a * &x;
1763
1764            let aa = backend.from_slice_m(N, N, a.as_slice())?;
1765            let mut bb = backend.from_slice_v(b.as_slice())?;
1766
1767            backend.solve_inplace(&aa, &mut bb)?;
1768
1769            let b2 = &a * backend.to_host_v(bb)?;
1770            assert!(approx::relative_eq!(b, b2, epsilon = 1e-3));
1771        }
1772
1773        Ok(())
1774    }
1775
1776    #[rstest::rstest]
1777    #[test]
1778    #[cfg_attr(miri, ignore)]
1779    fn test_reduce_col(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1780        let a = make_random_m(&backend, N, N)?;
1781
1782        let mut b = backend.alloc_v(N)?;
1783
1784        backend.reduce_col(&a, &mut b)?;
1785
1786        let a = backend.to_host_m(a)?;
1787        let b = backend.to_host_v(b)?;
1788
1789        (0..N).for_each(|row| {
1790            let sum = a.row(row).iter().sum::<f32>();
1791            assert_approx_eq::assert_approx_eq!(sum, b[row], EPS);
1792        });
1793        Ok(())
1794    }
1795
1796    #[rstest::rstest]
1797    #[test]
1798    #[cfg_attr(miri, ignore)]
1799    fn test_scaled_to_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1800        let a = make_random_cv(&backend, N)?;
1801        let b = make_random_cv(&backend, N)?;
1802        let mut c = backend.alloc_cv(N)?;
1803
1804        backend.scaled_to_cv(&a, &b, &mut c)?;
1805
1806        let a = backend.to_host_cv(a)?;
1807        let b = backend.to_host_cv(b)?;
1808        let c = backend.to_host_cv(c)?;
1809        c.iter()
1810            .zip(a.iter())
1811            .zip(b.iter())
1812            .for_each(|((&c, &a), &b)| {
1813                assert_approx_eq::assert_approx_eq!(c, a / a.abs() * b, EPS);
1814            });
1815
1816        Ok(())
1817    }
1818
1819    #[rstest::rstest]
1820    #[test]
1821    #[cfg_attr(miri, ignore)]
1822    fn test_scaled_to_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1823        let a = make_random_cv(&backend, N)?;
1824        let mut b = make_random_cv(&backend, N)?;
1825        let bc = backend.clone_cv(&b)?;
1826
1827        backend.scaled_to_assign_cv(&a, &mut b)?;
1828
1829        let a = backend.to_host_cv(a)?;
1830        let b = backend.to_host_cv(b)?;
1831        let bc = backend.to_host_cv(bc)?;
1832        b.iter()
1833            .zip(a.iter())
1834            .zip(bc.iter())
1835            .for_each(|((&b, &a), &bc)| {
1836                assert_approx_eq::assert_approx_eq!(b, bc / bc.abs() * a, EPS);
1837            });
1838
1839        Ok(())
1840    }
1841
1842    #[rstest::rstest]
1843    #[test]
1844    #[case(1, 2)]
1845    #[case(2, 1)]
1846    fn test_generate_propagation_matrix(
1847        #[case] dev_num: usize,
1848        #[case] foci_num: usize,
1849        backend: ArrayFireBackend<Sphere>,
1850    ) -> Result<(), HoloError> {
1851        let reference = |geometry: Geometry, foci: Vec<Point3>| {
1852            let mut g = MatrixXc::zeros(
1853                foci.len(),
1854                geometry
1855                    .iter()
1856                    .map(|dev| dev.num_transducers())
1857                    .sum::<usize>(),
1858            );
1859            let transducers = geometry
1860                .iter()
1861                .flat_map(|dev| dev.iter().map(|tr| (dev.idx(), tr)))
1862                .collect::<Vec<_>>();
1863            (0..foci.len()).for_each(|i| {
1864                (0..transducers.len()).for_each(|j| {
1865                    g[(i, j)] = propagate::<Sphere>(
1866                        transducers[j].1,
1867                        geometry[transducers[j].0].wavenumber(),
1868                        geometry[transducers[j].0].axial_direction(),
1869                        &foci[i],
1870                    )
1871                })
1872            });
1873            g
1874        };
1875
1876        let geometry = generate_geometry(dev_num);
1877        let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1878
1879        let g = backend.generate_propagation_matrix(&geometry, &foci, None)?;
1880        let g = backend.to_host_cm(g)?;
1881        reference(geometry, foci)
1882            .iter()
1883            .zip(g.iter())
1884            .for_each(|(r, g)| {
1885                assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1886                assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1887            });
1888
1889        Ok(())
1890    }
1891
1892    #[rstest::rstest]
1893    #[test]
1894    #[case(1, 2)]
1895    #[case(2, 1)]
1896    fn test_generate_propagation_matrix_with_filter(
1897        #[case] dev_num: usize,
1898        #[case] foci_num: usize,
1899        backend: ArrayFireBackend<Sphere>,
1900    ) -> Result<(), HoloError> {
1901        use std::collections::HashMap;
1902
1903        let filter = |geometry: &Geometry| {
1904            geometry
1905                .iter()
1906                .map(|dev| {
1907                    let mut filter = BitVec::new();
1908                    dev.iter().for_each(|tr| {
1909                        filter.push(tr.idx() > dev.num_transducers() / 2);
1910                    });
1911                    (dev.idx(), filter)
1912                })
1913                .collect::<HashMap<_, _>>()
1914        };
1915
1916        let reference = |geometry, foci: Vec<Point3>| {
1917            let filter = filter(&geometry);
1918            let transducers = geometry
1919                .iter()
1920                .flat_map(|dev| {
1921                    dev.iter().filter_map(|tr| {
1922                        if filter[&dev.idx()][tr.idx()] {
1923                            Some((dev.idx(), tr))
1924                        } else {
1925                            None
1926                        }
1927                    })
1928                })
1929                .collect::<Vec<_>>();
1930
1931            let mut g = MatrixXc::zeros(foci.len(), transducers.len());
1932            (0..foci.len()).for_each(|i| {
1933                (0..transducers.len()).for_each(|j| {
1934                    g[(i, j)] = propagate::<Sphere>(
1935                        transducers[j].1,
1936                        geometry[transducers[j].0].wavenumber(),
1937                        geometry[transducers[j].0].axial_direction(),
1938                        &foci[i],
1939                    )
1940                })
1941            });
1942            g
1943        };
1944
1945        let geometry = generate_geometry(dev_num);
1946        let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1947        let filter = filter(&geometry);
1948
1949        let g = backend.generate_propagation_matrix(&geometry, &foci, Some(&filter))?;
1950        let g = backend.to_host_cm(g)?;
1951        assert_eq!(g.nrows(), foci.len());
1952        assert_eq!(
1953            g.ncols(),
1954            geometry
1955                .iter()
1956                .map(|dev| dev.num_transducers() / 2)
1957                .sum::<usize>()
1958        );
1959        reference(geometry, foci)
1960            .iter()
1961            .zip(g.iter())
1962            .for_each(|(r, g)| {
1963                assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1964                assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1965            });
1966
1967        Ok(())
1968    }
1969
1970    #[rstest::rstest]
1971    #[test]
1972    fn test_gen_back_prop(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1973        let geometry = generate_geometry(1);
1974        let foci = gen_foci(2).map(|(p, _)| p).collect::<Vec<_>>();
1975
1976        let m = geometry
1977            .iter()
1978            .map(|dev| dev.num_transducers())
1979            .sum::<usize>();
1980        let n = foci.len();
1981
1982        let g = backend.generate_propagation_matrix(&geometry, &foci, None)?;
1983
1984        let b = backend.gen_back_prop(m, n, &g)?;
1985        let g = backend.to_host_cm(g)?;
1986        let reference = {
1987            let mut b = MatrixXc::zeros(m, n);
1988            (0..n).for_each(|i| {
1989                let x = 1.0 / g.rows(i, 1).iter().map(|x| x.norm_sqr()).sum::<f32>();
1990                (0..m).for_each(|j| {
1991                    b[(j, i)] = g[(i, j)].conj() * x;
1992                })
1993            });
1994            b
1995        };
1996
1997        let b = backend.to_host_cm(b)?;
1998        reference.iter().zip(b.iter()).for_each(|(r, b)| {
1999            assert_approx_eq::assert_approx_eq!(r.re, b.re, EPS);
2000            assert_approx_eq::assert_approx_eq!(r.im, b.im, EPS);
2001        });
2002        Ok(())
2003    }
2004}