1#![allow(unknown_lints)]
2#![allow(clippy::manual_slice_size_calculation)]
3
4use arrayfire::*;
5
6use autd3_core::{
7 acoustics::{
8 directivity::{Directivity, Sphere},
9 propagate,
10 },
11 environment::Environment,
12 gain::TransducerFilter,
13 geometry::Geometry,
14};
15use autd3_gain_holo::{
16 Complex, HoloError, LinAlgBackend, MatrixX, MatrixXc, Trans, VectorX, VectorXc,
17};
18
19pub type AFBackend = arrayfire::Backend;
20pub type AFDeviceInfo = (String, String, String, String);
21
22fn convert(trans: Trans) -> MatProp {
23 match trans {
24 Trans::NoTrans => MatProp::NONE,
25 Trans::Trans => MatProp::TRANS,
26 Trans::ConjTrans => MatProp::CTRANS,
27 }
28}
29
30pub struct ArrayFireBackend<D: Directivity> {
31 _phantom: std::marker::PhantomData<D>,
32}
33
34impl ArrayFireBackend<Sphere> {
35 pub fn get_available_backends() -> Vec<AFBackend> {
36 arrayfire::get_available_backends()
37 }
38
39 pub fn set_backend(backend: AFBackend) {
40 arrayfire::set_backend(backend);
41 }
42
43 pub fn set_device(device: i32) {
44 arrayfire::set_device(device);
45 }
46
47 pub fn get_available_devices() -> Vec<AFDeviceInfo> {
48 let cur_dev = arrayfire::get_device();
49 let r = (0..arrayfire::device_count())
50 .map(|i| {
51 arrayfire::set_device(i);
52 arrayfire::device_info()
53 })
54 .collect();
55 arrayfire::set_device(cur_dev);
56 r
57 }
58}
59
60impl Default for ArrayFireBackend<Sphere> {
61 fn default() -> Self {
62 Self {
63 _phantom: Default::default(),
64 }
65 }
66}
67
68impl<D: Directivity> ArrayFireBackend<D> {
69 pub fn new() -> Self {
70 Self {
71 _phantom: std::marker::PhantomData,
72 }
73 }
74}
75
76impl<D: Directivity> LinAlgBackend<D> for ArrayFireBackend<D> {
77 type MatrixXc = Array<c32>;
78 type MatrixX = Array<f32>;
79 type VectorXc = Array<c32>;
80 type VectorX = Array<f32>;
81
82 fn generate_propagation_matrix(
83 &self,
84 geometry: &Geometry,
85 env: &Environment,
86 foci: &[autd3_core::geometry::Point3],
87 filter: &TransducerFilter,
88 ) -> Result<Self::MatrixXc, HoloError> {
89 let g = if filter.is_all_enabled() {
90 geometry
91 .iter()
92 .flat_map(|dev| {
93 dev.iter().flat_map(move |tr| {
94 foci.iter().map(move |fp| {
95 propagate::<D>(tr, env.wavenumber(), dev.axial_direction(), fp)
96 })
97 })
98 })
99 .collect::<Vec<_>>()
100 } else {
101 geometry
102 .iter()
103 .filter(|dev| filter.is_enabled_device(dev))
104 .flat_map(|dev| {
105 dev.iter()
106 .filter(|tr| filter.is_enabled(tr))
107 .map(move |tr| {
108 foci.iter().map(move |fp| {
109 propagate::<D>(tr, env.wavenumber(), dev.axial_direction(), fp)
110 })
111 })
112 })
113 .flatten()
114 .collect::<Vec<_>>()
115 };
116
117 unsafe {
118 Ok(Array::new(
119 std::slice::from_raw_parts(g.as_ptr() as *const c32, g.len()),
120 Dim4::new(&[foci.len() as u64, (g.len() / foci.len()) as _, 1, 1]),
121 ))
122 }
123 }
124
125 fn alloc_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
126 Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
127 }
128
129 fn alloc_m(&self, rows: usize, cols: usize) -> Result<Self::MatrixX, HoloError> {
130 Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
131 }
132
133 fn alloc_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
134 Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
135 }
136
137 fn alloc_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
138 Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
139 }
140
141 fn alloc_zeros_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
142 Ok(arrayfire::constant(0., Dim4::new(&[size as _, 1, 1, 1])))
143 }
144
145 fn alloc_zeros_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
146 Ok(arrayfire::constant(
147 c32::new(0., 0.),
148 Dim4::new(&[size as _, 1, 1, 1]),
149 ))
150 }
151
152 fn alloc_zeros_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
153 Ok(arrayfire::constant(
154 c32::new(0., 0.),
155 Dim4::new(&[rows as _, cols as _, 1, 1]),
156 ))
157 }
158
159 fn to_host_v(&self, v: Self::VectorX) -> Result<VectorX, HoloError> {
160 let mut r = VectorX::zeros(v.elements());
161 v.host(r.as_mut_slice());
162 Ok(r)
163 }
164
165 fn to_host_m(&self, v: Self::MatrixX) -> Result<MatrixX, HoloError> {
166 let mut r = MatrixX::zeros(v.dims()[0] as _, v.dims()[1] as _);
167 v.host(r.as_mut_slice());
168 Ok(r)
169 }
170
171 fn to_host_cv(&self, v: Self::VectorXc) -> Result<VectorXc, HoloError> {
172 let n = v.elements();
173 let mut r = VectorXc::zeros(n);
174 unsafe {
175 v.host(std::slice::from_raw_parts_mut(
176 r.as_mut_ptr() as *mut c32,
177 n,
178 ));
179 }
180 Ok(r)
181 }
182
183 fn to_host_cm(&self, v: Self::MatrixXc) -> Result<MatrixXc, HoloError> {
184 let n = v.elements();
185 let mut r = MatrixXc::zeros(v.dims()[0] as _, v.dims()[1] as _);
186 unsafe {
187 v.host(std::slice::from_raw_parts_mut(
188 r.as_mut_ptr() as *mut c32,
189 n,
190 ));
191 }
192 Ok(r)
193 }
194
195 fn from_slice_v(&self, v: &[f32]) -> Result<Self::VectorX, HoloError> {
196 Ok(Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1])))
197 }
198
199 fn from_slice_m(
200 &self,
201 rows: usize,
202 cols: usize,
203 v: &[f32],
204 ) -> Result<Self::MatrixX, HoloError> {
205 Ok(Array::new(v, Dim4::new(&[rows as _, cols as _, 1, 1])))
206 }
207
208 fn from_slice_cv(&self, v: &[f32]) -> Result<Self::VectorXc, HoloError> {
209 let r = Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1]));
210 Ok(arrayfire::cplx(&r))
211 }
212
213 fn from_slice2_cv(&self, r: &[f32], i: &[f32]) -> Result<Self::VectorXc, HoloError> {
214 let r = Array::new(r, Dim4::new(&[r.len() as _, 1, 1, 1]));
215 let i = Array::new(i, Dim4::new(&[i.len() as _, 1, 1, 1]));
216 Ok(arrayfire::cplx2(&r, &i, false).cast())
217 }
218
219 fn from_slice2_cm(
220 &self,
221 rows: usize,
222 cols: usize,
223 r: &[f32],
224 i: &[f32],
225 ) -> Result<Self::MatrixXc, HoloError> {
226 let r = Array::new(r, Dim4::new(&[rows as _, cols as _, 1, 1]));
227 let i = Array::new(i, Dim4::new(&[rows as _, cols as _, 1, 1]));
228 Ok(arrayfire::cplx2(&r, &i, false).cast())
229 }
230
231 fn copy_from_slice_v(&self, v: &[f32], dst: &mut Self::VectorX) -> Result<(), HoloError> {
232 let n = v.len();
233 if n == 0 {
234 return Ok(());
235 }
236 let v = self.from_slice_v(v)?;
237 let seqs = [Seq::new(0u32, n as u32 - 1, 1)];
238 arrayfire::assign_seq(dst, &seqs, &v);
239 Ok(())
240 }
241
242 fn copy_to_v(&self, src: &Self::VectorX, dst: &mut Self::VectorX) -> Result<(), HoloError> {
243 let seqs = [Seq::new(0u32, src.elements() as u32 - 1, 1)];
244 arrayfire::assign_seq(dst, &seqs, src);
245 Ok(())
246 }
247
248 fn copy_to_m(&self, src: &Self::MatrixX, dst: &mut Self::MatrixX) -> Result<(), HoloError> {
249 let seqs = [
250 Seq::new(0u32, src.dims()[0] as u32 - 1, 1),
251 Seq::new(0u32, src.dims()[1] as u32 - 1, 1),
252 ];
253 arrayfire::assign_seq(dst, &seqs, src);
254 Ok(())
255 }
256
257 fn clone_v(&self, v: &Self::VectorX) -> Result<Self::VectorX, HoloError> {
258 Ok(v.copy())
259 }
260
261 fn clone_m(&self, v: &Self::MatrixX) -> Result<Self::MatrixX, HoloError> {
262 Ok(v.copy())
263 }
264
265 fn clone_cv(&self, v: &Self::VectorXc) -> Result<Self::VectorXc, HoloError> {
266 Ok(v.copy())
267 }
268
269 fn clone_cm(&self, v: &Self::MatrixXc) -> Result<Self::MatrixXc, HoloError> {
270 Ok(v.copy())
271 }
272
273 fn make_complex2_v(
274 &self,
275 real: &Self::VectorX,
276 imag: &Self::VectorX,
277 v: &mut Self::VectorXc,
278 ) -> Result<(), HoloError> {
279 *v = arrayfire::cplx2(real, imag, false).cast();
280 Ok(())
281 }
282
283 fn create_diagonal(&self, v: &Self::VectorX, a: &mut Self::MatrixX) -> Result<(), HoloError> {
284 *a = arrayfire::diag_create(v, 0);
285 Ok(())
286 }
287
288 fn create_diagonal_c(
289 &self,
290 v: &Self::VectorXc,
291 a: &mut Self::MatrixXc,
292 ) -> Result<(), HoloError> {
293 *a = arrayfire::diag_create(v, 0);
294 Ok(())
295 }
296
297 fn get_diagonal(&self, a: &Self::MatrixX, v: &mut Self::VectorX) -> Result<(), HoloError> {
298 *v = arrayfire::diag_extract(a, 0);
299 Ok(())
300 }
301
302 fn real_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
303 *b = arrayfire::real(a);
304 Ok(())
305 }
306
307 fn imag_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
308 *b = arrayfire::imag(a);
309 Ok(())
310 }
311
312 fn scale_assign_cv(
313 &self,
314 a: autd3_gain_holo::Complex,
315 b: &mut Self::VectorXc,
316 ) -> Result<(), HoloError> {
317 let a = c32::new(a.re, a.im);
318 *b = arrayfire::mul(b, &a, false);
319 Ok(())
320 }
321
322 fn conj_assign_v(&self, b: &mut Self::VectorXc) -> Result<(), HoloError> {
323 *b = arrayfire::conjg(b);
324 Ok(())
325 }
326
327 fn exp_assign_cv(&self, v: &mut Self::VectorXc) -> Result<(), HoloError> {
328 *v = arrayfire::exp(v);
329 Ok(())
330 }
331
332 fn concat_col_cm(
333 &self,
334 a: &Self::MatrixXc,
335 b: &Self::MatrixXc,
336 c: &mut Self::MatrixXc,
337 ) -> Result<(), HoloError> {
338 *c = arrayfire::join(1, a, b);
339 Ok(())
340 }
341
342 fn max_v(&self, m: &Self::VectorX) -> Result<f32, HoloError> {
343 Ok(arrayfire::max_all(m).0)
344 }
345
346 fn hadamard_product_cm(
347 &self,
348 x: &Self::MatrixXc,
349 y: &Self::MatrixXc,
350 z: &mut Self::MatrixXc,
351 ) -> Result<(), HoloError> {
352 *z = arrayfire::mul(x, y, false);
353 Ok(())
354 }
355
356 fn dot(&self, x: &Self::VectorX, y: &Self::VectorX) -> Result<f32, HoloError> {
357 let r = arrayfire::dot(x, y, MatProp::NONE, MatProp::NONE);
358 let mut v = [0.];
359 r.host(&mut v);
360 Ok(v[0])
361 }
362
363 fn dot_c(
364 &self,
365 x: &Self::VectorXc,
366 y: &Self::VectorXc,
367 ) -> Result<autd3_gain_holo::Complex, HoloError> {
368 let r = arrayfire::dot(x, y, MatProp::CONJ, MatProp::NONE);
369 let mut v = [c32::new(0., 0.)];
370 r.host(&mut v);
371 Ok(autd3_gain_holo::Complex::new(v[0].re, v[0].im))
372 }
373
374 fn add_v(&self, alpha: f32, a: &Self::VectorX, b: &mut Self::VectorX) -> Result<(), HoloError> {
375 *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
376 Ok(())
377 }
378
379 fn add_m(&self, alpha: f32, a: &Self::MatrixX, b: &mut Self::MatrixX) -> Result<(), HoloError> {
380 *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
381 Ok(())
382 }
383
384 fn gevv_c(
385 &self,
386 trans_a: autd3_gain_holo::Trans,
387 trans_b: autd3_gain_holo::Trans,
388 alpha: autd3_gain_holo::Complex,
389 a: &Self::VectorXc,
390 x: &Self::VectorXc,
391 beta: autd3_gain_holo::Complex,
392 y: &mut Self::MatrixXc,
393 ) -> Result<(), HoloError> {
394 let alpha = vec![c32::new(alpha.re, alpha.im)];
395 let beta = vec![c32::new(beta.re, beta.im)];
396 let trans_a = convert(trans_a);
397 let trans_b = convert(trans_b);
398 arrayfire::gemm(y, trans_a, trans_b, alpha, a, x, beta);
399 Ok(())
400 }
401
402 fn gemv_c(
403 &self,
404 trans: autd3_gain_holo::Trans,
405 alpha: autd3_gain_holo::Complex,
406 a: &Self::MatrixXc,
407 x: &Self::VectorXc,
408 beta: autd3_gain_holo::Complex,
409 y: &mut Self::VectorXc,
410 ) -> Result<(), HoloError> {
411 let alpha = vec![c32::new(alpha.re, alpha.im)];
412 let beta = vec![c32::new(beta.re, beta.im)];
413 let trans = convert(trans);
414 arrayfire::gemm(y, trans, MatProp::NONE, alpha, a, x, beta);
415 Ok(())
416 }
417
418 fn gemm_c(
419 &self,
420 trans_a: autd3_gain_holo::Trans,
421 trans_b: autd3_gain_holo::Trans,
422 alpha: autd3_gain_holo::Complex,
423 a: &Self::MatrixXc,
424 b: &Self::MatrixXc,
425 beta: autd3_gain_holo::Complex,
426 y: &mut Self::MatrixXc,
427 ) -> Result<(), HoloError> {
428 let alpha = vec![c32::new(alpha.re, alpha.im)];
429 let beta = vec![c32::new(beta.re, beta.im)];
430 let trans_a = convert(trans_a);
431 let trans_b = convert(trans_b);
432 arrayfire::gemm(y, trans_a, trans_b, alpha, a, b, beta);
433 Ok(())
434 }
435
436 fn solve_inplace(&self, a: &Self::MatrixX, x: &mut Self::VectorX) -> Result<(), HoloError> {
437 *x = arrayfire::solve(a, x, MatProp::NONE);
438 Ok(())
439 }
440
441 fn reduce_col(&self, a: &Self::MatrixX, b: &mut Self::VectorX) -> Result<(), HoloError> {
442 *b = arrayfire::sum(a, 1);
443 Ok(())
444 }
445
446 fn cols_c(&self, m: &Self::MatrixXc) -> Result<usize, HoloError> {
447 Ok(m.dims()[1] as _)
448 }
449
450 fn scaled_to_cv(
451 &self,
452 a: &Self::VectorXc,
453 b: &Self::VectorXc,
454 c: &mut Self::VectorXc,
455 ) -> Result<(), HoloError> {
456 let tmp = arrayfire::div(a, &arrayfire::abs(a), false);
457 *c = arrayfire::mul(&tmp, b, false);
458 Ok(())
459 }
460
461 fn scaled_to_assign_cv(
462 &self,
463 a: &Self::VectorXc,
464 b: &mut Self::VectorXc,
465 ) -> Result<(), HoloError> {
466 *b = arrayfire::div(b, &arrayfire::abs(b), false);
467 *b = arrayfire::mul(a, b, false);
468 Ok(())
469 }
470
471 fn gen_back_prop(
472 &self,
473 m: usize,
474 n: usize,
475 transfer: &Self::MatrixXc,
476 ) -> Result<Self::MatrixXc, HoloError> {
477 let mut b = self.alloc_zeros_cm(m, n)?;
478
479 let mut tmp = self.alloc_zeros_cm(n, n)?;
480
481 self.gemm_c(
482 Trans::NoTrans,
483 Trans::ConjTrans,
484 Complex::new(1., 0.),
485 transfer,
486 transfer,
487 Complex::new(0., 0.),
488 &mut tmp,
489 )?;
490
491 let mut denominator = arrayfire::diag_extract(&tmp, 0);
492 let a = c32::new(1., 0.);
493 denominator = arrayfire::div(&a, &denominator, false);
494
495 self.create_diagonal_c(&denominator, &mut tmp)?;
496
497 self.gemm_c(
498 Trans::ConjTrans,
499 Trans::NoTrans,
500 Complex::new(1., 0.),
501 transfer,
502 &tmp,
503 Complex::new(0., 0.),
504 &mut b,
505 )?;
506
507 Ok(b)
508 }
509
510 fn norm_squared_cv(&self, a: &Self::VectorXc, b: &mut Self::VectorX) -> Result<(), HoloError> {
511 *b = arrayfire::abs(a);
512 *b = arrayfire::mul(b, b, false);
513 Ok(())
514 }
515}
516#[cfg(test)]
517mod tests {
518 use std::f32::consts::PI;
519
520 use autd3::driver::autd3_device::AUTD3;
521 use autd3_core::{
522 acoustics::directivity::Sphere,
523 geometry::{Point3, Transducer, UnitQuaternion},
524 };
525
526 use nalgebra::{ComplexField, Normed};
527
528 use autd3_gain_holo::{Amplitude, Pa, Trans};
529
530 use super::*;
531
532 use rand::Rng;
533
534 const N: usize = 10;
535 const EPS: f32 = 1e-3;
536
537 fn generate_geometry(size: usize) -> Geometry {
538 Geometry::new(
539 (0..size)
540 .flat_map(|i| {
541 (0..size).map(move |j| {
542 AUTD3 {
543 pos: Point3::new(
544 i as f32 * AUTD3::DEVICE_WIDTH,
545 j as f32 * AUTD3::DEVICE_HEIGHT,
546 0.,
547 ),
548 rot: UnitQuaternion::identity(),
549 }
550 .into()
551 })
552 })
553 .collect(),
554 )
555 }
556
557 fn gen_foci(n: usize) -> impl Iterator<Item = (Point3, Amplitude)> {
558 (0..n).map(move |i| {
559 (
560 Point3::new(
561 90. + 10. * (2.0 * PI * i as f32 / n as f32).cos(),
562 70. + 10. * (2.0 * PI * i as f32 / n as f32).sin(),
563 150.,
564 ),
565 10e3 * Pa,
566 )
567 })
568 }
569
570 fn make_random_v(
571 backend: &ArrayFireBackend<Sphere>,
572 size: usize,
573 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorX, HoloError> {
574 let mut rng = rand::rng();
575 let v: Vec<f32> = (&mut rng)
576 .sample_iter(rand::distr::StandardUniform)
577 .take(size)
578 .collect();
579 backend.from_slice_v(&v)
580 }
581
582 fn make_random_m(
583 backend: &ArrayFireBackend<Sphere>,
584 rows: usize,
585 cols: usize,
586 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixX, HoloError> {
587 let mut rng = rand::rng();
588 let v: Vec<f32> = (&mut rng)
589 .sample_iter(rand::distr::StandardUniform)
590 .take(rows * cols)
591 .collect();
592 backend.from_slice_m(rows, cols, &v)
593 }
594
595 fn make_random_cv(
596 backend: &ArrayFireBackend<Sphere>,
597 size: usize,
598 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorXc, HoloError> {
599 let mut rng = rand::rng();
600 let real: Vec<f32> = (&mut rng)
601 .sample_iter(rand::distr::StandardUniform)
602 .take(size)
603 .collect();
604 let imag: Vec<f32> = (&mut rng)
605 .sample_iter(rand::distr::StandardUniform)
606 .take(size)
607 .collect();
608 backend.from_slice2_cv(&real, &imag)
609 }
610
611 fn make_random_cm(
612 backend: &ArrayFireBackend<Sphere>,
613 rows: usize,
614 cols: usize,
615 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixXc, HoloError> {
616 let mut rng = rand::rng();
617 let real: Vec<f32> = (&mut rng)
618 .sample_iter(rand::distr::StandardUniform)
619 .take(rows * cols)
620 .collect();
621 let imag: Vec<f32> = (&mut rng)
622 .sample_iter(rand::distr::StandardUniform)
623 .take(rows * cols)
624 .collect();
625 backend.from_slice2_cm(rows, cols, &real, &imag)
626 }
627
628 #[rstest::fixture]
629 fn backend() -> ArrayFireBackend<Sphere> {
630 ArrayFireBackend::set_backend(AFBackend::CPU);
631 ArrayFireBackend {
632 _phantom: std::marker::PhantomData,
633 }
634 }
635
636 #[rstest::rstest]
637 #[test]
638 #[cfg_attr(miri, ignore)]
639 fn test_alloc_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
640 let v = backend.alloc_v(N)?;
641 let v = backend.to_host_v(v)?;
642
643 assert_eq!(N, v.len());
644 Ok(())
645 }
646
647 #[rstest::rstest]
648 #[test]
649 #[cfg_attr(miri, ignore)]
650 fn test_alloc_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
651 let m = backend.alloc_m(N, 2 * N)?;
652 let m = backend.to_host_m(m)?;
653
654 assert_eq!(N, m.nrows());
655 assert_eq!(2 * N, m.ncols());
656 Ok(())
657 }
658
659 #[rstest::rstest]
660 #[test]
661 #[cfg_attr(miri, ignore)]
662 fn test_alloc_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
663 let v = backend.alloc_cv(N)?;
664 let v = backend.to_host_cv(v)?;
665
666 assert_eq!(N, v.len());
667 Ok(())
668 }
669
670 #[rstest::rstest]
671 #[test]
672 #[cfg_attr(miri, ignore)]
673 fn test_alloc_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
674 let m = backend.alloc_cm(N, 2 * N)?;
675 let m = backend.to_host_cm(m)?;
676
677 assert_eq!(N, m.nrows());
678 assert_eq!(2 * N, m.ncols());
679 Ok(())
680 }
681
682 #[rstest::rstest]
683 #[test]
684 #[cfg_attr(miri, ignore)]
685 fn test_alloc_zeros_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
686 let v = backend.alloc_zeros_v(N)?;
687 let v = backend.to_host_v(v)?;
688
689 assert_eq!(N, v.len());
690 assert!(v.iter().all(|&v| v == 0.));
691 Ok(())
692 }
693
694 #[rstest::rstest]
695 #[test]
696 #[cfg_attr(miri, ignore)]
697 fn test_alloc_zeros_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
698 let v = backend.alloc_zeros_cv(N)?;
699 let v = backend.to_host_cv(v)?;
700
701 assert_eq!(N, v.len());
702 assert!(v.iter().all(|&v| v == Complex::new(0., 0.)));
703 Ok(())
704 }
705
706 #[rstest::rstest]
707 #[test]
708 #[cfg_attr(miri, ignore)]
709 fn test_alloc_zeros_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
710 let m = backend.alloc_zeros_cm(N, 2 * N)?;
711 let m = backend.to_host_cm(m)?;
712
713 assert_eq!(N, m.nrows());
714 assert_eq!(2 * N, m.ncols());
715 assert!(m.iter().all(|&v| v == Complex::new(0., 0.)));
716 Ok(())
717 }
718
719 #[rstest::rstest]
720 #[test]
721 #[cfg_attr(miri, ignore)]
722 fn test_cols_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
723 let m = backend.alloc_cm(N, 2 * N)?;
724
725 assert_eq!(2 * N, backend.cols_c(&m)?);
726
727 Ok(())
728 }
729
730 #[rstest::rstest]
731 #[test]
732 #[cfg_attr(miri, ignore)]
733 fn test_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
734 let rng = rand::rng();
735
736 let v: Vec<f32> = rng
737 .sample_iter(rand::distr::StandardUniform)
738 .take(N)
739 .collect();
740
741 let c = backend.from_slice_v(&v)?;
742 let c = backend.to_host_v(c)?;
743
744 assert_eq!(N, c.len());
745 v.iter().zip(c.iter()).for_each(|(&r, &c)| {
746 assert_eq!(r, c);
747 });
748 Ok(())
749 }
750
751 #[rstest::rstest]
752 #[test]
753 #[cfg_attr(miri, ignore)]
754 fn test_from_slice_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
755 let rng = rand::rng();
756
757 let v: Vec<f32> = rng
758 .sample_iter(rand::distr::StandardUniform)
759 .take(N * 2 * N)
760 .collect();
761
762 let c = backend.from_slice_m(N, 2 * N, &v)?;
763 let c = backend.to_host_m(c)?;
764
765 assert_eq!(N, c.nrows());
766 assert_eq!(2 * N, c.ncols());
767 (0..2 * N).for_each(|col| {
768 (0..N).for_each(|row| {
769 assert_eq!(v[col * N + row], c[(row, col)]);
770 })
771 });
772 Ok(())
773 }
774
775 #[rstest::rstest]
776 #[test]
777 #[cfg_attr(miri, ignore)]
778 fn test_from_slice_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
779 let rng = rand::rng();
780
781 let real: Vec<f32> = rng
782 .sample_iter(rand::distr::StandardUniform)
783 .take(N)
784 .collect();
785
786 let c = backend.from_slice_cv(&real)?;
787 let c = backend.to_host_cv(c)?;
788
789 assert_eq!(N, c.len());
790 real.iter().zip(c.iter()).for_each(|(r, c)| {
791 assert_eq!(r, &c.re);
792 assert_eq!(0.0, c.im);
793 });
794 Ok(())
795 }
796
797 #[rstest::rstest]
798 #[test]
799 #[cfg_attr(miri, ignore)]
800 fn test_from_slice2_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
801 let mut rng = rand::rng();
802
803 let real: Vec<f32> = (&mut rng)
804 .sample_iter(rand::distr::StandardUniform)
805 .take(N)
806 .collect();
807 let imag: Vec<f32> = (&mut rng)
808 .sample_iter(rand::distr::StandardUniform)
809 .take(N)
810 .collect();
811
812 let c = backend.from_slice2_cv(&real, &imag)?;
813 let c = backend.to_host_cv(c)?;
814
815 assert_eq!(N, c.len());
816 real.iter()
817 .zip(imag.iter())
818 .zip(c.iter())
819 .for_each(|((r, i), c)| {
820 assert_eq!(r, &c.re);
821 assert_eq!(i, &c.im);
822 });
823 Ok(())
824 }
825
826 #[rstest::rstest]
827 #[test]
828 #[cfg_attr(miri, ignore)]
829 fn test_from_slice2_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
830 let mut rng = rand::rng();
831
832 let real: Vec<f32> = (&mut rng)
833 .sample_iter(rand::distr::StandardUniform)
834 .take(N * 2 * N)
835 .collect();
836 let imag: Vec<f32> = (&mut rng)
837 .sample_iter(rand::distr::StandardUniform)
838 .take(N * 2 * N)
839 .collect();
840
841 let c = backend.from_slice2_cm(N, 2 * N, &real, &imag)?;
842 let c = backend.to_host_cm(c)?;
843
844 assert_eq!(N, c.nrows());
845 assert_eq!(2 * N, c.ncols());
846 (0..2 * N).for_each(|col| {
847 (0..N).for_each(|row| {
848 assert_eq!(real[col * N + row], c[(row, col)].re);
849 assert_eq!(imag[col * N + row], c[(row, col)].im);
850 })
851 });
852 Ok(())
853 }
854
855 #[rstest::rstest]
856 #[test]
857 #[cfg_attr(miri, ignore)]
858 fn test_copy_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
859 {
860 let mut a = backend.alloc_zeros_v(N)?;
861 let mut rng = rand::rng();
862 let v = (&mut rng)
863 .sample_iter(rand::distr::StandardUniform)
864 .take(N / 2)
865 .collect::<Vec<f32>>();
866
867 backend.copy_from_slice_v(&v, &mut a)?;
868
869 let a = backend.to_host_v(a)?;
870 (0..N / 2).for_each(|i| {
871 assert_eq!(v[i], a[i]);
872 });
873 (N / 2..N).for_each(|i| {
874 assert_eq!(0., a[i]);
875 });
876 }
877
878 {
879 let mut a = backend.alloc_zeros_v(N)?;
880 let v = [];
881
882 backend.copy_from_slice_v(&v, &mut a)?;
883
884 let a = backend.to_host_v(a)?;
885 a.iter().for_each(|&a| {
886 assert_eq!(0., a);
887 });
888 }
889
890 Ok(())
891 }
892
893 #[rstest::rstest]
894 #[test]
895 #[cfg_attr(miri, ignore)]
896 fn test_copy_to_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
897 let a = make_random_v(&backend, N)?;
898 let mut b = backend.alloc_v(N)?;
899
900 backend.copy_to_v(&a, &mut b)?;
901
902 let a = backend.to_host_v(a)?;
903 let b = backend.to_host_v(b)?;
904 a.iter().zip(b.iter()).for_each(|(a, b)| {
905 assert_eq!(a, b);
906 });
907 Ok(())
908 }
909
910 #[rstest::rstest]
911 #[test]
912 #[cfg_attr(miri, ignore)]
913 fn test_copy_to_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
914 let a = make_random_m(&backend, N, N)?;
915 let mut b = backend.alloc_m(N, N)?;
916
917 backend.copy_to_m(&a, &mut b)?;
918
919 let a = backend.to_host_m(a)?;
920 let b = backend.to_host_m(b)?;
921 a.iter().zip(b.iter()).for_each(|(a, b)| {
922 assert_eq!(a, b);
923 });
924 Ok(())
925 }
926
927 #[rstest::rstest]
928 #[test]
929 #[cfg_attr(miri, ignore)]
930 fn test_clone_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
931 let c = make_random_v(&backend, N)?;
932 let c2 = backend.clone_v(&c)?;
933
934 let c = backend.to_host_v(c)?;
935 let c2 = backend.to_host_v(c2)?;
936
937 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
938 assert_eq!(c, c2);
939 });
940 Ok(())
941 }
942
943 #[rstest::rstest]
944 #[test]
945 #[cfg_attr(miri, ignore)]
946 fn test_clone_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
947 let c = make_random_m(&backend, N, N)?;
948 let c2 = backend.clone_m(&c)?;
949
950 let c = backend.to_host_m(c)?;
951 let c2 = backend.to_host_m(c2)?;
952
953 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
954 assert_eq!(c, c2);
955 });
956 Ok(())
957 }
958
959 #[rstest::rstest]
960 #[test]
961 #[cfg_attr(miri, ignore)]
962 fn test_clone_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
963 let c = make_random_cv(&backend, N)?;
964 let c2 = backend.clone_cv(&c)?;
965
966 let c = backend.to_host_cv(c)?;
967 let c2 = backend.to_host_cv(c2)?;
968
969 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
970 assert_eq!(c.re, c2.re);
971 assert_eq!(c.im, c2.im);
972 });
973 Ok(())
974 }
975
976 #[rstest::rstest]
977 #[test]
978 #[cfg_attr(miri, ignore)]
979 fn test_clone_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
980 let c = make_random_cm(&backend, N, N)?;
981 let c2 = backend.clone_cm(&c)?;
982
983 let c = backend.to_host_cm(c)?;
984 let c2 = backend.to_host_cm(c2)?;
985
986 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
987 assert_eq!(c.re, c2.re);
988 assert_eq!(c.im, c2.im);
989 });
990 Ok(())
991 }
992
993 #[rstest::rstest]
994 #[test]
995 #[cfg_attr(miri, ignore)]
996 fn test_make_complex2_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
997 let real = make_random_v(&backend, N)?;
998 let imag = make_random_v(&backend, N)?;
999
1000 let mut c = backend.alloc_cv(N)?;
1001 backend.make_complex2_v(&real, &imag, &mut c)?;
1002
1003 let real = backend.to_host_v(real)?;
1004 let imag = backend.to_host_v(imag)?;
1005 let c = backend.to_host_cv(c)?;
1006 real.iter()
1007 .zip(imag.iter())
1008 .zip(c.iter())
1009 .for_each(|((r, i), c)| {
1010 assert_eq!(r, &c.re);
1011 assert_eq!(i, &c.im);
1012 });
1013 Ok(())
1014 }
1015
1016 #[rstest::rstest]
1017 #[test]
1018 #[cfg_attr(miri, ignore)]
1019 fn test_create_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1020 let diagonal = make_random_v(&backend, N)?;
1021
1022 let mut c = backend.alloc_m(N, N)?;
1023
1024 backend.create_diagonal(&diagonal, &mut c)?;
1025
1026 let diagonal = backend.to_host_v(diagonal)?;
1027 let c = backend.to_host_m(c)?;
1028 (0..N).for_each(|i| {
1029 (0..N).for_each(|j| {
1030 if i == j {
1031 assert_eq!(diagonal[i], c[(i, j)]);
1032 } else {
1033 assert_eq!(0.0, c[(i, j)]);
1034 }
1035 })
1036 });
1037 Ok(())
1038 }
1039
1040 #[rstest::rstest]
1041 #[test]
1042 #[cfg_attr(miri, ignore)]
1043 fn test_create_diagonal_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1044 let diagonal = make_random_cv(&backend, N)?;
1045
1046 let mut c = backend.alloc_cm(N, N)?;
1047
1048 backend.create_diagonal_c(&diagonal, &mut c)?;
1049
1050 let diagonal = backend.to_host_cv(diagonal)?;
1051 let c = backend.to_host_cm(c)?;
1052 (0..N).for_each(|i| {
1053 (0..N).for_each(|j| {
1054 if i == j {
1055 assert_eq!(diagonal[i].re, c[(i, j)].re);
1056 assert_eq!(diagonal[i].im, c[(i, j)].im);
1057 } else {
1058 assert_eq!(0.0, c[(i, j)].re);
1059 assert_eq!(0.0, c[(i, j)].im);
1060 }
1061 })
1062 });
1063 Ok(())
1064 }
1065
1066 #[rstest::rstest]
1067 #[test]
1068 #[cfg_attr(miri, ignore)]
1069 fn test_get_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1070 let m = make_random_m(&backend, N, N)?;
1071 let mut diagonal = backend.alloc_v(N)?;
1072
1073 backend.get_diagonal(&m, &mut diagonal)?;
1074
1075 let m = backend.to_host_m(m)?;
1076 let diagonal = backend.to_host_v(diagonal)?;
1077 (0..N).for_each(|i| {
1078 assert_eq!(m[(i, i)], diagonal[i]);
1079 });
1080 Ok(())
1081 }
1082
1083 #[rstest::rstest]
1084 #[test]
1085 #[cfg_attr(miri, ignore)]
1086 fn test_norm_squared_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1087 let v = make_random_cv(&backend, N)?;
1088
1089 let mut abs = backend.alloc_v(N)?;
1090 backend.norm_squared_cv(&v, &mut abs)?;
1091
1092 let v = backend.to_host_cv(v)?;
1093 let abs = backend.to_host_v(abs)?;
1094 v.iter().zip(abs.iter()).for_each(|(v, abs)| {
1095 assert_approx_eq::assert_approx_eq!(v.norm_squared(), abs, EPS);
1096 });
1097 Ok(())
1098 }
1099
1100 #[rstest::rstest]
1101 #[test]
1102 #[cfg_attr(miri, ignore)]
1103 fn test_real_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1104 let v = make_random_cm(&backend, N, N)?;
1105 let mut r = backend.alloc_m(N, N)?;
1106
1107 backend.real_cm(&v, &mut r)?;
1108
1109 let v = backend.to_host_cm(v)?;
1110 let r = backend.to_host_m(r)?;
1111 (0..N).for_each(|i| {
1112 (0..N).for_each(|j| {
1113 assert_approx_eq::assert_approx_eq!(v[(i, j)].re, r[(i, j)], EPS);
1114 })
1115 });
1116 Ok(())
1117 }
1118
1119 #[rstest::rstest]
1120 #[test]
1121 #[cfg_attr(miri, ignore)]
1122 fn test_imag_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1123 let v = make_random_cm(&backend, N, N)?;
1124 let mut r = backend.alloc_m(N, N)?;
1125
1126 backend.imag_cm(&v, &mut r)?;
1127
1128 let v = backend.to_host_cm(v)?;
1129 let r = backend.to_host_m(r)?;
1130 (0..N).for_each(|i| {
1131 (0..N).for_each(|j| {
1132 assert_approx_eq::assert_approx_eq!(v[(i, j)].im, r[(i, j)], EPS);
1133 })
1134 });
1135 Ok(())
1136 }
1137
1138 #[rstest::rstest]
1139 #[test]
1140 #[cfg_attr(miri, ignore)]
1141 fn test_scale_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1142 let mut v = make_random_cv(&backend, N)?;
1143 let vc = backend.clone_cv(&v)?;
1144 let mut rng = rand::rng();
1145 let scale = Complex::new(rng.random(), rng.random());
1146
1147 backend.scale_assign_cv(scale, &mut v)?;
1148
1149 let v = backend.to_host_cv(v)?;
1150 let vc = backend.to_host_cv(vc)?;
1151 v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1152 assert_approx_eq::assert_approx_eq!(scale * vc, v, EPS);
1153 });
1154 Ok(())
1155 }
1156
1157 #[rstest::rstest]
1158 #[test]
1159 #[cfg_attr(miri, ignore)]
1160 fn test_conj_assign_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1161 let mut v = make_random_cv(&backend, N)?;
1162 let vc = backend.clone_cv(&v)?;
1163
1164 backend.conj_assign_v(&mut v)?;
1165
1166 let v = backend.to_host_cv(v)?;
1167 let vc = backend.to_host_cv(vc)?;
1168 v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1169 assert_eq!(vc.re, v.re);
1170 assert_eq!(vc.im, -v.im);
1171 });
1172 Ok(())
1173 }
1174
1175 #[rstest::rstest]
1176 #[test]
1177 #[cfg_attr(miri, ignore)]
1178 fn test_exp_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1179 let mut v = make_random_cv(&backend, N)?;
1180 let vc = backend.clone_cv(&v)?;
1181
1182 backend.exp_assign_cv(&mut v)?;
1183
1184 let v = backend.to_host_cv(v)?;
1185 let vc = backend.to_host_cv(vc)?;
1186 v.iter().zip(vc.iter()).for_each(|(v, vc)| {
1187 assert_approx_eq::assert_approx_eq!(vc.exp(), v, EPS);
1188 });
1189 Ok(())
1190 }
1191
1192 #[rstest::rstest]
1193 #[test]
1194 #[cfg_attr(miri, ignore)]
1195 fn test_concat_col_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1196 let a = make_random_cm(&backend, N, N)?;
1197 let b = make_random_cm(&backend, N, 2 * N)?;
1198 let mut c = backend.alloc_cm(N, N + 2 * N)?;
1199
1200 backend.concat_col_cm(&a, &b, &mut c)?;
1201
1202 let a = backend.to_host_cm(a)?;
1203 let b = backend.to_host_cm(b)?;
1204 let c = backend.to_host_cm(c)?;
1205 (0..N).for_each(|col| (0..N).for_each(|row| assert_eq!(a[(row, col)], c[(row, col)])));
1206 (0..2 * N)
1207 .for_each(|col| (0..N).for_each(|row| assert_eq!(b[(row, col)], c[(row, N + col)])));
1208 Ok(())
1209 }
1210
1211 #[rstest::rstest]
1212 #[test]
1213 #[cfg_attr(miri, ignore)]
1214 fn test_max_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1215 let v = make_random_v(&backend, N)?;
1216
1217 let max = backend.max_v(&v)?;
1218
1219 let v = backend.to_host_v(v)?;
1220 assert_eq!(
1221 *v.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
1222 max
1223 );
1224 Ok(())
1225 }
1226
1227 #[rstest::rstest]
1228 #[test]
1229 #[cfg_attr(miri, ignore)]
1230 fn test_hadamard_product_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1231 let a = make_random_cm(&backend, N, N)?;
1232 let b = make_random_cm(&backend, N, N)?;
1233 let mut c = backend.alloc_cm(N, N)?;
1234
1235 backend.hadamard_product_cm(&a, &b, &mut c)?;
1236
1237 let a = backend.to_host_cm(a)?;
1238 let b = backend.to_host_cm(b)?;
1239 let c = backend.to_host_cm(c)?;
1240 c.iter()
1241 .zip(a.iter())
1242 .zip(b.iter())
1243 .for_each(|((c, a), b)| {
1244 assert_approx_eq::assert_approx_eq!(a.re * b.re - a.im * b.im, c.re, EPS);
1245 assert_approx_eq::assert_approx_eq!(a.re * b.im + a.im * b.re, c.im, EPS);
1246 });
1247 Ok(())
1248 }
1249
1250 #[rstest::rstest]
1251 #[test]
1252 #[cfg_attr(miri, ignore)]
1253 fn test_dot(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1254 let a = make_random_v(&backend, N)?;
1255 let b = make_random_v(&backend, N)?;
1256
1257 let dot = backend.dot(&a, &b)?;
1258
1259 let a = backend.to_host_v(a)?;
1260 let b = backend.to_host_v(b)?;
1261 let expect = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum::<f32>();
1262 assert_approx_eq::assert_approx_eq!(dot, expect, EPS);
1263 Ok(())
1264 }
1265
1266 #[rstest::rstest]
1267 #[test]
1268 #[cfg_attr(miri, ignore)]
1269 fn test_dot_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1270 let a = make_random_cv(&backend, N)?;
1271 let b = make_random_cv(&backend, N)?;
1272
1273 let dot = backend.dot_c(&a, &b)?;
1274
1275 let a = backend.to_host_cv(a)?;
1276 let b = backend.to_host_cv(b)?;
1277 let expect = a
1278 .iter()
1279 .zip(b.iter())
1280 .map(|(a, b)| a.conj() * b)
1281 .sum::<Complex>();
1282 assert_approx_eq::assert_approx_eq!(dot.re, expect.re, EPS);
1283 assert_approx_eq::assert_approx_eq!(dot.im, expect.im, EPS);
1284 Ok(())
1285 }
1286
1287 #[rstest::rstest]
1288 #[test]
1289 #[cfg_attr(miri, ignore)]
1290 fn test_add_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1291 let a = make_random_v(&backend, N)?;
1292 let mut b = make_random_v(&backend, N)?;
1293 let bc = backend.clone_v(&b)?;
1294
1295 let mut rng = rand::rng();
1296 let alpha = rng.random();
1297
1298 backend.add_v(alpha, &a, &mut b)?;
1299
1300 let a = backend.to_host_v(a)?;
1301 let b = backend.to_host_v(b)?;
1302 let bc = backend.to_host_v(bc)?;
1303 b.iter()
1304 .zip(a.iter())
1305 .zip(bc.iter())
1306 .for_each(|((b, a), bc)| {
1307 assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1308 });
1309 Ok(())
1310 }
1311
1312 #[rstest::rstest]
1313 #[test]
1314 #[cfg_attr(miri, ignore)]
1315 fn test_add_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1316 let a = make_random_m(&backend, N, N)?;
1317 let mut b = make_random_m(&backend, N, N)?;
1318 let bc = backend.clone_m(&b)?;
1319
1320 let mut rng = rand::rng();
1321 let alpha = rng.random();
1322
1323 backend.add_m(alpha, &a, &mut b)?;
1324
1325 let a = backend.to_host_m(a)?;
1326 let b = backend.to_host_m(b)?;
1327 let bc = backend.to_host_m(bc)?;
1328 b.iter()
1329 .zip(a.iter())
1330 .zip(bc.iter())
1331 .for_each(|((b, a), bc)| {
1332 assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1333 });
1334 Ok(())
1335 }
1336
1337 #[rstest::rstest]
1338 #[test]
1339 #[cfg_attr(miri, ignore)]
1340 fn test_gevv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1341 let mut rng = rand::rng();
1342
1343 {
1344 let a = make_random_cv(&backend, N)?;
1345 let b = make_random_cv(&backend, N)?;
1346 let mut c = make_random_cm(&backend, N, N)?;
1347 let cc = backend.clone_cm(&c)?;
1348
1349 let alpha = Complex::new(rng.random(), rng.random());
1350 let beta = Complex::new(rng.random(), rng.random());
1351 backend.gevv_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1352
1353 let a = backend.to_host_cv(a)?;
1354 let b = backend.to_host_cv(b)?;
1355 let c = backend.to_host_cm(c)?;
1356 let cc = backend.to_host_cm(cc)?;
1357 let expected = a * b.transpose() * alpha + cc * beta;
1358 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1359 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1360 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1361 });
1362 }
1363
1364 {
1365 let a = make_random_cv(&backend, N)?;
1366 let b = make_random_cv(&backend, N)?;
1367 let mut c = make_random_cm(&backend, N, N)?;
1368 let cc = backend.clone_cm(&c)?;
1369
1370 let alpha = Complex::new(rng.random(), rng.random());
1371 let beta = Complex::new(rng.random(), rng.random());
1372 backend.gevv_c(
1373 Trans::NoTrans,
1374 Trans::ConjTrans,
1375 alpha,
1376 &a,
1377 &b,
1378 beta,
1379 &mut c,
1380 )?;
1381
1382 let a = backend.to_host_cv(a)?;
1383 let b = backend.to_host_cv(b)?;
1384 let c = backend.to_host_cm(c)?;
1385 let cc = backend.to_host_cm(cc)?;
1386 let expected = a * b.adjoint() * alpha + cc * beta;
1387 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1388 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1389 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1390 });
1391 }
1392
1393 {
1394 let a = make_random_cv(&backend, N)?;
1395 let b = make_random_cv(&backend, N)?;
1396 let mut c = make_random_cm(&backend, 1, 1)?;
1397 let cc = backend.clone_cm(&c)?;
1398
1399 let alpha = Complex::new(rng.random(), rng.random());
1400 let beta = Complex::new(rng.random(), rng.random());
1401 backend.gevv_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1402
1403 let a = backend.to_host_cv(a)?;
1404 let b = backend.to_host_cv(b)?;
1405 let c = backend.to_host_cm(c)?;
1406 let cc = backend.to_host_cm(cc)?;
1407 let expected = a.transpose() * b * alpha + cc * beta;
1408 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1409 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1410 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1411 });
1412 }
1413
1414 {
1415 let a = make_random_cv(&backend, N)?;
1416 let b = make_random_cv(&backend, N)?;
1417 let mut c = make_random_cm(&backend, 1, 1)?;
1418 let cc = backend.clone_cm(&c)?;
1419
1420 let alpha = Complex::new(rng.random(), rng.random());
1421 let beta = Complex::new(rng.random(), rng.random());
1422 backend.gevv_c(
1423 Trans::ConjTrans,
1424 Trans::NoTrans,
1425 alpha,
1426 &a,
1427 &b,
1428 beta,
1429 &mut c,
1430 )?;
1431
1432 let a = backend.to_host_cv(a)?;
1433 let b = backend.to_host_cv(b)?;
1434 let c = backend.to_host_cm(c)?;
1435 let cc = backend.to_host_cm(cc)?;
1436 let expected = a.adjoint() * b * alpha + cc * beta;
1437 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1438 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1439 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1440 });
1441 }
1442
1443 Ok(())
1444 }
1445
1446 #[rstest::rstest]
1447 #[test]
1448 #[cfg_attr(miri, ignore)]
1449 fn test_gemv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1450 let m = N;
1451 let n = 2 * N;
1452
1453 let mut rng = rand::rng();
1454
1455 {
1456 let a = make_random_cm(&backend, m, n)?;
1457 let b = make_random_cv(&backend, n)?;
1458 let mut c = make_random_cv(&backend, m)?;
1459 let cc = backend.clone_cv(&c)?;
1460
1461 let alpha = Complex::new(rng.random(), rng.random());
1462 let beta = Complex::new(rng.random(), rng.random());
1463 backend.gemv_c(Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1464
1465 let a = backend.to_host_cm(a)?;
1466 let b = backend.to_host_cv(b)?;
1467 let c = backend.to_host_cv(c)?;
1468 let cc = backend.to_host_cv(cc)?;
1469 let expected = a * b * alpha + cc * beta;
1470 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1471 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1472 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1473 });
1474 }
1475
1476 {
1477 let a = make_random_cm(&backend, n, m)?;
1478 let b = make_random_cv(&backend, n)?;
1479 let mut c = make_random_cv(&backend, m)?;
1480 let cc = backend.clone_cv(&c)?;
1481
1482 let alpha = Complex::new(rng.random(), rng.random());
1483 let beta = Complex::new(rng.random(), rng.random());
1484 backend.gemv_c(Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1485
1486 let a = backend.to_host_cm(a)?;
1487 let b = backend.to_host_cv(b)?;
1488 let c = backend.to_host_cv(c)?;
1489 let cc = backend.to_host_cv(cc)?;
1490 let expected = a.transpose() * b * alpha + cc * beta;
1491 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1492 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1493 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1494 });
1495 }
1496
1497 {
1498 let a = make_random_cm(&backend, n, m)?;
1499 let b = make_random_cv(&backend, n)?;
1500 let mut c = make_random_cv(&backend, m)?;
1501 let cc = backend.clone_cv(&c)?;
1502
1503 let alpha = Complex::new(rng.random(), rng.random());
1504 let beta = Complex::new(rng.random(), rng.random());
1505 backend.gemv_c(Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1506
1507 let a = backend.to_host_cm(a)?;
1508 let b = backend.to_host_cv(b)?;
1509 let c = backend.to_host_cv(c)?;
1510 let cc = backend.to_host_cv(cc)?;
1511 let expected = a.adjoint() * b * alpha + cc * beta;
1512 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1513 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1514 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1515 });
1516 }
1517 Ok(())
1518 }
1519
1520 #[rstest::rstest]
1521 #[test]
1522 #[cfg_attr(miri, ignore)]
1523 fn test_gemm_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1524 let m = N;
1525 let n = 2 * N;
1526 let k = 3 * N;
1527
1528 let mut rng = rand::rng();
1529
1530 {
1531 let a = make_random_cm(&backend, m, k)?;
1532 let b = make_random_cm(&backend, k, n)?;
1533 let mut c = make_random_cm(&backend, m, n)?;
1534 let cc = backend.clone_cm(&c)?;
1535
1536 let alpha = Complex::new(rng.random(), rng.random());
1537 let beta = Complex::new(rng.random(), rng.random());
1538 backend.gemm_c(Trans::NoTrans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1539
1540 let a = backend.to_host_cm(a)?;
1541 let b = backend.to_host_cm(b)?;
1542 let c = backend.to_host_cm(c)?;
1543 let cc = backend.to_host_cm(cc)?;
1544 let expected = a * b * alpha + cc * beta;
1545 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1546 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1547 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1548 });
1549 }
1550
1551 {
1552 let a = make_random_cm(&backend, m, k)?;
1553 let b = make_random_cm(&backend, n, k)?;
1554 let mut c = make_random_cm(&backend, m, n)?;
1555 let cc = backend.clone_cm(&c)?;
1556
1557 let alpha = Complex::new(rng.random(), rng.random());
1558 let beta = Complex::new(rng.random(), rng.random());
1559 backend.gemm_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1560
1561 let a = backend.to_host_cm(a)?;
1562 let b = backend.to_host_cm(b)?;
1563 let c = backend.to_host_cm(c)?;
1564 let cc = backend.to_host_cm(cc)?;
1565 let expected = a * b.transpose() * alpha + cc * beta;
1566 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1567 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1568 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1569 });
1570 }
1571
1572 {
1573 let a = make_random_cm(&backend, m, k)?;
1574 let b = make_random_cm(&backend, n, k)?;
1575 let mut c = make_random_cm(&backend, m, n)?;
1576 let cc = backend.clone_cm(&c)?;
1577
1578 let alpha = Complex::new(rng.random(), rng.random());
1579 let beta = Complex::new(rng.random(), rng.random());
1580 backend.gemm_c(
1581 Trans::NoTrans,
1582 Trans::ConjTrans,
1583 alpha,
1584 &a,
1585 &b,
1586 beta,
1587 &mut c,
1588 )?;
1589
1590 let a = backend.to_host_cm(a)?;
1591 let b = backend.to_host_cm(b)?;
1592 let c = backend.to_host_cm(c)?;
1593 let cc = backend.to_host_cm(cc)?;
1594 let expected = a * b.adjoint() * alpha + cc * beta;
1595 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1596 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1597 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1598 });
1599 }
1600
1601 {
1602 let a = make_random_cm(&backend, k, m)?;
1603 let b = make_random_cm(&backend, k, n)?;
1604 let mut c = make_random_cm(&backend, m, n)?;
1605 let cc = backend.clone_cm(&c)?;
1606
1607 let alpha = Complex::new(rng.random(), rng.random());
1608 let beta = Complex::new(rng.random(), rng.random());
1609 backend.gemm_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1610
1611 let a = backend.to_host_cm(a)?;
1612 let b = backend.to_host_cm(b)?;
1613 let c = backend.to_host_cm(c)?;
1614 let cc = backend.to_host_cm(cc)?;
1615 let expected = a.transpose() * b * alpha + cc * beta;
1616 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1617 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1618 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1619 });
1620 }
1621
1622 {
1623 let a = make_random_cm(&backend, k, m)?;
1624 let b = make_random_cm(&backend, n, k)?;
1625 let mut c = make_random_cm(&backend, m, n)?;
1626 let cc = backend.clone_cm(&c)?;
1627
1628 let alpha = Complex::new(rng.random(), rng.random());
1629 let beta = Complex::new(rng.random(), rng.random());
1630 backend.gemm_c(Trans::Trans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1631
1632 let a = backend.to_host_cm(a)?;
1633 let b = backend.to_host_cm(b)?;
1634 let c = backend.to_host_cm(c)?;
1635 let cc = backend.to_host_cm(cc)?;
1636 let expected = a.transpose() * b.transpose() * alpha + cc * beta;
1637 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1638 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1639 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1640 });
1641 }
1642
1643 {
1644 let a = make_random_cm(&backend, k, m)?;
1645 let b = make_random_cm(&backend, n, k)?;
1646 let mut c = make_random_cm(&backend, m, n)?;
1647 let cc = backend.clone_cm(&c)?;
1648
1649 let alpha = Complex::new(rng.random(), rng.random());
1650 let beta = Complex::new(rng.random(), rng.random());
1651 backend.gemm_c(Trans::Trans, Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1652
1653 let a = backend.to_host_cm(a)?;
1654 let b = backend.to_host_cm(b)?;
1655 let c = backend.to_host_cm(c)?;
1656 let cc = backend.to_host_cm(cc)?;
1657 let expected = a.transpose() * b.adjoint() * alpha + cc * beta;
1658 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1659 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1660 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1661 });
1662 }
1663
1664 {
1665 let a = make_random_cm(&backend, k, m)?;
1666 let b = make_random_cm(&backend, k, n)?;
1667 let mut c = make_random_cm(&backend, m, n)?;
1668 let cc = backend.clone_cm(&c)?;
1669
1670 let alpha = Complex::new(rng.random(), rng.random());
1671 let beta = Complex::new(rng.random(), rng.random());
1672 backend.gemm_c(
1673 Trans::ConjTrans,
1674 Trans::NoTrans,
1675 alpha,
1676 &a,
1677 &b,
1678 beta,
1679 &mut c,
1680 )?;
1681
1682 let a = backend.to_host_cm(a)?;
1683 let b = backend.to_host_cm(b)?;
1684 let c = backend.to_host_cm(c)?;
1685 let cc = backend.to_host_cm(cc)?;
1686 let expected = a.adjoint() * b * alpha + cc * beta;
1687 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1688 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1689 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1690 });
1691 }
1692
1693 {
1694 let a = make_random_cm(&backend, k, m)?;
1695 let b = make_random_cm(&backend, n, k)?;
1696 let mut c = make_random_cm(&backend, m, n)?;
1697 let cc = backend.clone_cm(&c)?;
1698
1699 let alpha = Complex::new(rng.random(), rng.random());
1700 let beta = Complex::new(rng.random(), rng.random());
1701 backend.gemm_c(Trans::ConjTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1702
1703 let a = backend.to_host_cm(a)?;
1704 let b = backend.to_host_cm(b)?;
1705 let c = backend.to_host_cm(c)?;
1706 let cc = backend.to_host_cm(cc)?;
1707 let expected = a.adjoint() * b.transpose() * alpha + cc * beta;
1708 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1709 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1710 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1711 });
1712 }
1713
1714 {
1715 let a = make_random_cm(&backend, k, m)?;
1716 let b = make_random_cm(&backend, n, k)?;
1717 let mut c = make_random_cm(&backend, m, n)?;
1718 let cc = backend.clone_cm(&c)?;
1719
1720 let alpha = Complex::new(rng.random(), rng.random());
1721 let beta = Complex::new(rng.random(), rng.random());
1722 backend.gemm_c(
1723 Trans::ConjTrans,
1724 Trans::ConjTrans,
1725 alpha,
1726 &a,
1727 &b,
1728 beta,
1729 &mut c,
1730 )?;
1731
1732 let a = backend.to_host_cm(a)?;
1733 let b = backend.to_host_cm(b)?;
1734 let c = backend.to_host_cm(c)?;
1735 let cc = backend.to_host_cm(cc)?;
1736 let expected = a.adjoint() * b.adjoint() * alpha + cc * beta;
1737 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1738 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1739 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1740 });
1741 }
1742 Ok(())
1743 }
1744
1745 #[rstest::rstest]
1746 #[test]
1747 #[cfg_attr(miri, ignore)]
1748 fn test_solve_inplace(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1749 {
1750 let tmp = make_random_m(&backend, N, N)?;
1751 let tmp = backend.to_host_m(tmp)?;
1752
1753 let a = &tmp * tmp.adjoint();
1754
1755 let mut rng = rand::rng();
1756 let x = VectorX::from_iterator(N, (0..N).map(|_| rng.random()));
1757
1758 let b = &a * &x;
1759
1760 let aa = backend.from_slice_m(N, N, a.as_slice())?;
1761 let mut bb = backend.from_slice_v(b.as_slice())?;
1762
1763 backend.solve_inplace(&aa, &mut bb)?;
1764
1765 let b2 = &a * backend.to_host_v(bb)?;
1766 assert!(approx::relative_eq!(b, b2, epsilon = 1e-3));
1767 }
1768
1769 Ok(())
1770 }
1771
1772 #[rstest::rstest]
1773 #[test]
1774 #[cfg_attr(miri, ignore)]
1775 fn test_reduce_col(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1776 let a = make_random_m(&backend, N, N)?;
1777
1778 let mut b = backend.alloc_v(N)?;
1779
1780 backend.reduce_col(&a, &mut b)?;
1781
1782 let a = backend.to_host_m(a)?;
1783 let b = backend.to_host_v(b)?;
1784
1785 (0..N).for_each(|row| {
1786 let sum = a.row(row).iter().sum::<f32>();
1787 assert_approx_eq::assert_approx_eq!(sum, b[row], EPS);
1788 });
1789 Ok(())
1790 }
1791
1792 #[rstest::rstest]
1793 #[test]
1794 #[cfg_attr(miri, ignore)]
1795 fn test_scaled_to_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1796 let a = make_random_cv(&backend, N)?;
1797 let b = make_random_cv(&backend, N)?;
1798 let mut c = backend.alloc_cv(N)?;
1799
1800 backend.scaled_to_cv(&a, &b, &mut c)?;
1801
1802 let a = backend.to_host_cv(a)?;
1803 let b = backend.to_host_cv(b)?;
1804 let c = backend.to_host_cv(c)?;
1805 c.iter()
1806 .zip(a.iter())
1807 .zip(b.iter())
1808 .for_each(|((&c, &a), &b)| {
1809 assert_approx_eq::assert_approx_eq!(c, a / a.abs() * b, EPS);
1810 });
1811
1812 Ok(())
1813 }
1814
1815 #[rstest::rstest]
1816 #[test]
1817 #[cfg_attr(miri, ignore)]
1818 fn test_scaled_to_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1819 let a = make_random_cv(&backend, N)?;
1820 let mut b = make_random_cv(&backend, N)?;
1821 let bc = backend.clone_cv(&b)?;
1822
1823 backend.scaled_to_assign_cv(&a, &mut b)?;
1824
1825 let a = backend.to_host_cv(a)?;
1826 let b = backend.to_host_cv(b)?;
1827 let bc = backend.to_host_cv(bc)?;
1828 b.iter()
1829 .zip(a.iter())
1830 .zip(bc.iter())
1831 .for_each(|((&b, &a), &bc)| {
1832 assert_approx_eq::assert_approx_eq!(b, bc / bc.abs() * a, EPS);
1833 });
1834
1835 Ok(())
1836 }
1837
1838 #[rstest::rstest]
1839 #[test]
1840 #[case(1, 2)]
1841 #[case(2, 1)]
1842 fn test_generate_propagation_matrix(
1843 #[case] dev_num: usize,
1844 #[case] foci_num: usize,
1845 backend: ArrayFireBackend<Sphere>,
1846 ) -> Result<(), HoloError> {
1847 let env = Environment::new();
1848
1849 let reference = |geometry: Geometry, foci: Vec<Point3>| {
1850 let mut g = MatrixXc::zeros(
1851 foci.len(),
1852 geometry
1853 .iter()
1854 .map(|dev| dev.num_transducers())
1855 .sum::<usize>(),
1856 );
1857 let transducers = geometry
1858 .iter()
1859 .flat_map(|dev| dev.iter().map(|tr| (dev.idx(), tr)))
1860 .collect::<Vec<_>>();
1861 (0..foci.len()).for_each(|i| {
1862 (0..transducers.len()).for_each(|j| {
1863 g[(i, j)] = propagate::<Sphere>(
1864 transducers[j].1,
1865 env.wavenumber(),
1866 geometry[transducers[j].0].axial_direction(),
1867 &foci[i],
1868 )
1869 })
1870 });
1871 g
1872 };
1873
1874 let geometry = generate_geometry(dev_num);
1875 let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1876
1877 let g = backend.generate_propagation_matrix(
1878 &geometry,
1879 &env,
1880 &foci,
1881 &TransducerFilter::all_enabled(),
1882 )?;
1883 let g = backend.to_host_cm(g)?;
1884 reference(geometry, foci)
1885 .iter()
1886 .zip(g.iter())
1887 .for_each(|(r, g)| {
1888 assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1889 assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1890 });
1891
1892 Ok(())
1893 }
1894
1895 #[rstest::rstest]
1896 #[test]
1897 #[case(1, 2)]
1898 #[case(2, 1)]
1899 fn test_generate_propagation_matrix_with_filter(
1900 #[case] dev_num: usize,
1901 #[case] foci_num: usize,
1902 backend: ArrayFireBackend<Sphere>,
1903 ) -> Result<(), HoloError> {
1904 let env = Environment::new();
1905
1906 let filter = |geometry: &Geometry| -> TransducerFilter {
1907 TransducerFilter::from_fn(geometry, |dev| {
1908 let num_transducers = dev.num_transducers();
1909 Some(move |tr: &Transducer| tr.idx() > num_transducers / 2)
1910 })
1911 };
1912
1913 let reference = |geometry, foci: Vec<Point3>| {
1914 let filter = filter(&geometry);
1915 let transducers = geometry
1916 .iter()
1917 .flat_map(|dev| {
1918 dev.iter().filter_map(|tr| {
1919 if filter.is_enabled(tr) {
1920 Some((dev.idx(), tr))
1921 } else {
1922 None
1923 }
1924 })
1925 })
1926 .collect::<Vec<_>>();
1927
1928 let mut g = MatrixXc::zeros(foci.len(), transducers.len());
1929 (0..foci.len()).for_each(|i| {
1930 (0..transducers.len()).for_each(|j| {
1931 g[(i, j)] = propagate::<Sphere>(
1932 transducers[j].1,
1933 env.wavenumber(),
1934 geometry[transducers[j].0].axial_direction(),
1935 &foci[i],
1936 )
1937 })
1938 });
1939 g
1940 };
1941
1942 let geometry = generate_geometry(dev_num);
1943 let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1944 let filter = filter(&geometry);
1945
1946 let g = backend.generate_propagation_matrix(&geometry, &env, &foci, &filter)?;
1947 let g = backend.to_host_cm(g)?;
1948 assert_eq!(g.nrows(), foci.len());
1949 assert_eq!(
1950 g.ncols(),
1951 geometry
1952 .iter()
1953 .map(|dev| dev.num_transducers() / 2)
1954 .sum::<usize>()
1955 );
1956 reference(geometry, foci)
1957 .iter()
1958 .zip(g.iter())
1959 .for_each(|(r, g)| {
1960 assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1961 assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1962 });
1963
1964 Ok(())
1965 }
1966
1967 #[rstest::rstest]
1968 #[test]
1969 fn test_gen_back_prop(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1970 let env = Environment::new();
1971
1972 let geometry = generate_geometry(1);
1973 let foci = gen_foci(2).map(|(p, _)| p).collect::<Vec<_>>();
1974
1975 let m = geometry
1976 .iter()
1977 .map(|dev| dev.num_transducers())
1978 .sum::<usize>();
1979 let n = foci.len();
1980
1981 let g = backend.generate_propagation_matrix(
1982 &geometry,
1983 &env,
1984 &foci,
1985 &TransducerFilter::all_enabled(),
1986 )?;
1987
1988 let b = backend.gen_back_prop(m, n, &g)?;
1989 let g = backend.to_host_cm(g)?;
1990 let reference = {
1991 let mut b = MatrixXc::zeros(m, n);
1992 (0..n).for_each(|i| {
1993 let x = 1.0 / g.rows(i, 1).iter().map(|x| x.norm_sqr()).sum::<f32>();
1994 (0..m).for_each(|j| {
1995 b[(j, i)] = g[(i, j)].conj() * x;
1996 })
1997 });
1998 b
1999 };
2000
2001 let b = backend.to_host_cm(b)?;
2002 reference.iter().zip(b.iter()).for_each(|(r, b)| {
2003 assert_approx_eq::assert_approx_eq!(r.re, b.re, EPS);
2004 assert_approx_eq::assert_approx_eq!(r.im, b.im, EPS);
2005 });
2006 Ok(())
2007 }
2008}