autd3_backend_arrayfire/
lib.rs

1#![allow(unknown_lints)]
2#![allow(clippy::manual_slice_size_calculation)]
3
4use arrayfire::*;
5
6use autd3_core::{
7    acoustics::{
8        directivity::{Directivity, Sphere},
9        propagate,
10    },
11    environment::Environment,
12    gain::TransducerFilter,
13    geometry::Geometry,
14};
15use autd3_gain_holo::{
16    Complex, HoloError, LinAlgBackend, MatrixX, MatrixXc, Trans, VectorX, VectorXc,
17};
18
19pub type AFBackend = arrayfire::Backend;
20pub type AFDeviceInfo = (String, String, String, String);
21
22fn convert(trans: Trans) -> MatProp {
23    match trans {
24        Trans::NoTrans => MatProp::NONE,
25        Trans::Trans => MatProp::TRANS,
26        Trans::ConjTrans => MatProp::CTRANS,
27    }
28}
29
30pub struct ArrayFireBackend<D: Directivity> {
31    _phantom: std::marker::PhantomData<D>,
32}
33
34impl ArrayFireBackend<Sphere> {
35    pub fn get_available_backends() -> Vec<AFBackend> {
36        arrayfire::get_available_backends()
37    }
38
39    pub fn set_backend(backend: AFBackend) {
40        arrayfire::set_backend(backend);
41    }
42
43    pub fn set_device(device: i32) {
44        arrayfire::set_device(device);
45    }
46
47    pub fn get_available_devices() -> Vec<AFDeviceInfo> {
48        let cur_dev = arrayfire::get_device();
49        let r = (0..arrayfire::device_count())
50            .map(|i| {
51                arrayfire::set_device(i);
52                arrayfire::device_info()
53            })
54            .collect();
55        arrayfire::set_device(cur_dev);
56        r
57    }
58}
59
60impl Default for ArrayFireBackend<Sphere> {
61    fn default() -> Self {
62        Self {
63            _phantom: Default::default(),
64        }
65    }
66}
67
68impl<D: Directivity> ArrayFireBackend<D> {
69    pub fn new() -> Self {
70        Self {
71            _phantom: std::marker::PhantomData,
72        }
73    }
74}
75
76impl<D: Directivity> LinAlgBackend<D> for ArrayFireBackend<D> {
77    type MatrixXc = Array<c32>;
78    type MatrixX = Array<f32>;
79    type VectorXc = Array<c32>;
80    type VectorX = Array<f32>;
81
82    fn generate_propagation_matrix(
83        &self,
84        geometry: &Geometry,
85        env: &Environment,
86        foci: &[autd3_core::geometry::Point3],
87        filter: &TransducerFilter,
88    ) -> Result<Self::MatrixXc, HoloError> {
89        let g = if filter.is_all_enabled() {
90            geometry
91                .iter()
92                .flat_map(|dev| {
93                    dev.iter().flat_map(move |tr| {
94                        foci.iter().map(move |fp| {
95                            propagate::<D>(tr, env.wavenumber(), dev.axial_direction(), fp)
96                        })
97                    })
98                })
99                .collect::<Vec<_>>()
100        } else {
101            geometry
102                .iter()
103                .filter(|dev| filter.is_enabled_device(dev))
104                .flat_map(|dev| {
105                    dev.iter()
106                        .filter(|tr| filter.is_enabled(tr))
107                        .map(move |tr| {
108                            foci.iter().map(move |fp| {
109                                propagate::<D>(tr, env.wavenumber(), dev.axial_direction(), fp)
110                            })
111                        })
112                })
113                .flatten()
114                .collect::<Vec<_>>()
115        };
116
117        unsafe {
118            Ok(Array::new(
119                std::slice::from_raw_parts(g.as_ptr() as *const c32, g.len()),
120                Dim4::new(&[foci.len() as u64, (g.len() / foci.len()) as _, 1, 1]),
121            ))
122        }
123    }
124
125    fn alloc_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
126        Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
127    }
128
129    fn alloc_m(&self, rows: usize, cols: usize) -> Result<Self::MatrixX, HoloError> {
130        Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
131    }
132
133    fn alloc_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
134        Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
135    }
136
137    fn alloc_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
138        Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
139    }
140
141    fn alloc_zeros_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
142        Ok(arrayfire::constant(0., Dim4::new(&[size as _, 1, 1, 1])))
143    }
144
145    fn alloc_zeros_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
146        Ok(arrayfire::constant(
147            c32::new(0., 0.),
148            Dim4::new(&[size as _, 1, 1, 1]),
149        ))
150    }
151
152    fn alloc_zeros_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
153        Ok(arrayfire::constant(
154            c32::new(0., 0.),
155            Dim4::new(&[rows as _, cols as _, 1, 1]),
156        ))
157    }
158
159    fn to_host_v(&self, v: Self::VectorX) -> Result<VectorX, HoloError> {
160        let mut r = VectorX::zeros(v.elements());
161        v.host(r.as_mut_slice());
162        Ok(r)
163    }
164
165    fn to_host_m(&self, v: Self::MatrixX) -> Result<MatrixX, HoloError> {
166        let mut r = MatrixX::zeros(v.dims()[0] as _, v.dims()[1] as _);
167        v.host(r.as_mut_slice());
168        Ok(r)
169    }
170
171    fn to_host_cv(&self, v: Self::VectorXc) -> Result<VectorXc, HoloError> {
172        let n = v.elements();
173        let mut r = VectorXc::zeros(n);
174        unsafe {
175            v.host(std::slice::from_raw_parts_mut(
176                r.as_mut_ptr() as *mut c32,
177                n,
178            ));
179        }
180        Ok(r)
181    }
182
183    fn to_host_cm(&self, v: Self::MatrixXc) -> Result<MatrixXc, HoloError> {
184        let n = v.elements();
185        let mut r = MatrixXc::zeros(v.dims()[0] as _, v.dims()[1] as _);
186        unsafe {
187            v.host(std::slice::from_raw_parts_mut(
188                r.as_mut_ptr() as *mut c32,
189                n,
190            ));
191        }
192        Ok(r)
193    }
194
195    fn from_slice_v(&self, v: &[f32]) -> Result<Self::VectorX, HoloError> {
196        Ok(Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1])))
197    }
198
199    fn from_slice_m(
200        &self,
201        rows: usize,
202        cols: usize,
203        v: &[f32],
204    ) -> Result<Self::MatrixX, HoloError> {
205        Ok(Array::new(v, Dim4::new(&[rows as _, cols as _, 1, 1])))
206    }
207
208    fn from_slice_cv(&self, v: &[f32]) -> Result<Self::VectorXc, HoloError> {
209        let r = Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1]));
210        Ok(arrayfire::cplx(&r))
211    }
212
213    fn from_slice2_cv(&self, r: &[f32], i: &[f32]) -> Result<Self::VectorXc, HoloError> {
214        let r = Array::new(r, Dim4::new(&[r.len() as _, 1, 1, 1]));
215        let i = Array::new(i, Dim4::new(&[i.len() as _, 1, 1, 1]));
216        Ok(arrayfire::cplx2(&r, &i, false).cast())
217    }
218
219    fn from_slice2_cm(
220        &self,
221        rows: usize,
222        cols: usize,
223        r: &[f32],
224        i: &[f32],
225    ) -> Result<Self::MatrixXc, HoloError> {
226        let r = Array::new(r, Dim4::new(&[rows as _, cols as _, 1, 1]));
227        let i = Array::new(i, Dim4::new(&[rows as _, cols as _, 1, 1]));
228        Ok(arrayfire::cplx2(&r, &i, false).cast())
229    }
230
231    fn copy_from_slice_v(&self, v: &[f32], dst: &mut Self::VectorX) -> Result<(), HoloError> {
232        let n = v.len();
233        if n == 0 {
234            return Ok(());
235        }
236        let v = self.from_slice_v(v)?;
237        let seqs = [Seq::new(0u32, n as u32 - 1, 1)];
238        arrayfire::assign_seq(dst, &seqs, &v);
239        Ok(())
240    }
241
242    fn copy_to_v(&self, src: &Self::VectorX, dst: &mut Self::VectorX) -> Result<(), HoloError> {
243        let seqs = [Seq::new(0u32, src.elements() as u32 - 1, 1)];
244        arrayfire::assign_seq(dst, &seqs, src);
245        Ok(())
246    }
247
248    fn copy_to_m(&self, src: &Self::MatrixX, dst: &mut Self::MatrixX) -> Result<(), HoloError> {
249        let seqs = [
250            Seq::new(0u32, src.dims()[0] as u32 - 1, 1),
251            Seq::new(0u32, src.dims()[1] as u32 - 1, 1),
252        ];
253        arrayfire::assign_seq(dst, &seqs, src);
254        Ok(())
255    }
256
257    fn clone_v(&self, v: &Self::VectorX) -> Result<Self::VectorX, HoloError> {
258        Ok(v.copy())
259    }
260
261    fn clone_m(&self, v: &Self::MatrixX) -> Result<Self::MatrixX, HoloError> {
262        Ok(v.copy())
263    }
264
265    fn clone_cv(&self, v: &Self::VectorXc) -> Result<Self::VectorXc, HoloError> {
266        Ok(v.copy())
267    }
268
269    fn clone_cm(&self, v: &Self::MatrixXc) -> Result<Self::MatrixXc, HoloError> {
270        Ok(v.copy())
271    }
272
273    fn make_complex2_v(
274        &self,
275        real: &Self::VectorX,
276        imag: &Self::VectorX,
277        v: &mut Self::VectorXc,
278    ) -> Result<(), HoloError> {
279        *v = arrayfire::cplx2(real, imag, false).cast();
280        Ok(())
281    }
282
283    fn create_diagonal(&self, v: &Self::VectorX, a: &mut Self::MatrixX) -> Result<(), HoloError> {
284        *a = arrayfire::diag_create(v, 0);
285        Ok(())
286    }
287
288    fn create_diagonal_c(
289        &self,
290        v: &Self::VectorXc,
291        a: &mut Self::MatrixXc,
292    ) -> Result<(), HoloError> {
293        *a = arrayfire::diag_create(v, 0);
294        Ok(())
295    }
296
297    fn get_diagonal(&self, a: &Self::MatrixX, v: &mut Self::VectorX) -> Result<(), HoloError> {
298        *v = arrayfire::diag_extract(a, 0);
299        Ok(())
300    }
301
302    fn real_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
303        *b = arrayfire::real(a);
304        Ok(())
305    }
306
307    fn imag_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
308        *b = arrayfire::imag(a);
309        Ok(())
310    }
311
312    fn scale_assign_cv(
313        &self,
314        a: autd3_gain_holo::Complex,
315        b: &mut Self::VectorXc,
316    ) -> Result<(), HoloError> {
317        let a = c32::new(a.re, a.im);
318        *b = arrayfire::mul(b, &a, false);
319        Ok(())
320    }
321
322    fn conj_assign_v(&self, b: &mut Self::VectorXc) -> Result<(), HoloError> {
323        *b = arrayfire::conjg(b);
324        Ok(())
325    }
326
327    fn exp_assign_cv(&self, v: &mut Self::VectorXc) -> Result<(), HoloError> {
328        *v = arrayfire::exp(v);
329        Ok(())
330    }
331
332    fn concat_col_cm(
333        &self,
334        a: &Self::MatrixXc,
335        b: &Self::MatrixXc,
336        c: &mut Self::MatrixXc,
337    ) -> Result<(), HoloError> {
338        *c = arrayfire::join(1, a, b);
339        Ok(())
340    }
341
342    fn max_v(&self, m: &Self::VectorX) -> Result<f32, HoloError> {
343        Ok(arrayfire::max_all(m).0)
344    }
345
346    fn hadamard_product_cm(
347        &self,
348        x: &Self::MatrixXc,
349        y: &Self::MatrixXc,
350        z: &mut Self::MatrixXc,
351    ) -> Result<(), HoloError> {
352        *z = arrayfire::mul(x, y, false);
353        Ok(())
354    }
355
356    fn dot(&self, x: &Self::VectorX, y: &Self::VectorX) -> Result<f32, HoloError> {
357        let r = arrayfire::dot(x, y, MatProp::NONE, MatProp::NONE);
358        let mut v = [0.];
359        r.host(&mut v);
360        Ok(v[0])
361    }
362
363    fn dot_c(
364        &self,
365        x: &Self::VectorXc,
366        y: &Self::VectorXc,
367    ) -> Result<autd3_gain_holo::Complex, HoloError> {
368        let r = arrayfire::dot(x, y, MatProp::CONJ, MatProp::NONE);
369        let mut v = [c32::new(0., 0.)];
370        r.host(&mut v);
371        Ok(autd3_gain_holo::Complex::new(v[0].re, v[0].im))
372    }
373
374    fn add_v(&self, alpha: f32, a: &Self::VectorX, b: &mut Self::VectorX) -> Result<(), HoloError> {
375        *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
376        Ok(())
377    }
378
379    fn add_m(&self, alpha: f32, a: &Self::MatrixX, b: &mut Self::MatrixX) -> Result<(), HoloError> {
380        *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
381        Ok(())
382    }
383
384    fn gevv_c(
385        &self,
386        trans_a: autd3_gain_holo::Trans,
387        trans_b: autd3_gain_holo::Trans,
388        alpha: autd3_gain_holo::Complex,
389        a: &Self::VectorXc,
390        x: &Self::VectorXc,
391        beta: autd3_gain_holo::Complex,
392        y: &mut Self::MatrixXc,
393    ) -> Result<(), HoloError> {
394        let alpha = vec![c32::new(alpha.re, alpha.im)];
395        let beta = vec![c32::new(beta.re, beta.im)];
396        let trans_a = convert(trans_a);
397        let trans_b = convert(trans_b);
398        arrayfire::gemm(y, trans_a, trans_b, alpha, a, x, beta);
399        Ok(())
400    }
401
402    fn gemv_c(
403        &self,
404        trans: autd3_gain_holo::Trans,
405        alpha: autd3_gain_holo::Complex,
406        a: &Self::MatrixXc,
407        x: &Self::VectorXc,
408        beta: autd3_gain_holo::Complex,
409        y: &mut Self::VectorXc,
410    ) -> Result<(), HoloError> {
411        let alpha = vec![c32::new(alpha.re, alpha.im)];
412        let beta = vec![c32::new(beta.re, beta.im)];
413        let trans = convert(trans);
414        arrayfire::gemm(y, trans, MatProp::NONE, alpha, a, x, beta);
415        Ok(())
416    }
417
418    fn gemm_c(
419        &self,
420        trans_a: autd3_gain_holo::Trans,
421        trans_b: autd3_gain_holo::Trans,
422        alpha: autd3_gain_holo::Complex,
423        a: &Self::MatrixXc,
424        b: &Self::MatrixXc,
425        beta: autd3_gain_holo::Complex,
426        y: &mut Self::MatrixXc,
427    ) -> Result<(), HoloError> {
428        let alpha = vec![c32::new(alpha.re, alpha.im)];
429        let beta = vec![c32::new(beta.re, beta.im)];
430        let trans_a = convert(trans_a);
431        let trans_b = convert(trans_b);
432        arrayfire::gemm(y, trans_a, trans_b, alpha, a, b, beta);
433        Ok(())
434    }
435
436    fn solve_inplace(&self, a: &Self::MatrixX, x: &mut Self::VectorX) -> Result<(), HoloError> {
437        *x = arrayfire::solve(a, x, MatProp::NONE);
438        Ok(())
439    }
440
441    fn reduce_col(&self, a: &Self::MatrixX, b: &mut Self::VectorX) -> Result<(), HoloError> {
442        *b = arrayfire::sum(a, 1);
443        Ok(())
444    }
445
446    fn cols_c(&self, m: &Self::MatrixXc) -> Result<usize, HoloError> {
447        Ok(m.dims()[1] as _)
448    }
449
450    fn scaled_to_cv(
451        &self,
452        a: &Self::VectorXc,
453        b: &Self::VectorXc,
454        c: &mut Self::VectorXc,
455    ) -> Result<(), HoloError> {
456        let tmp = arrayfire::div(a, &arrayfire::abs(a), false);
457        *c = arrayfire::mul(&tmp, b, false);
458        Ok(())
459    }
460
461    fn scaled_to_assign_cv(
462        &self,
463        a: &Self::VectorXc,
464        b: &mut Self::VectorXc,
465    ) -> Result<(), HoloError> {
466        *b = arrayfire::div(b, &arrayfire::abs(b), false);
467        *b = arrayfire::mul(a, b, false);
468        Ok(())
469    }
470
471    fn gen_back_prop(
472        &self,
473        m: usize,
474        n: usize,
475        transfer: &Self::MatrixXc,
476    ) -> Result<Self::MatrixXc, HoloError> {
477        let mut b = self.alloc_zeros_cm(m, n)?;
478
479        let mut tmp = self.alloc_zeros_cm(n, n)?;
480
481        self.gemm_c(
482            Trans::NoTrans,
483            Trans::ConjTrans,
484            Complex::new(1., 0.),
485            transfer,
486            transfer,
487            Complex::new(0., 0.),
488            &mut tmp,
489        )?;
490
491        let mut denominator = arrayfire::diag_extract(&tmp, 0);
492        let a = c32::new(1., 0.);
493        denominator = arrayfire::div(&a, &denominator, false);
494
495        self.create_diagonal_c(&denominator, &mut tmp)?;
496
497        self.gemm_c(
498            Trans::ConjTrans,
499            Trans::NoTrans,
500            Complex::new(1., 0.),
501            transfer,
502            &tmp,
503            Complex::new(0., 0.),
504            &mut b,
505        )?;
506
507        Ok(b)
508    }
509
510    fn norm_squared_cv(&self, a: &Self::VectorXc, b: &mut Self::VectorX) -> Result<(), HoloError> {
511        *b = arrayfire::abs(a);
512        *b = arrayfire::mul(b, b, false);
513        Ok(())
514    }
515}
516#[cfg(test)]
517mod tests {
518    use std::f32::consts::PI;
519
520    use autd3::driver::autd3_device::AUTD3;
521    use autd3_core::{
522        acoustics::directivity::Sphere,
523        geometry::{Point3, Transducer, UnitQuaternion},
524    };
525
526    use nalgebra::{ComplexField, Normed};
527
528    use autd3_gain_holo::{Amplitude, Pa, Trans};
529
530    use super::*;
531
532    use rand::Rng;
533
534    const N: usize = 10;
535    const EPS: f32 = 1e-3;
536
537    fn generate_geometry(size: usize) -> Geometry {
538        Geometry::new(
539            (0..size)
540                .flat_map(|i| {
541                    (0..size).map(move |j| {
542                        AUTD3 {
543                            pos: Point3::new(
544                                i as f32 * AUTD3::DEVICE_WIDTH,
545                                j as f32 * AUTD3::DEVICE_HEIGHT,
546                                0.,
547                            ),
548                            rot: UnitQuaternion::identity(),
549                        }
550                        .into()
551                    })
552                })
553                .collect(),
554        )
555    }
556
557    fn gen_foci(n: usize) -> impl Iterator<Item = (Point3, Amplitude)> {
558        (0..n).map(move |i| {
559            (
560                Point3::new(
561                    90. + 10. * (2.0 * PI * i as f32 / n as f32).cos(),
562                    70. + 10. * (2.0 * PI * i as f32 / n as f32).sin(),
563                    150.,
564                ),
565                10e3 * Pa,
566            )
567        })
568    }
569
570    fn make_random_v(
571        backend: &ArrayFireBackend<Sphere>,
572        size: usize,
573    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorX, HoloError> {
574        let mut rng = rand::rng();
575        let v: Vec<f32> = (&mut rng)
576            .sample_iter(rand::distr::StandardUniform)
577            .take(size)
578            .collect();
579        backend.from_slice_v(&v)
580    }
581
582    fn make_random_m(
583        backend: &ArrayFireBackend<Sphere>,
584        rows: usize,
585        cols: usize,
586    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixX, HoloError> {
587        let mut rng = rand::rng();
588        let v: Vec<f32> = (&mut rng)
589            .sample_iter(rand::distr::StandardUniform)
590            .take(rows * cols)
591            .collect();
592        backend.from_slice_m(rows, cols, &v)
593    }
594
595    fn make_random_cv(
596        backend: &ArrayFireBackend<Sphere>,
597        size: usize,
598    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorXc, HoloError> {
599        let mut rng = rand::rng();
600        let real: Vec<f32> = (&mut rng)
601            .sample_iter(rand::distr::StandardUniform)
602            .take(size)
603            .collect();
604        let imag: Vec<f32> = (&mut rng)
605            .sample_iter(rand::distr::StandardUniform)
606            .take(size)
607            .collect();
608        backend.from_slice2_cv(&real, &imag)
609    }
610
611    fn make_random_cm(
612        backend: &ArrayFireBackend<Sphere>,
613        rows: usize,
614        cols: usize,
615    ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixXc, HoloError> {
616        let mut rng = rand::rng();
617        let real: Vec<f32> = (&mut rng)
618            .sample_iter(rand::distr::StandardUniform)
619            .take(rows * cols)
620            .collect();
621        let imag: Vec<f32> = (&mut rng)
622            .sample_iter(rand::distr::StandardUniform)
623            .take(rows * cols)
624            .collect();
625        backend.from_slice2_cm(rows, cols, &real, &imag)
626    }
627
628    #[rstest::fixture]
629    fn backend() -> ArrayFireBackend<Sphere> {
630        ArrayFireBackend::set_backend(AFBackend::CPU);
631        ArrayFireBackend {
632            _phantom: std::marker::PhantomData,
633        }
634    }
635
636    #[rstest::rstest]
637    #[test]
638    #[cfg_attr(miri, ignore)]
639    fn test_alloc_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
640        let v = backend.alloc_v(N)?;
641        let v = backend.to_host_v(v)?;
642
643        assert_eq!(N, v.len());
644        Ok(())
645    }
646
647    #[rstest::rstest]
648    #[test]
649    #[cfg_attr(miri, ignore)]
650    fn test_alloc_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
651        let m = backend.alloc_m(N, 2 * N)?;
652        let m = backend.to_host_m(m)?;
653
654        assert_eq!(N, m.nrows());
655        assert_eq!(2 * N, m.ncols());
656        Ok(())
657    }
658
659    #[rstest::rstest]
660    #[test]
661    #[cfg_attr(miri, ignore)]
662    fn test_alloc_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
663        let v = backend.alloc_cv(N)?;
664        let v = backend.to_host_cv(v)?;
665
666        assert_eq!(N, v.len());
667        Ok(())
668    }
669
670    #[rstest::rstest]
671    #[test]
672    #[cfg_attr(miri, ignore)]
673    fn test_alloc_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
674        let m = backend.alloc_cm(N, 2 * N)?;
675        let m = backend.to_host_cm(m)?;
676
677        assert_eq!(N, m.nrows());
678        assert_eq!(2 * N, m.ncols());
679        Ok(())
680    }
681
682    #[rstest::rstest]
683    #[test]
684    #[cfg_attr(miri, ignore)]
685    fn test_alloc_zeros_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
686        let v = backend.alloc_zeros_v(N)?;
687        let v = backend.to_host_v(v)?;
688
689        assert_eq!(N, v.len());
690        assert!(v.iter().all(|&v| v == 0.));
691        Ok(())
692    }
693
694    #[rstest::rstest]
695    #[test]
696    #[cfg_attr(miri, ignore)]
697    fn test_alloc_zeros_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
698        let v = backend.alloc_zeros_cv(N)?;
699        let v = backend.to_host_cv(v)?;
700
701        assert_eq!(N, v.len());
702        assert!(v.iter().all(|&v| v == Complex::new(0., 0.)));
703        Ok(())
704    }
705
706    #[rstest::rstest]
707    #[test]
708    #[cfg_attr(miri, ignore)]
709    fn test_alloc_zeros_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
710        let m = backend.alloc_zeros_cm(N, 2 * N)?;
711        let m = backend.to_host_cm(m)?;
712
713        assert_eq!(N, m.nrows());
714        assert_eq!(2 * N, m.ncols());
715        assert!(m.iter().all(|&v| v == Complex::new(0., 0.)));
716        Ok(())
717    }
718
719    #[rstest::rstest]
720    #[test]
721    #[cfg_attr(miri, ignore)]
722    fn test_cols_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
723        let m = backend.alloc_cm(N, 2 * N)?;
724
725        assert_eq!(2 * N, backend.cols_c(&m)?);
726
727        Ok(())
728    }
729
730    #[rstest::rstest]
731    #[test]
732    #[cfg_attr(miri, ignore)]
733    fn test_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
734        let rng = rand::rng();
735
736        let v: Vec<f32> = rng
737            .sample_iter(rand::distr::StandardUniform)
738            .take(N)
739            .collect();
740
741        let c = backend.from_slice_v(&v)?;
742        let c = backend.to_host_v(c)?;
743
744        assert_eq!(N, c.len());
745        v.iter().zip(c.iter()).for_each(|(&r, &c)| {
746            assert_eq!(r, c);
747        });
748        Ok(())
749    }
750
751    #[rstest::rstest]
752    #[test]
753    #[cfg_attr(miri, ignore)]
754    fn test_from_slice_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
755        let rng = rand::rng();
756
757        let v: Vec<f32> = rng
758            .sample_iter(rand::distr::StandardUniform)
759            .take(N * 2 * N)
760            .collect();
761
762        let c = backend.from_slice_m(N, 2 * N, &v)?;
763        let c = backend.to_host_m(c)?;
764
765        assert_eq!(N, c.nrows());
766        assert_eq!(2 * N, c.ncols());
767        (0..2 * N).for_each(|col| {
768            (0..N).for_each(|row| {
769                assert_eq!(v[col * N + row], c[(row, col)]);
770            })
771        });
772        Ok(())
773    }
774
775    #[rstest::rstest]
776    #[test]
777    #[cfg_attr(miri, ignore)]
778    fn test_from_slice_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
779        let rng = rand::rng();
780
781        let real: Vec<f32> = rng
782            .sample_iter(rand::distr::StandardUniform)
783            .take(N)
784            .collect();
785
786        let c = backend.from_slice_cv(&real)?;
787        let c = backend.to_host_cv(c)?;
788
789        assert_eq!(N, c.len());
790        real.iter().zip(c.iter()).for_each(|(r, c)| {
791            assert_eq!(r, &c.re);
792            assert_eq!(0.0, c.im);
793        });
794        Ok(())
795    }
796
797    #[rstest::rstest]
798    #[test]
799    #[cfg_attr(miri, ignore)]
800    fn test_from_slice2_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
801        let mut rng = rand::rng();
802
803        let real: Vec<f32> = (&mut rng)
804            .sample_iter(rand::distr::StandardUniform)
805            .take(N)
806            .collect();
807        let imag: Vec<f32> = (&mut rng)
808            .sample_iter(rand::distr::StandardUniform)
809            .take(N)
810            .collect();
811
812        let c = backend.from_slice2_cv(&real, &imag)?;
813        let c = backend.to_host_cv(c)?;
814
815        assert_eq!(N, c.len());
816        real.iter()
817            .zip(imag.iter())
818            .zip(c.iter())
819            .for_each(|((r, i), c)| {
820                assert_eq!(r, &c.re);
821                assert_eq!(i, &c.im);
822            });
823        Ok(())
824    }
825
826    #[rstest::rstest]
827    #[test]
828    #[cfg_attr(miri, ignore)]
829    fn test_from_slice2_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
830        let mut rng = rand::rng();
831
832        let real: Vec<f32> = (&mut rng)
833            .sample_iter(rand::distr::StandardUniform)
834            .take(N * 2 * N)
835            .collect();
836        let imag: Vec<f32> = (&mut rng)
837            .sample_iter(rand::distr::StandardUniform)
838            .take(N * 2 * N)
839            .collect();
840
841        let c = backend.from_slice2_cm(N, 2 * N, &real, &imag)?;
842        let c = backend.to_host_cm(c)?;
843
844        assert_eq!(N, c.nrows());
845        assert_eq!(2 * N, c.ncols());
846        (0..2 * N).for_each(|col| {
847            (0..N).for_each(|row| {
848                assert_eq!(real[col * N + row], c[(row, col)].re);
849                assert_eq!(imag[col * N + row], c[(row, col)].im);
850            })
851        });
852        Ok(())
853    }
854
855    #[rstest::rstest]
856    #[test]
857    #[cfg_attr(miri, ignore)]
858    fn test_copy_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
859        {
860            let mut a = backend.alloc_zeros_v(N)?;
861            let mut rng = rand::rng();
862            let v = (&mut rng)
863                .sample_iter(rand::distr::StandardUniform)
864                .take(N / 2)
865                .collect::<Vec<f32>>();
866
867            backend.copy_from_slice_v(&v, &mut a)?;
868
869            let a = backend.to_host_v(a)?;
870            (0..N / 2).for_each(|i| {
871                assert_eq!(v[i], a[i]);
872            });
873            (N / 2..N).for_each(|i| {
874                assert_eq!(0., a[i]);
875            });
876        }
877
878        {
879            let mut a = backend.alloc_zeros_v(N)?;
880            let v = [];
881
882            backend.copy_from_slice_v(&v, &mut a)?;
883
884            let a = backend.to_host_v(a)?;
885            a.iter().for_each(|&a| {
886                assert_eq!(0., a);
887            });
888        }
889
890        Ok(())
891    }
892
893    #[rstest::rstest]
894    #[test]
895    #[cfg_attr(miri, ignore)]
896    fn test_copy_to_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
897        let a = make_random_v(&backend, N)?;
898        let mut b = backend.alloc_v(N)?;
899
900        backend.copy_to_v(&a, &mut b)?;
901
902        let a = backend.to_host_v(a)?;
903        let b = backend.to_host_v(b)?;
904        a.iter().zip(b.iter()).for_each(|(a, b)| {
905            assert_eq!(a, b);
906        });
907        Ok(())
908    }
909
910    #[rstest::rstest]
911    #[test]
912    #[cfg_attr(miri, ignore)]
913    fn test_copy_to_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
914        let a = make_random_m(&backend, N, N)?;
915        let mut b = backend.alloc_m(N, N)?;
916
917        backend.copy_to_m(&a, &mut b)?;
918
919        let a = backend.to_host_m(a)?;
920        let b = backend.to_host_m(b)?;
921        a.iter().zip(b.iter()).for_each(|(a, b)| {
922            assert_eq!(a, b);
923        });
924        Ok(())
925    }
926
927    #[rstest::rstest]
928    #[test]
929    #[cfg_attr(miri, ignore)]
930    fn test_clone_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
931        let c = make_random_v(&backend, N)?;
932        let c2 = backend.clone_v(&c)?;
933
934        let c = backend.to_host_v(c)?;
935        let c2 = backend.to_host_v(c2)?;
936
937        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
938            assert_eq!(c, c2);
939        });
940        Ok(())
941    }
942
943    #[rstest::rstest]
944    #[test]
945    #[cfg_attr(miri, ignore)]
946    fn test_clone_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
947        let c = make_random_m(&backend, N, N)?;
948        let c2 = backend.clone_m(&c)?;
949
950        let c = backend.to_host_m(c)?;
951        let c2 = backend.to_host_m(c2)?;
952
953        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
954            assert_eq!(c, c2);
955        });
956        Ok(())
957    }
958
959    #[rstest::rstest]
960    #[test]
961    #[cfg_attr(miri, ignore)]
962    fn test_clone_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
963        let c = make_random_cv(&backend, N)?;
964        let c2 = backend.clone_cv(&c)?;
965
966        let c = backend.to_host_cv(c)?;
967        let c2 = backend.to_host_cv(c2)?;
968
969        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
970            assert_eq!(c.re, c2.re);
971            assert_eq!(c.im, c2.im);
972        });
973        Ok(())
974    }
975
976    #[rstest::rstest]
977    #[test]
978    #[cfg_attr(miri, ignore)]
979    fn test_clone_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
980        let c = make_random_cm(&backend, N, N)?;
981        let c2 = backend.clone_cm(&c)?;
982
983        let c = backend.to_host_cm(c)?;
984        let c2 = backend.to_host_cm(c2)?;
985
986        c.iter().zip(c2.iter()).for_each(|(c, c2)| {
987            assert_eq!(c.re, c2.re);
988            assert_eq!(c.im, c2.im);
989        });
990        Ok(())
991    }
992
993    #[rstest::rstest]
994    #[test]
995    #[cfg_attr(miri, ignore)]
996    fn test_make_complex2_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
997        let real = make_random_v(&backend, N)?;
998        let imag = make_random_v(&backend, N)?;
999
1000        let mut c = backend.alloc_cv(N)?;
1001        backend.make_complex2_v(&real, &imag, &mut c)?;
1002
1003        let real = backend.to_host_v(real)?;
1004        let imag = backend.to_host_v(imag)?;
1005        let c = backend.to_host_cv(c)?;
1006        real.iter()
1007            .zip(imag.iter())
1008            .zip(c.iter())
1009            .for_each(|((r, i), c)| {
1010                assert_eq!(r, &c.re);
1011                assert_eq!(i, &c.im);
1012            });
1013        Ok(())
1014    }
1015
1016    #[rstest::rstest]
1017    #[test]
1018    #[cfg_attr(miri, ignore)]
1019    fn test_create_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1020        let diagonal = make_random_v(&backend, N)?;
1021
1022        let mut c = backend.alloc_m(N, N)?;
1023
1024        backend.create_diagonal(&diagonal, &mut c)?;
1025
1026        let diagonal = backend.to_host_v(diagonal)?;
1027        let c = backend.to_host_m(c)?;
1028        (0..N).for_each(|i| {
1029            (0..N).for_each(|j| {
1030                if i == j {
1031                    assert_eq!(diagonal[i], c[(i, j)]);
1032                } else {
1033                    assert_eq!(0.0, c[(i, j)]);
1034                }
1035            })
1036        });
1037        Ok(())
1038    }
1039
1040    #[rstest::rstest]
1041    #[test]
1042    #[cfg_attr(miri, ignore)]
1043    fn test_create_diagonal_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1044        let diagonal = make_random_cv(&backend, N)?;
1045
1046        let mut c = backend.alloc_cm(N, N)?;
1047
1048        backend.create_diagonal_c(&diagonal, &mut c)?;
1049
1050        let diagonal = backend.to_host_cv(diagonal)?;
1051        let c = backend.to_host_cm(c)?;
1052        (0..N).for_each(|i| {
1053            (0..N).for_each(|j| {
1054                if i == j {
1055                    assert_eq!(diagonal[i].re, c[(i, j)].re);
1056                    assert_eq!(diagonal[i].im, c[(i, j)].im);
1057                } else {
1058                    assert_eq!(0.0, c[(i, j)].re);
1059                    assert_eq!(0.0, c[(i, j)].im);
1060                }
1061            })
1062        });
1063        Ok(())
1064    }
1065
1066    #[rstest::rstest]
1067    #[test]
1068    #[cfg_attr(miri, ignore)]
1069    fn test_get_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1070        let m = make_random_m(&backend, N, N)?;
1071        let mut diagonal = backend.alloc_v(N)?;
1072
1073        backend.get_diagonal(&m, &mut diagonal)?;
1074
1075        let m = backend.to_host_m(m)?;
1076        let diagonal = backend.to_host_v(diagonal)?;
1077        (0..N).for_each(|i| {
1078            assert_eq!(m[(i, i)], diagonal[i]);
1079        });
1080        Ok(())
1081    }
1082
1083    #[rstest::rstest]
1084    #[test]
1085    #[cfg_attr(miri, ignore)]
1086    fn test_norm_squared_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1087        let v = make_random_cv(&backend, N)?;
1088
1089        let mut abs = backend.alloc_v(N)?;
1090        backend.norm_squared_cv(&v, &mut abs)?;
1091
1092        let v = backend.to_host_cv(v)?;
1093        let abs = backend.to_host_v(abs)?;
1094        v.iter().zip(abs.iter()).for_each(|(v, abs)| {
1095            assert_approx_eq::assert_approx_eq!(v.norm_squared(), abs, EPS);
1096        });
1097        Ok(())
1098    }
1099
1100    #[rstest::rstest]
1101    #[test]
1102    #[cfg_attr(miri, ignore)]
1103    fn test_real_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1104        let v = make_random_cm(&backend, N, N)?;
1105        let mut r = backend.alloc_m(N, N)?;
1106
1107        backend.real_cm(&v, &mut r)?;
1108
1109        let v = backend.to_host_cm(v)?;
1110        let r = backend.to_host_m(r)?;
1111        (0..N).for_each(|i| {
1112            (0..N).for_each(|j| {
1113                assert_approx_eq::assert_approx_eq!(v[(i, j)].re, r[(i, j)], EPS);
1114            })
1115        });
1116        Ok(())
1117    }
1118
1119    #[rstest::rstest]
1120    #[test]
1121    #[cfg_attr(miri, ignore)]
1122    fn test_imag_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1123        let v = make_random_cm(&backend, N, N)?;
1124        let mut r = backend.alloc_m(N, N)?;
1125
1126        backend.imag_cm(&v, &mut r)?;
1127
1128        let v = backend.to_host_cm(v)?;
1129        let r = backend.to_host_m(r)?;
1130        (0..N).for_each(|i| {
1131            (0..N).for_each(|j| {
1132                assert_approx_eq::assert_approx_eq!(v[(i, j)].im, r[(i, j)], EPS);
1133            })
1134        });
1135        Ok(())
1136    }
1137
1138    #[rstest::rstest]
1139    #[test]
1140    #[cfg_attr(miri, ignore)]
1141    fn test_scale_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1142        let mut v = make_random_cv(&backend, N)?;
1143        let vc = backend.clone_cv(&v)?;
1144        let mut rng = rand::rng();
1145        let scale = Complex::new(rng.random(), rng.random());
1146
1147        backend.scale_assign_cv(scale, &mut v)?;
1148
1149        let v = backend.to_host_cv(v)?;
1150        let vc = backend.to_host_cv(vc)?;
1151        v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1152            assert_approx_eq::assert_approx_eq!(scale * vc, v, EPS);
1153        });
1154        Ok(())
1155    }
1156
1157    #[rstest::rstest]
1158    #[test]
1159    #[cfg_attr(miri, ignore)]
1160    fn test_conj_assign_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1161        let mut v = make_random_cv(&backend, N)?;
1162        let vc = backend.clone_cv(&v)?;
1163
1164        backend.conj_assign_v(&mut v)?;
1165
1166        let v = backend.to_host_cv(v)?;
1167        let vc = backend.to_host_cv(vc)?;
1168        v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1169            assert_eq!(vc.re, v.re);
1170            assert_eq!(vc.im, -v.im);
1171        });
1172        Ok(())
1173    }
1174
1175    #[rstest::rstest]
1176    #[test]
1177    #[cfg_attr(miri, ignore)]
1178    fn test_exp_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1179        let mut v = make_random_cv(&backend, N)?;
1180        let vc = backend.clone_cv(&v)?;
1181
1182        backend.exp_assign_cv(&mut v)?;
1183
1184        let v = backend.to_host_cv(v)?;
1185        let vc = backend.to_host_cv(vc)?;
1186        v.iter().zip(vc.iter()).for_each(|(v, vc)| {
1187            assert_approx_eq::assert_approx_eq!(vc.exp(), v, EPS);
1188        });
1189        Ok(())
1190    }
1191
1192    #[rstest::rstest]
1193    #[test]
1194    #[cfg_attr(miri, ignore)]
1195    fn test_concat_col_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1196        let a = make_random_cm(&backend, N, N)?;
1197        let b = make_random_cm(&backend, N, 2 * N)?;
1198        let mut c = backend.alloc_cm(N, N + 2 * N)?;
1199
1200        backend.concat_col_cm(&a, &b, &mut c)?;
1201
1202        let a = backend.to_host_cm(a)?;
1203        let b = backend.to_host_cm(b)?;
1204        let c = backend.to_host_cm(c)?;
1205        (0..N).for_each(|col| (0..N).for_each(|row| assert_eq!(a[(row, col)], c[(row, col)])));
1206        (0..2 * N)
1207            .for_each(|col| (0..N).for_each(|row| assert_eq!(b[(row, col)], c[(row, N + col)])));
1208        Ok(())
1209    }
1210
1211    #[rstest::rstest]
1212    #[test]
1213    #[cfg_attr(miri, ignore)]
1214    fn test_max_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1215        let v = make_random_v(&backend, N)?;
1216
1217        let max = backend.max_v(&v)?;
1218
1219        let v = backend.to_host_v(v)?;
1220        assert_eq!(
1221            *v.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
1222            max
1223        );
1224        Ok(())
1225    }
1226
1227    #[rstest::rstest]
1228    #[test]
1229    #[cfg_attr(miri, ignore)]
1230    fn test_hadamard_product_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1231        let a = make_random_cm(&backend, N, N)?;
1232        let b = make_random_cm(&backend, N, N)?;
1233        let mut c = backend.alloc_cm(N, N)?;
1234
1235        backend.hadamard_product_cm(&a, &b, &mut c)?;
1236
1237        let a = backend.to_host_cm(a)?;
1238        let b = backend.to_host_cm(b)?;
1239        let c = backend.to_host_cm(c)?;
1240        c.iter()
1241            .zip(a.iter())
1242            .zip(b.iter())
1243            .for_each(|((c, a), b)| {
1244                assert_approx_eq::assert_approx_eq!(a.re * b.re - a.im * b.im, c.re, EPS);
1245                assert_approx_eq::assert_approx_eq!(a.re * b.im + a.im * b.re, c.im, EPS);
1246            });
1247        Ok(())
1248    }
1249
1250    #[rstest::rstest]
1251    #[test]
1252    #[cfg_attr(miri, ignore)]
1253    fn test_dot(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1254        let a = make_random_v(&backend, N)?;
1255        let b = make_random_v(&backend, N)?;
1256
1257        let dot = backend.dot(&a, &b)?;
1258
1259        let a = backend.to_host_v(a)?;
1260        let b = backend.to_host_v(b)?;
1261        let expect = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum::<f32>();
1262        assert_approx_eq::assert_approx_eq!(dot, expect, EPS);
1263        Ok(())
1264    }
1265
1266    #[rstest::rstest]
1267    #[test]
1268    #[cfg_attr(miri, ignore)]
1269    fn test_dot_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1270        let a = make_random_cv(&backend, N)?;
1271        let b = make_random_cv(&backend, N)?;
1272
1273        let dot = backend.dot_c(&a, &b)?;
1274
1275        let a = backend.to_host_cv(a)?;
1276        let b = backend.to_host_cv(b)?;
1277        let expect = a
1278            .iter()
1279            .zip(b.iter())
1280            .map(|(a, b)| a.conj() * b)
1281            .sum::<Complex>();
1282        assert_approx_eq::assert_approx_eq!(dot.re, expect.re, EPS);
1283        assert_approx_eq::assert_approx_eq!(dot.im, expect.im, EPS);
1284        Ok(())
1285    }
1286
1287    #[rstest::rstest]
1288    #[test]
1289    #[cfg_attr(miri, ignore)]
1290    fn test_add_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1291        let a = make_random_v(&backend, N)?;
1292        let mut b = make_random_v(&backend, N)?;
1293        let bc = backend.clone_v(&b)?;
1294
1295        let mut rng = rand::rng();
1296        let alpha = rng.random();
1297
1298        backend.add_v(alpha, &a, &mut b)?;
1299
1300        let a = backend.to_host_v(a)?;
1301        let b = backend.to_host_v(b)?;
1302        let bc = backend.to_host_v(bc)?;
1303        b.iter()
1304            .zip(a.iter())
1305            .zip(bc.iter())
1306            .for_each(|((b, a), bc)| {
1307                assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1308            });
1309        Ok(())
1310    }
1311
1312    #[rstest::rstest]
1313    #[test]
1314    #[cfg_attr(miri, ignore)]
1315    fn test_add_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1316        let a = make_random_m(&backend, N, N)?;
1317        let mut b = make_random_m(&backend, N, N)?;
1318        let bc = backend.clone_m(&b)?;
1319
1320        let mut rng = rand::rng();
1321        let alpha = rng.random();
1322
1323        backend.add_m(alpha, &a, &mut b)?;
1324
1325        let a = backend.to_host_m(a)?;
1326        let b = backend.to_host_m(b)?;
1327        let bc = backend.to_host_m(bc)?;
1328        b.iter()
1329            .zip(a.iter())
1330            .zip(bc.iter())
1331            .for_each(|((b, a), bc)| {
1332                assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1333            });
1334        Ok(())
1335    }
1336
1337    #[rstest::rstest]
1338    #[test]
1339    #[cfg_attr(miri, ignore)]
1340    fn test_gevv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1341        let mut rng = rand::rng();
1342
1343        {
1344            let a = make_random_cv(&backend, N)?;
1345            let b = make_random_cv(&backend, N)?;
1346            let mut c = make_random_cm(&backend, N, N)?;
1347            let cc = backend.clone_cm(&c)?;
1348
1349            let alpha = Complex::new(rng.random(), rng.random());
1350            let beta = Complex::new(rng.random(), rng.random());
1351            backend.gevv_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1352
1353            let a = backend.to_host_cv(a)?;
1354            let b = backend.to_host_cv(b)?;
1355            let c = backend.to_host_cm(c)?;
1356            let cc = backend.to_host_cm(cc)?;
1357            let expected = a * b.transpose() * alpha + cc * beta;
1358            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1359                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1360                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1361            });
1362        }
1363
1364        {
1365            let a = make_random_cv(&backend, N)?;
1366            let b = make_random_cv(&backend, N)?;
1367            let mut c = make_random_cm(&backend, N, N)?;
1368            let cc = backend.clone_cm(&c)?;
1369
1370            let alpha = Complex::new(rng.random(), rng.random());
1371            let beta = Complex::new(rng.random(), rng.random());
1372            backend.gevv_c(
1373                Trans::NoTrans,
1374                Trans::ConjTrans,
1375                alpha,
1376                &a,
1377                &b,
1378                beta,
1379                &mut c,
1380            )?;
1381
1382            let a = backend.to_host_cv(a)?;
1383            let b = backend.to_host_cv(b)?;
1384            let c = backend.to_host_cm(c)?;
1385            let cc = backend.to_host_cm(cc)?;
1386            let expected = a * b.adjoint() * alpha + cc * beta;
1387            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1388                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1389                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1390            });
1391        }
1392
1393        {
1394            let a = make_random_cv(&backend, N)?;
1395            let b = make_random_cv(&backend, N)?;
1396            let mut c = make_random_cm(&backend, 1, 1)?;
1397            let cc = backend.clone_cm(&c)?;
1398
1399            let alpha = Complex::new(rng.random(), rng.random());
1400            let beta = Complex::new(rng.random(), rng.random());
1401            backend.gevv_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1402
1403            let a = backend.to_host_cv(a)?;
1404            let b = backend.to_host_cv(b)?;
1405            let c = backend.to_host_cm(c)?;
1406            let cc = backend.to_host_cm(cc)?;
1407            let expected = a.transpose() * b * alpha + cc * beta;
1408            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1409                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1410                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1411            });
1412        }
1413
1414        {
1415            let a = make_random_cv(&backend, N)?;
1416            let b = make_random_cv(&backend, N)?;
1417            let mut c = make_random_cm(&backend, 1, 1)?;
1418            let cc = backend.clone_cm(&c)?;
1419
1420            let alpha = Complex::new(rng.random(), rng.random());
1421            let beta = Complex::new(rng.random(), rng.random());
1422            backend.gevv_c(
1423                Trans::ConjTrans,
1424                Trans::NoTrans,
1425                alpha,
1426                &a,
1427                &b,
1428                beta,
1429                &mut c,
1430            )?;
1431
1432            let a = backend.to_host_cv(a)?;
1433            let b = backend.to_host_cv(b)?;
1434            let c = backend.to_host_cm(c)?;
1435            let cc = backend.to_host_cm(cc)?;
1436            let expected = a.adjoint() * b * alpha + cc * beta;
1437            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1438                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1439                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1440            });
1441        }
1442
1443        Ok(())
1444    }
1445
1446    #[rstest::rstest]
1447    #[test]
1448    #[cfg_attr(miri, ignore)]
1449    fn test_gemv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1450        let m = N;
1451        let n = 2 * N;
1452
1453        let mut rng = rand::rng();
1454
1455        {
1456            let a = make_random_cm(&backend, m, n)?;
1457            let b = make_random_cv(&backend, n)?;
1458            let mut c = make_random_cv(&backend, m)?;
1459            let cc = backend.clone_cv(&c)?;
1460
1461            let alpha = Complex::new(rng.random(), rng.random());
1462            let beta = Complex::new(rng.random(), rng.random());
1463            backend.gemv_c(Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1464
1465            let a = backend.to_host_cm(a)?;
1466            let b = backend.to_host_cv(b)?;
1467            let c = backend.to_host_cv(c)?;
1468            let cc = backend.to_host_cv(cc)?;
1469            let expected = a * b * alpha + cc * beta;
1470            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1471                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1472                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1473            });
1474        }
1475
1476        {
1477            let a = make_random_cm(&backend, n, m)?;
1478            let b = make_random_cv(&backend, n)?;
1479            let mut c = make_random_cv(&backend, m)?;
1480            let cc = backend.clone_cv(&c)?;
1481
1482            let alpha = Complex::new(rng.random(), rng.random());
1483            let beta = Complex::new(rng.random(), rng.random());
1484            backend.gemv_c(Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1485
1486            let a = backend.to_host_cm(a)?;
1487            let b = backend.to_host_cv(b)?;
1488            let c = backend.to_host_cv(c)?;
1489            let cc = backend.to_host_cv(cc)?;
1490            let expected = a.transpose() * b * alpha + cc * beta;
1491            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1492                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1493                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1494            });
1495        }
1496
1497        {
1498            let a = make_random_cm(&backend, n, m)?;
1499            let b = make_random_cv(&backend, n)?;
1500            let mut c = make_random_cv(&backend, m)?;
1501            let cc = backend.clone_cv(&c)?;
1502
1503            let alpha = Complex::new(rng.random(), rng.random());
1504            let beta = Complex::new(rng.random(), rng.random());
1505            backend.gemv_c(Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1506
1507            let a = backend.to_host_cm(a)?;
1508            let b = backend.to_host_cv(b)?;
1509            let c = backend.to_host_cv(c)?;
1510            let cc = backend.to_host_cv(cc)?;
1511            let expected = a.adjoint() * b * alpha + cc * beta;
1512            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1513                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1514                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1515            });
1516        }
1517        Ok(())
1518    }
1519
1520    #[rstest::rstest]
1521    #[test]
1522    #[cfg_attr(miri, ignore)]
1523    fn test_gemm_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1524        let m = N;
1525        let n = 2 * N;
1526        let k = 3 * N;
1527
1528        let mut rng = rand::rng();
1529
1530        {
1531            let a = make_random_cm(&backend, m, k)?;
1532            let b = make_random_cm(&backend, k, n)?;
1533            let mut c = make_random_cm(&backend, m, n)?;
1534            let cc = backend.clone_cm(&c)?;
1535
1536            let alpha = Complex::new(rng.random(), rng.random());
1537            let beta = Complex::new(rng.random(), rng.random());
1538            backend.gemm_c(Trans::NoTrans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1539
1540            let a = backend.to_host_cm(a)?;
1541            let b = backend.to_host_cm(b)?;
1542            let c = backend.to_host_cm(c)?;
1543            let cc = backend.to_host_cm(cc)?;
1544            let expected = a * b * alpha + cc * beta;
1545            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1546                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1547                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1548            });
1549        }
1550
1551        {
1552            let a = make_random_cm(&backend, m, k)?;
1553            let b = make_random_cm(&backend, n, k)?;
1554            let mut c = make_random_cm(&backend, m, n)?;
1555            let cc = backend.clone_cm(&c)?;
1556
1557            let alpha = Complex::new(rng.random(), rng.random());
1558            let beta = Complex::new(rng.random(), rng.random());
1559            backend.gemm_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1560
1561            let a = backend.to_host_cm(a)?;
1562            let b = backend.to_host_cm(b)?;
1563            let c = backend.to_host_cm(c)?;
1564            let cc = backend.to_host_cm(cc)?;
1565            let expected = a * b.transpose() * alpha + cc * beta;
1566            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1567                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1568                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1569            });
1570        }
1571
1572        {
1573            let a = make_random_cm(&backend, m, k)?;
1574            let b = make_random_cm(&backend, n, k)?;
1575            let mut c = make_random_cm(&backend, m, n)?;
1576            let cc = backend.clone_cm(&c)?;
1577
1578            let alpha = Complex::new(rng.random(), rng.random());
1579            let beta = Complex::new(rng.random(), rng.random());
1580            backend.gemm_c(
1581                Trans::NoTrans,
1582                Trans::ConjTrans,
1583                alpha,
1584                &a,
1585                &b,
1586                beta,
1587                &mut c,
1588            )?;
1589
1590            let a = backend.to_host_cm(a)?;
1591            let b = backend.to_host_cm(b)?;
1592            let c = backend.to_host_cm(c)?;
1593            let cc = backend.to_host_cm(cc)?;
1594            let expected = a * b.adjoint() * alpha + cc * beta;
1595            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1596                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1597                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1598            });
1599        }
1600
1601        {
1602            let a = make_random_cm(&backend, k, m)?;
1603            let b = make_random_cm(&backend, k, n)?;
1604            let mut c = make_random_cm(&backend, m, n)?;
1605            let cc = backend.clone_cm(&c)?;
1606
1607            let alpha = Complex::new(rng.random(), rng.random());
1608            let beta = Complex::new(rng.random(), rng.random());
1609            backend.gemm_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1610
1611            let a = backend.to_host_cm(a)?;
1612            let b = backend.to_host_cm(b)?;
1613            let c = backend.to_host_cm(c)?;
1614            let cc = backend.to_host_cm(cc)?;
1615            let expected = a.transpose() * b * alpha + cc * beta;
1616            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1617                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1618                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1619            });
1620        }
1621
1622        {
1623            let a = make_random_cm(&backend, k, m)?;
1624            let b = make_random_cm(&backend, n, k)?;
1625            let mut c = make_random_cm(&backend, m, n)?;
1626            let cc = backend.clone_cm(&c)?;
1627
1628            let alpha = Complex::new(rng.random(), rng.random());
1629            let beta = Complex::new(rng.random(), rng.random());
1630            backend.gemm_c(Trans::Trans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1631
1632            let a = backend.to_host_cm(a)?;
1633            let b = backend.to_host_cm(b)?;
1634            let c = backend.to_host_cm(c)?;
1635            let cc = backend.to_host_cm(cc)?;
1636            let expected = a.transpose() * b.transpose() * alpha + cc * beta;
1637            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1638                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1639                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1640            });
1641        }
1642
1643        {
1644            let a = make_random_cm(&backend, k, m)?;
1645            let b = make_random_cm(&backend, n, k)?;
1646            let mut c = make_random_cm(&backend, m, n)?;
1647            let cc = backend.clone_cm(&c)?;
1648
1649            let alpha = Complex::new(rng.random(), rng.random());
1650            let beta = Complex::new(rng.random(), rng.random());
1651            backend.gemm_c(Trans::Trans, Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1652
1653            let a = backend.to_host_cm(a)?;
1654            let b = backend.to_host_cm(b)?;
1655            let c = backend.to_host_cm(c)?;
1656            let cc = backend.to_host_cm(cc)?;
1657            let expected = a.transpose() * b.adjoint() * alpha + cc * beta;
1658            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1659                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1660                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1661            });
1662        }
1663
1664        {
1665            let a = make_random_cm(&backend, k, m)?;
1666            let b = make_random_cm(&backend, k, n)?;
1667            let mut c = make_random_cm(&backend, m, n)?;
1668            let cc = backend.clone_cm(&c)?;
1669
1670            let alpha = Complex::new(rng.random(), rng.random());
1671            let beta = Complex::new(rng.random(), rng.random());
1672            backend.gemm_c(
1673                Trans::ConjTrans,
1674                Trans::NoTrans,
1675                alpha,
1676                &a,
1677                &b,
1678                beta,
1679                &mut c,
1680            )?;
1681
1682            let a = backend.to_host_cm(a)?;
1683            let b = backend.to_host_cm(b)?;
1684            let c = backend.to_host_cm(c)?;
1685            let cc = backend.to_host_cm(cc)?;
1686            let expected = a.adjoint() * b * alpha + cc * beta;
1687            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1688                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1689                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1690            });
1691        }
1692
1693        {
1694            let a = make_random_cm(&backend, k, m)?;
1695            let b = make_random_cm(&backend, n, k)?;
1696            let mut c = make_random_cm(&backend, m, n)?;
1697            let cc = backend.clone_cm(&c)?;
1698
1699            let alpha = Complex::new(rng.random(), rng.random());
1700            let beta = Complex::new(rng.random(), rng.random());
1701            backend.gemm_c(Trans::ConjTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1702
1703            let a = backend.to_host_cm(a)?;
1704            let b = backend.to_host_cm(b)?;
1705            let c = backend.to_host_cm(c)?;
1706            let cc = backend.to_host_cm(cc)?;
1707            let expected = a.adjoint() * b.transpose() * alpha + cc * beta;
1708            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1709                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1710                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1711            });
1712        }
1713
1714        {
1715            let a = make_random_cm(&backend, k, m)?;
1716            let b = make_random_cm(&backend, n, k)?;
1717            let mut c = make_random_cm(&backend, m, n)?;
1718            let cc = backend.clone_cm(&c)?;
1719
1720            let alpha = Complex::new(rng.random(), rng.random());
1721            let beta = Complex::new(rng.random(), rng.random());
1722            backend.gemm_c(
1723                Trans::ConjTrans,
1724                Trans::ConjTrans,
1725                alpha,
1726                &a,
1727                &b,
1728                beta,
1729                &mut c,
1730            )?;
1731
1732            let a = backend.to_host_cm(a)?;
1733            let b = backend.to_host_cm(b)?;
1734            let c = backend.to_host_cm(c)?;
1735            let cc = backend.to_host_cm(cc)?;
1736            let expected = a.adjoint() * b.adjoint() * alpha + cc * beta;
1737            c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1738                assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1739                assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1740            });
1741        }
1742        Ok(())
1743    }
1744
1745    #[rstest::rstest]
1746    #[test]
1747    #[cfg_attr(miri, ignore)]
1748    fn test_solve_inplace(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1749        {
1750            let tmp = make_random_m(&backend, N, N)?;
1751            let tmp = backend.to_host_m(tmp)?;
1752
1753            let a = &tmp * tmp.adjoint();
1754
1755            let mut rng = rand::rng();
1756            let x = VectorX::from_iterator(N, (0..N).map(|_| rng.random()));
1757
1758            let b = &a * &x;
1759
1760            let aa = backend.from_slice_m(N, N, a.as_slice())?;
1761            let mut bb = backend.from_slice_v(b.as_slice())?;
1762
1763            backend.solve_inplace(&aa, &mut bb)?;
1764
1765            let b2 = &a * backend.to_host_v(bb)?;
1766            assert!(approx::relative_eq!(b, b2, epsilon = 1e-3));
1767        }
1768
1769        Ok(())
1770    }
1771
1772    #[rstest::rstest]
1773    #[test]
1774    #[cfg_attr(miri, ignore)]
1775    fn test_reduce_col(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1776        let a = make_random_m(&backend, N, N)?;
1777
1778        let mut b = backend.alloc_v(N)?;
1779
1780        backend.reduce_col(&a, &mut b)?;
1781
1782        let a = backend.to_host_m(a)?;
1783        let b = backend.to_host_v(b)?;
1784
1785        (0..N).for_each(|row| {
1786            let sum = a.row(row).iter().sum::<f32>();
1787            assert_approx_eq::assert_approx_eq!(sum, b[row], EPS);
1788        });
1789        Ok(())
1790    }
1791
1792    #[rstest::rstest]
1793    #[test]
1794    #[cfg_attr(miri, ignore)]
1795    fn test_scaled_to_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1796        let a = make_random_cv(&backend, N)?;
1797        let b = make_random_cv(&backend, N)?;
1798        let mut c = backend.alloc_cv(N)?;
1799
1800        backend.scaled_to_cv(&a, &b, &mut c)?;
1801
1802        let a = backend.to_host_cv(a)?;
1803        let b = backend.to_host_cv(b)?;
1804        let c = backend.to_host_cv(c)?;
1805        c.iter()
1806            .zip(a.iter())
1807            .zip(b.iter())
1808            .for_each(|((&c, &a), &b)| {
1809                assert_approx_eq::assert_approx_eq!(c, a / a.abs() * b, EPS);
1810            });
1811
1812        Ok(())
1813    }
1814
1815    #[rstest::rstest]
1816    #[test]
1817    #[cfg_attr(miri, ignore)]
1818    fn test_scaled_to_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1819        let a = make_random_cv(&backend, N)?;
1820        let mut b = make_random_cv(&backend, N)?;
1821        let bc = backend.clone_cv(&b)?;
1822
1823        backend.scaled_to_assign_cv(&a, &mut b)?;
1824
1825        let a = backend.to_host_cv(a)?;
1826        let b = backend.to_host_cv(b)?;
1827        let bc = backend.to_host_cv(bc)?;
1828        b.iter()
1829            .zip(a.iter())
1830            .zip(bc.iter())
1831            .for_each(|((&b, &a), &bc)| {
1832                assert_approx_eq::assert_approx_eq!(b, bc / bc.abs() * a, EPS);
1833            });
1834
1835        Ok(())
1836    }
1837
1838    #[rstest::rstest]
1839    #[test]
1840    #[case(1, 2)]
1841    #[case(2, 1)]
1842    fn test_generate_propagation_matrix(
1843        #[case] dev_num: usize,
1844        #[case] foci_num: usize,
1845        backend: ArrayFireBackend<Sphere>,
1846    ) -> Result<(), HoloError> {
1847        let env = Environment::new();
1848
1849        let reference = |geometry: Geometry, foci: Vec<Point3>| {
1850            let mut g = MatrixXc::zeros(
1851                foci.len(),
1852                geometry
1853                    .iter()
1854                    .map(|dev| dev.num_transducers())
1855                    .sum::<usize>(),
1856            );
1857            let transducers = geometry
1858                .iter()
1859                .flat_map(|dev| dev.iter().map(|tr| (dev.idx(), tr)))
1860                .collect::<Vec<_>>();
1861            (0..foci.len()).for_each(|i| {
1862                (0..transducers.len()).for_each(|j| {
1863                    g[(i, j)] = propagate::<Sphere>(
1864                        transducers[j].1,
1865                        env.wavenumber(),
1866                        geometry[transducers[j].0].axial_direction(),
1867                        &foci[i],
1868                    )
1869                })
1870            });
1871            g
1872        };
1873
1874        let geometry = generate_geometry(dev_num);
1875        let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1876
1877        let g = backend.generate_propagation_matrix(
1878            &geometry,
1879            &env,
1880            &foci,
1881            &TransducerFilter::all_enabled(),
1882        )?;
1883        let g = backend.to_host_cm(g)?;
1884        reference(geometry, foci)
1885            .iter()
1886            .zip(g.iter())
1887            .for_each(|(r, g)| {
1888                assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1889                assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1890            });
1891
1892        Ok(())
1893    }
1894
1895    #[rstest::rstest]
1896    #[test]
1897    #[case(1, 2)]
1898    #[case(2, 1)]
1899    fn test_generate_propagation_matrix_with_filter(
1900        #[case] dev_num: usize,
1901        #[case] foci_num: usize,
1902        backend: ArrayFireBackend<Sphere>,
1903    ) -> Result<(), HoloError> {
1904        let env = Environment::new();
1905
1906        let filter = |geometry: &Geometry| -> TransducerFilter {
1907            TransducerFilter::from_fn(geometry, |dev| {
1908                let num_transducers = dev.num_transducers();
1909                Some(move |tr: &Transducer| tr.idx() > num_transducers / 2)
1910            })
1911        };
1912
1913        let reference = |geometry, foci: Vec<Point3>| {
1914            let filter = filter(&geometry);
1915            let transducers = geometry
1916                .iter()
1917                .flat_map(|dev| {
1918                    dev.iter().filter_map(|tr| {
1919                        if filter.is_enabled(tr) {
1920                            Some((dev.idx(), tr))
1921                        } else {
1922                            None
1923                        }
1924                    })
1925                })
1926                .collect::<Vec<_>>();
1927
1928            let mut g = MatrixXc::zeros(foci.len(), transducers.len());
1929            (0..foci.len()).for_each(|i| {
1930                (0..transducers.len()).for_each(|j| {
1931                    g[(i, j)] = propagate::<Sphere>(
1932                        transducers[j].1,
1933                        env.wavenumber(),
1934                        geometry[transducers[j].0].axial_direction(),
1935                        &foci[i],
1936                    )
1937                })
1938            });
1939            g
1940        };
1941
1942        let geometry = generate_geometry(dev_num);
1943        let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1944        let filter = filter(&geometry);
1945
1946        let g = backend.generate_propagation_matrix(&geometry, &env, &foci, &filter)?;
1947        let g = backend.to_host_cm(g)?;
1948        assert_eq!(g.nrows(), foci.len());
1949        assert_eq!(
1950            g.ncols(),
1951            geometry
1952                .iter()
1953                .map(|dev| dev.num_transducers() / 2)
1954                .sum::<usize>()
1955        );
1956        reference(geometry, foci)
1957            .iter()
1958            .zip(g.iter())
1959            .for_each(|(r, g)| {
1960                assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1961                assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1962            });
1963
1964        Ok(())
1965    }
1966
1967    #[rstest::rstest]
1968    #[test]
1969    fn test_gen_back_prop(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1970        let env = Environment::new();
1971
1972        let geometry = generate_geometry(1);
1973        let foci = gen_foci(2).map(|(p, _)| p).collect::<Vec<_>>();
1974
1975        let m = geometry
1976            .iter()
1977            .map(|dev| dev.num_transducers())
1978            .sum::<usize>();
1979        let n = foci.len();
1980
1981        let g = backend.generate_propagation_matrix(
1982            &geometry,
1983            &env,
1984            &foci,
1985            &TransducerFilter::all_enabled(),
1986        )?;
1987
1988        let b = backend.gen_back_prop(m, n, &g)?;
1989        let g = backend.to_host_cm(g)?;
1990        let reference = {
1991            let mut b = MatrixXc::zeros(m, n);
1992            (0..n).for_each(|i| {
1993                let x = 1.0 / g.rows(i, 1).iter().map(|x| x.norm_sqr()).sum::<f32>();
1994                (0..m).for_each(|j| {
1995                    b[(j, i)] = g[(i, j)].conj() * x;
1996                })
1997            });
1998            b
1999        };
2000
2001        let b = backend.to_host_cm(b)?;
2002        reference.iter().zip(b.iter()).for_each(|(r, b)| {
2003            assert_approx_eq::assert_approx_eq!(r.re, b.re, EPS);
2004            assert_approx_eq::assert_approx_eq!(r.im, b.im, EPS);
2005        });
2006        Ok(())
2007    }
2008}