1#![allow(unknown_lints)]
2#![allow(clippy::manual_slice_size_calculation)]
3
4use std::collections::HashMap;
5
6use arrayfire::*;
7
8use autd3_core::{
9 acoustics::{
10 directivity::{Directivity, Sphere},
11 propagate,
12 },
13 gain::BitVec,
14 geometry::Geometry,
15};
16use autd3_gain_holo::{
17 Complex, HoloError, LinAlgBackend, MatrixX, MatrixXc, Trans, VectorX, VectorXc,
18};
19
20pub type AFBackend = arrayfire::Backend;
21pub type AFDeviceInfo = (String, String, String, String);
22
23fn convert(trans: Trans) -> MatProp {
24 match trans {
25 Trans::NoTrans => MatProp::NONE,
26 Trans::Trans => MatProp::TRANS,
27 Trans::ConjTrans => MatProp::CTRANS,
28 }
29}
30
31pub struct ArrayFireBackend<D: Directivity> {
32 _phantom: std::marker::PhantomData<D>,
33}
34
35impl ArrayFireBackend<Sphere> {
36 pub fn get_available_backends() -> Vec<AFBackend> {
37 arrayfire::get_available_backends()
38 }
39
40 pub fn set_backend(backend: AFBackend) {
41 arrayfire::set_backend(backend);
42 }
43
44 pub fn set_device(device: i32) {
45 arrayfire::set_device(device);
46 }
47
48 pub fn get_available_devices() -> Vec<AFDeviceInfo> {
49 let cur_dev = arrayfire::get_device();
50 let r = (0..arrayfire::device_count())
51 .map(|i| {
52 arrayfire::set_device(i);
53 arrayfire::device_info()
54 })
55 .collect();
56 arrayfire::set_device(cur_dev);
57 r
58 }
59}
60
61impl Default for ArrayFireBackend<Sphere> {
62 fn default() -> Self {
63 Self {
64 _phantom: Default::default(),
65 }
66 }
67}
68
69impl<D: Directivity> ArrayFireBackend<D> {
70 pub fn new() -> Self {
71 Self {
72 _phantom: std::marker::PhantomData,
73 }
74 }
75}
76
77impl<D: Directivity> LinAlgBackend<D> for ArrayFireBackend<D> {
78 type MatrixXc = Array<c32>;
79 type MatrixX = Array<f32>;
80 type VectorXc = Array<c32>;
81 type VectorX = Array<f32>;
82
83 fn generate_propagation_matrix(
84 &self,
85 geometry: &Geometry,
86 foci: &[autd3_core::geometry::Point3],
87 filter: Option<&HashMap<usize, BitVec>>,
88 ) -> Result<Self::MatrixXc, HoloError> {
89 let g = if let Some(filter) = filter {
90 geometry
91 .devices()
92 .flat_map(|dev| {
93 dev.iter().filter_map(move |tr| {
94 if let Some(filter) = filter.get(&dev.idx()) {
95 if filter[tr.idx()] {
96 Some(foci.iter().map(move |fp| {
97 propagate::<D>(tr, dev.wavenumber(), dev.axial_direction(), fp)
98 }))
99 } else {
100 None
101 }
102 } else {
103 None
104 }
105 })
106 })
107 .flatten()
108 .collect::<Vec<_>>()
109 } else {
110 geometry
111 .devices()
112 .flat_map(|dev| {
113 dev.iter().flat_map(move |tr| {
114 foci.iter().map(move |fp| {
115 propagate::<D>(tr, dev.wavenumber(), dev.axial_direction(), fp)
116 })
117 })
118 })
119 .collect::<Vec<_>>()
120 };
121
122 unsafe {
123 Ok(Array::new(
124 std::slice::from_raw_parts(g.as_ptr() as *const c32, g.len()),
125 Dim4::new(&[foci.len() as u64, (g.len() / foci.len()) as _, 1, 1]),
126 ))
127 }
128 }
129
130 fn alloc_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
131 Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
132 }
133
134 fn alloc_m(&self, rows: usize, cols: usize) -> Result<Self::MatrixX, HoloError> {
135 Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
136 }
137
138 fn alloc_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
139 Ok(Array::new_empty(Dim4::new(&[size as _, 1, 1, 1])))
140 }
141
142 fn alloc_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
143 Ok(Array::new_empty(Dim4::new(&[rows as _, cols as _, 1, 1])))
144 }
145
146 fn alloc_zeros_v(&self, size: usize) -> Result<Self::VectorX, HoloError> {
147 Ok(arrayfire::constant(0., Dim4::new(&[size as _, 1, 1, 1])))
148 }
149
150 fn alloc_zeros_cv(&self, size: usize) -> Result<Self::VectorXc, HoloError> {
151 Ok(arrayfire::constant(
152 c32::new(0., 0.),
153 Dim4::new(&[size as _, 1, 1, 1]),
154 ))
155 }
156
157 fn alloc_zeros_cm(&self, rows: usize, cols: usize) -> Result<Self::MatrixXc, HoloError> {
158 Ok(arrayfire::constant(
159 c32::new(0., 0.),
160 Dim4::new(&[rows as _, cols as _, 1, 1]),
161 ))
162 }
163
164 fn to_host_v(&self, v: Self::VectorX) -> Result<VectorX, HoloError> {
165 let mut r = VectorX::zeros(v.elements());
166 v.host(r.as_mut_slice());
167 Ok(r)
168 }
169
170 fn to_host_m(&self, v: Self::MatrixX) -> Result<MatrixX, HoloError> {
171 let mut r = MatrixX::zeros(v.dims()[0] as _, v.dims()[1] as _);
172 v.host(r.as_mut_slice());
173 Ok(r)
174 }
175
176 fn to_host_cv(&self, v: Self::VectorXc) -> Result<VectorXc, HoloError> {
177 let n = v.elements();
178 let mut r = VectorXc::zeros(n);
179 unsafe {
180 v.host(std::slice::from_raw_parts_mut(
181 r.as_mut_ptr() as *mut c32,
182 n,
183 ));
184 }
185 Ok(r)
186 }
187
188 fn to_host_cm(&self, v: Self::MatrixXc) -> Result<MatrixXc, HoloError> {
189 let n = v.elements();
190 let mut r = MatrixXc::zeros(v.dims()[0] as _, v.dims()[1] as _);
191 unsafe {
192 v.host(std::slice::from_raw_parts_mut(
193 r.as_mut_ptr() as *mut c32,
194 n,
195 ));
196 }
197 Ok(r)
198 }
199
200 fn from_slice_v(&self, v: &[f32]) -> Result<Self::VectorX, HoloError> {
201 Ok(Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1])))
202 }
203
204 fn from_slice_m(
205 &self,
206 rows: usize,
207 cols: usize,
208 v: &[f32],
209 ) -> Result<Self::MatrixX, HoloError> {
210 Ok(Array::new(v, Dim4::new(&[rows as _, cols as _, 1, 1])))
211 }
212
213 fn from_slice_cv(&self, v: &[f32]) -> Result<Self::VectorXc, HoloError> {
214 let r = Array::new(v, Dim4::new(&[v.len() as _, 1, 1, 1]));
215 Ok(arrayfire::cplx(&r))
216 }
217
218 fn from_slice2_cv(&self, r: &[f32], i: &[f32]) -> Result<Self::VectorXc, HoloError> {
219 let r = Array::new(r, Dim4::new(&[r.len() as _, 1, 1, 1]));
220 let i = Array::new(i, Dim4::new(&[i.len() as _, 1, 1, 1]));
221 Ok(arrayfire::cplx2(&r, &i, false).cast())
222 }
223
224 fn from_slice2_cm(
225 &self,
226 rows: usize,
227 cols: usize,
228 r: &[f32],
229 i: &[f32],
230 ) -> Result<Self::MatrixXc, HoloError> {
231 let r = Array::new(r, Dim4::new(&[rows as _, cols as _, 1, 1]));
232 let i = Array::new(i, Dim4::new(&[rows as _, cols as _, 1, 1]));
233 Ok(arrayfire::cplx2(&r, &i, false).cast())
234 }
235
236 fn copy_from_slice_v(&self, v: &[f32], dst: &mut Self::VectorX) -> Result<(), HoloError> {
237 let n = v.len();
238 if n == 0 {
239 return Ok(());
240 }
241 let v = self.from_slice_v(v)?;
242 let seqs = [Seq::new(0u32, n as u32 - 1, 1)];
243 arrayfire::assign_seq(dst, &seqs, &v);
244 Ok(())
245 }
246
247 fn copy_to_v(&self, src: &Self::VectorX, dst: &mut Self::VectorX) -> Result<(), HoloError> {
248 let seqs = [Seq::new(0u32, src.elements() as u32 - 1, 1)];
249 arrayfire::assign_seq(dst, &seqs, src);
250 Ok(())
251 }
252
253 fn copy_to_m(&self, src: &Self::MatrixX, dst: &mut Self::MatrixX) -> Result<(), HoloError> {
254 let seqs = [
255 Seq::new(0u32, src.dims()[0] as u32 - 1, 1),
256 Seq::new(0u32, src.dims()[1] as u32 - 1, 1),
257 ];
258 arrayfire::assign_seq(dst, &seqs, src);
259 Ok(())
260 }
261
262 fn clone_v(&self, v: &Self::VectorX) -> Result<Self::VectorX, HoloError> {
263 Ok(v.copy())
264 }
265
266 fn clone_m(&self, v: &Self::MatrixX) -> Result<Self::MatrixX, HoloError> {
267 Ok(v.copy())
268 }
269
270 fn clone_cv(&self, v: &Self::VectorXc) -> Result<Self::VectorXc, HoloError> {
271 Ok(v.copy())
272 }
273
274 fn clone_cm(&self, v: &Self::MatrixXc) -> Result<Self::MatrixXc, HoloError> {
275 Ok(v.copy())
276 }
277
278 fn make_complex2_v(
279 &self,
280 real: &Self::VectorX,
281 imag: &Self::VectorX,
282 v: &mut Self::VectorXc,
283 ) -> Result<(), HoloError> {
284 *v = arrayfire::cplx2(real, imag, false).cast();
285 Ok(())
286 }
287
288 fn create_diagonal(&self, v: &Self::VectorX, a: &mut Self::MatrixX) -> Result<(), HoloError> {
289 *a = arrayfire::diag_create(v, 0);
290 Ok(())
291 }
292
293 fn create_diagonal_c(
294 &self,
295 v: &Self::VectorXc,
296 a: &mut Self::MatrixXc,
297 ) -> Result<(), HoloError> {
298 *a = arrayfire::diag_create(v, 0);
299 Ok(())
300 }
301
302 fn get_diagonal(&self, a: &Self::MatrixX, v: &mut Self::VectorX) -> Result<(), HoloError> {
303 *v = arrayfire::diag_extract(a, 0);
304 Ok(())
305 }
306
307 fn real_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
308 *b = arrayfire::real(a);
309 Ok(())
310 }
311
312 fn imag_cm(&self, a: &Self::MatrixXc, b: &mut Self::MatrixX) -> Result<(), HoloError> {
313 *b = arrayfire::imag(a);
314 Ok(())
315 }
316
317 fn scale_assign_cv(
318 &self,
319 a: autd3_gain_holo::Complex,
320 b: &mut Self::VectorXc,
321 ) -> Result<(), HoloError> {
322 let a = c32::new(a.re, a.im);
323 *b = arrayfire::mul(b, &a, false);
324 Ok(())
325 }
326
327 fn conj_assign_v(&self, b: &mut Self::VectorXc) -> Result<(), HoloError> {
328 *b = arrayfire::conjg(b);
329 Ok(())
330 }
331
332 fn exp_assign_cv(&self, v: &mut Self::VectorXc) -> Result<(), HoloError> {
333 *v = arrayfire::exp(v);
334 Ok(())
335 }
336
337 fn concat_col_cm(
338 &self,
339 a: &Self::MatrixXc,
340 b: &Self::MatrixXc,
341 c: &mut Self::MatrixXc,
342 ) -> Result<(), HoloError> {
343 *c = arrayfire::join(1, a, b);
344 Ok(())
345 }
346
347 fn max_v(&self, m: &Self::VectorX) -> Result<f32, HoloError> {
348 Ok(arrayfire::max_all(m).0)
349 }
350
351 fn hadamard_product_cm(
352 &self,
353 x: &Self::MatrixXc,
354 y: &Self::MatrixXc,
355 z: &mut Self::MatrixXc,
356 ) -> Result<(), HoloError> {
357 *z = arrayfire::mul(x, y, false);
358 Ok(())
359 }
360
361 fn dot(&self, x: &Self::VectorX, y: &Self::VectorX) -> Result<f32, HoloError> {
362 let r = arrayfire::dot(x, y, MatProp::NONE, MatProp::NONE);
363 let mut v = [0.];
364 r.host(&mut v);
365 Ok(v[0])
366 }
367
368 fn dot_c(
369 &self,
370 x: &Self::VectorXc,
371 y: &Self::VectorXc,
372 ) -> Result<autd3_gain_holo::Complex, HoloError> {
373 let r = arrayfire::dot(x, y, MatProp::CONJ, MatProp::NONE);
374 let mut v = [c32::new(0., 0.)];
375 r.host(&mut v);
376 Ok(autd3_gain_holo::Complex::new(v[0].re, v[0].im))
377 }
378
379 fn add_v(&self, alpha: f32, a: &Self::VectorX, b: &mut Self::VectorX) -> Result<(), HoloError> {
380 *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
381 Ok(())
382 }
383
384 fn add_m(&self, alpha: f32, a: &Self::MatrixX, b: &mut Self::MatrixX) -> Result<(), HoloError> {
385 *b = arrayfire::add(&arrayfire::mul(a, &alpha, false), b, false);
386 Ok(())
387 }
388
389 fn gevv_c(
390 &self,
391 trans_a: autd3_gain_holo::Trans,
392 trans_b: autd3_gain_holo::Trans,
393 alpha: autd3_gain_holo::Complex,
394 a: &Self::VectorXc,
395 x: &Self::VectorXc,
396 beta: autd3_gain_holo::Complex,
397 y: &mut Self::MatrixXc,
398 ) -> Result<(), HoloError> {
399 let alpha = vec![c32::new(alpha.re, alpha.im)];
400 let beta = vec![c32::new(beta.re, beta.im)];
401 let trans_a = convert(trans_a);
402 let trans_b = convert(trans_b);
403 arrayfire::gemm(y, trans_a, trans_b, alpha, a, x, beta);
404 Ok(())
405 }
406
407 fn gemv_c(
408 &self,
409 trans: autd3_gain_holo::Trans,
410 alpha: autd3_gain_holo::Complex,
411 a: &Self::MatrixXc,
412 x: &Self::VectorXc,
413 beta: autd3_gain_holo::Complex,
414 y: &mut Self::VectorXc,
415 ) -> Result<(), HoloError> {
416 let alpha = vec![c32::new(alpha.re, alpha.im)];
417 let beta = vec![c32::new(beta.re, beta.im)];
418 let trans = convert(trans);
419 arrayfire::gemm(y, trans, MatProp::NONE, alpha, a, x, beta);
420 Ok(())
421 }
422
423 fn gemm_c(
424 &self,
425 trans_a: autd3_gain_holo::Trans,
426 trans_b: autd3_gain_holo::Trans,
427 alpha: autd3_gain_holo::Complex,
428 a: &Self::MatrixXc,
429 b: &Self::MatrixXc,
430 beta: autd3_gain_holo::Complex,
431 y: &mut Self::MatrixXc,
432 ) -> Result<(), HoloError> {
433 let alpha = vec![c32::new(alpha.re, alpha.im)];
434 let beta = vec![c32::new(beta.re, beta.im)];
435 let trans_a = convert(trans_a);
436 let trans_b = convert(trans_b);
437 arrayfire::gemm(y, trans_a, trans_b, alpha, a, b, beta);
438 Ok(())
439 }
440
441 fn solve_inplace(&self, a: &Self::MatrixX, x: &mut Self::VectorX) -> Result<(), HoloError> {
442 *x = arrayfire::solve(a, x, MatProp::NONE);
443 Ok(())
444 }
445
446 fn reduce_col(&self, a: &Self::MatrixX, b: &mut Self::VectorX) -> Result<(), HoloError> {
447 *b = arrayfire::sum(a, 1);
448 Ok(())
449 }
450
451 fn cols_c(&self, m: &Self::MatrixXc) -> Result<usize, HoloError> {
452 Ok(m.dims()[1] as _)
453 }
454
455 fn scaled_to_cv(
456 &self,
457 a: &Self::VectorXc,
458 b: &Self::VectorXc,
459 c: &mut Self::VectorXc,
460 ) -> Result<(), HoloError> {
461 let tmp = arrayfire::div(a, &arrayfire::abs(a), false);
462 *c = arrayfire::mul(&tmp, b, false);
463 Ok(())
464 }
465
466 fn scaled_to_assign_cv(
467 &self,
468 a: &Self::VectorXc,
469 b: &mut Self::VectorXc,
470 ) -> Result<(), HoloError> {
471 *b = arrayfire::div(b, &arrayfire::abs(b), false);
472 *b = arrayfire::mul(a, b, false);
473 Ok(())
474 }
475
476 fn gen_back_prop(
477 &self,
478 m: usize,
479 n: usize,
480 transfer: &Self::MatrixXc,
481 ) -> Result<Self::MatrixXc, HoloError> {
482 let mut b = self.alloc_zeros_cm(m, n)?;
483
484 let mut tmp = self.alloc_zeros_cm(n, n)?;
485
486 self.gemm_c(
487 Trans::NoTrans,
488 Trans::ConjTrans,
489 Complex::new(1., 0.),
490 transfer,
491 transfer,
492 Complex::new(0., 0.),
493 &mut tmp,
494 )?;
495
496 let mut denominator = arrayfire::diag_extract(&tmp, 0);
497 let a = c32::new(1., 0.);
498 denominator = arrayfire::div(&a, &denominator, false);
499
500 self.create_diagonal_c(&denominator, &mut tmp)?;
501
502 self.gemm_c(
503 Trans::ConjTrans,
504 Trans::NoTrans,
505 Complex::new(1., 0.),
506 transfer,
507 &tmp,
508 Complex::new(0., 0.),
509 &mut b,
510 )?;
511
512 Ok(b)
513 }
514
515 fn norm_squared_cv(&self, a: &Self::VectorXc, b: &mut Self::VectorX) -> Result<(), HoloError> {
516 *b = arrayfire::abs(a);
517 *b = arrayfire::mul(b, b, false);
518 Ok(())
519 }
520}
521#[cfg(test)]
522mod tests {
523 use autd3::driver::autd3_device::AUTD3;
524 use autd3_core::{
525 acoustics::directivity::Sphere,
526 defined::PI,
527 geometry::{IntoDevice, Point3, UnitQuaternion},
528 };
529
530 use nalgebra::{ComplexField, Normed};
531
532 use autd3_gain_holo::{Amplitude, Pa, Trans};
533
534 use super::*;
535
536 use rand::Rng;
537
538 const N: usize = 10;
539 const EPS: f32 = 1e-3;
540
541 fn generate_geometry(size: usize) -> Geometry {
542 Geometry::new(
543 (0..size)
544 .flat_map(|i| {
545 (0..size).map(move |j| {
546 AUTD3 {
547 pos: Point3::new(
548 i as f32 * AUTD3::DEVICE_WIDTH,
549 j as f32 * AUTD3::DEVICE_HEIGHT,
550 0.,
551 ),
552 rot: UnitQuaternion::identity(),
553 }
554 .into_device((j + i * size) as _)
555 })
556 })
557 .collect(),
558 )
559 }
560
561 fn gen_foci(n: usize) -> impl Iterator<Item = (Point3, Amplitude)> {
562 (0..n).map(move |i| {
563 (
564 Point3::new(
565 90. + 10. * (2.0 * PI * i as f32 / n as f32).cos(),
566 70. + 10. * (2.0 * PI * i as f32 / n as f32).sin(),
567 150.,
568 ),
569 10e3 * Pa,
570 )
571 })
572 }
573
574 fn make_random_v(
575 backend: &ArrayFireBackend<Sphere>,
576 size: usize,
577 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorX, HoloError> {
578 let mut rng = rand::rng();
579 let v: Vec<f32> = (&mut rng)
580 .sample_iter(rand::distr::StandardUniform)
581 .take(size)
582 .collect();
583 backend.from_slice_v(&v)
584 }
585
586 fn make_random_m(
587 backend: &ArrayFireBackend<Sphere>,
588 rows: usize,
589 cols: usize,
590 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixX, HoloError> {
591 let mut rng = rand::rng();
592 let v: Vec<f32> = (&mut rng)
593 .sample_iter(rand::distr::StandardUniform)
594 .take(rows * cols)
595 .collect();
596 backend.from_slice_m(rows, cols, &v)
597 }
598
599 fn make_random_cv(
600 backend: &ArrayFireBackend<Sphere>,
601 size: usize,
602 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::VectorXc, HoloError> {
603 let mut rng = rand::rng();
604 let real: Vec<f32> = (&mut rng)
605 .sample_iter(rand::distr::StandardUniform)
606 .take(size)
607 .collect();
608 let imag: Vec<f32> = (&mut rng)
609 .sample_iter(rand::distr::StandardUniform)
610 .take(size)
611 .collect();
612 backend.from_slice2_cv(&real, &imag)
613 }
614
615 fn make_random_cm(
616 backend: &ArrayFireBackend<Sphere>,
617 rows: usize,
618 cols: usize,
619 ) -> Result<<ArrayFireBackend<Sphere> as LinAlgBackend<Sphere>>::MatrixXc, HoloError> {
620 let mut rng = rand::rng();
621 let real: Vec<f32> = (&mut rng)
622 .sample_iter(rand::distr::StandardUniform)
623 .take(rows * cols)
624 .collect();
625 let imag: Vec<f32> = (&mut rng)
626 .sample_iter(rand::distr::StandardUniform)
627 .take(rows * cols)
628 .collect();
629 backend.from_slice2_cm(rows, cols, &real, &imag)
630 }
631
632 #[rstest::fixture]
633 fn backend() -> ArrayFireBackend<Sphere> {
634 ArrayFireBackend::set_backend(AFBackend::CPU);
635 ArrayFireBackend {
636 _phantom: std::marker::PhantomData,
637 }
638 }
639
640 #[rstest::rstest]
641 #[test]
642 #[cfg_attr(miri, ignore)]
643 fn test_alloc_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
644 let v = backend.alloc_v(N)?;
645 let v = backend.to_host_v(v)?;
646
647 assert_eq!(N, v.len());
648 Ok(())
649 }
650
651 #[rstest::rstest]
652 #[test]
653 #[cfg_attr(miri, ignore)]
654 fn test_alloc_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
655 let m = backend.alloc_m(N, 2 * N)?;
656 let m = backend.to_host_m(m)?;
657
658 assert_eq!(N, m.nrows());
659 assert_eq!(2 * N, m.ncols());
660 Ok(())
661 }
662
663 #[rstest::rstest]
664 #[test]
665 #[cfg_attr(miri, ignore)]
666 fn test_alloc_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
667 let v = backend.alloc_cv(N)?;
668 let v = backend.to_host_cv(v)?;
669
670 assert_eq!(N, v.len());
671 Ok(())
672 }
673
674 #[rstest::rstest]
675 #[test]
676 #[cfg_attr(miri, ignore)]
677 fn test_alloc_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
678 let m = backend.alloc_cm(N, 2 * N)?;
679 let m = backend.to_host_cm(m)?;
680
681 assert_eq!(N, m.nrows());
682 assert_eq!(2 * N, m.ncols());
683 Ok(())
684 }
685
686 #[rstest::rstest]
687 #[test]
688 #[cfg_attr(miri, ignore)]
689 fn test_alloc_zeros_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
690 let v = backend.alloc_zeros_v(N)?;
691 let v = backend.to_host_v(v)?;
692
693 assert_eq!(N, v.len());
694 assert!(v.iter().all(|&v| v == 0.));
695 Ok(())
696 }
697
698 #[rstest::rstest]
699 #[test]
700 #[cfg_attr(miri, ignore)]
701 fn test_alloc_zeros_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
702 let v = backend.alloc_zeros_cv(N)?;
703 let v = backend.to_host_cv(v)?;
704
705 assert_eq!(N, v.len());
706 assert!(v.iter().all(|&v| v == Complex::new(0., 0.)));
707 Ok(())
708 }
709
710 #[rstest::rstest]
711 #[test]
712 #[cfg_attr(miri, ignore)]
713 fn test_alloc_zeros_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
714 let m = backend.alloc_zeros_cm(N, 2 * N)?;
715 let m = backend.to_host_cm(m)?;
716
717 assert_eq!(N, m.nrows());
718 assert_eq!(2 * N, m.ncols());
719 assert!(m.iter().all(|&v| v == Complex::new(0., 0.)));
720 Ok(())
721 }
722
723 #[rstest::rstest]
724 #[test]
725 #[cfg_attr(miri, ignore)]
726 fn test_cols_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
727 let m = backend.alloc_cm(N, 2 * N)?;
728
729 assert_eq!(2 * N, backend.cols_c(&m)?);
730
731 Ok(())
732 }
733
734 #[rstest::rstest]
735 #[test]
736 #[cfg_attr(miri, ignore)]
737 fn test_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
738 let rng = rand::rng();
739
740 let v: Vec<f32> = rng
741 .sample_iter(rand::distr::StandardUniform)
742 .take(N)
743 .collect();
744
745 let c = backend.from_slice_v(&v)?;
746 let c = backend.to_host_v(c)?;
747
748 assert_eq!(N, c.len());
749 v.iter().zip(c.iter()).for_each(|(&r, &c)| {
750 assert_eq!(r, c);
751 });
752 Ok(())
753 }
754
755 #[rstest::rstest]
756 #[test]
757 #[cfg_attr(miri, ignore)]
758 fn test_from_slice_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
759 let rng = rand::rng();
760
761 let v: Vec<f32> = rng
762 .sample_iter(rand::distr::StandardUniform)
763 .take(N * 2 * N)
764 .collect();
765
766 let c = backend.from_slice_m(N, 2 * N, &v)?;
767 let c = backend.to_host_m(c)?;
768
769 assert_eq!(N, c.nrows());
770 assert_eq!(2 * N, c.ncols());
771 (0..2 * N).for_each(|col| {
772 (0..N).for_each(|row| {
773 assert_eq!(v[col * N + row], c[(row, col)]);
774 })
775 });
776 Ok(())
777 }
778
779 #[rstest::rstest]
780 #[test]
781 #[cfg_attr(miri, ignore)]
782 fn test_from_slice_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
783 let rng = rand::rng();
784
785 let real: Vec<f32> = rng
786 .sample_iter(rand::distr::StandardUniform)
787 .take(N)
788 .collect();
789
790 let c = backend.from_slice_cv(&real)?;
791 let c = backend.to_host_cv(c)?;
792
793 assert_eq!(N, c.len());
794 real.iter().zip(c.iter()).for_each(|(r, c)| {
795 assert_eq!(r, &c.re);
796 assert_eq!(0.0, c.im);
797 });
798 Ok(())
799 }
800
801 #[rstest::rstest]
802 #[test]
803 #[cfg_attr(miri, ignore)]
804 fn test_from_slice2_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
805 let mut rng = rand::rng();
806
807 let real: Vec<f32> = (&mut rng)
808 .sample_iter(rand::distr::StandardUniform)
809 .take(N)
810 .collect();
811 let imag: Vec<f32> = (&mut rng)
812 .sample_iter(rand::distr::StandardUniform)
813 .take(N)
814 .collect();
815
816 let c = backend.from_slice2_cv(&real, &imag)?;
817 let c = backend.to_host_cv(c)?;
818
819 assert_eq!(N, c.len());
820 real.iter()
821 .zip(imag.iter())
822 .zip(c.iter())
823 .for_each(|((r, i), c)| {
824 assert_eq!(r, &c.re);
825 assert_eq!(i, &c.im);
826 });
827 Ok(())
828 }
829
830 #[rstest::rstest]
831 #[test]
832 #[cfg_attr(miri, ignore)]
833 fn test_from_slice2_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
834 let mut rng = rand::rng();
835
836 let real: Vec<f32> = (&mut rng)
837 .sample_iter(rand::distr::StandardUniform)
838 .take(N * 2 * N)
839 .collect();
840 let imag: Vec<f32> = (&mut rng)
841 .sample_iter(rand::distr::StandardUniform)
842 .take(N * 2 * N)
843 .collect();
844
845 let c = backend.from_slice2_cm(N, 2 * N, &real, &imag)?;
846 let c = backend.to_host_cm(c)?;
847
848 assert_eq!(N, c.nrows());
849 assert_eq!(2 * N, c.ncols());
850 (0..2 * N).for_each(|col| {
851 (0..N).for_each(|row| {
852 assert_eq!(real[col * N + row], c[(row, col)].re);
853 assert_eq!(imag[col * N + row], c[(row, col)].im);
854 })
855 });
856 Ok(())
857 }
858
859 #[rstest::rstest]
860 #[test]
861 #[cfg_attr(miri, ignore)]
862 fn test_copy_from_slice_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
863 {
864 let mut a = backend.alloc_zeros_v(N)?;
865 let mut rng = rand::rng();
866 let v = (&mut rng)
867 .sample_iter(rand::distr::StandardUniform)
868 .take(N / 2)
869 .collect::<Vec<f32>>();
870
871 backend.copy_from_slice_v(&v, &mut a)?;
872
873 let a = backend.to_host_v(a)?;
874 (0..N / 2).for_each(|i| {
875 assert_eq!(v[i], a[i]);
876 });
877 (N / 2..N).for_each(|i| {
878 assert_eq!(0., a[i]);
879 });
880 }
881
882 {
883 let mut a = backend.alloc_zeros_v(N)?;
884 let v = [];
885
886 backend.copy_from_slice_v(&v, &mut a)?;
887
888 let a = backend.to_host_v(a)?;
889 a.iter().for_each(|&a| {
890 assert_eq!(0., a);
891 });
892 }
893
894 Ok(())
895 }
896
897 #[rstest::rstest]
898 #[test]
899 #[cfg_attr(miri, ignore)]
900 fn test_copy_to_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
901 let a = make_random_v(&backend, N)?;
902 let mut b = backend.alloc_v(N)?;
903
904 backend.copy_to_v(&a, &mut b)?;
905
906 let a = backend.to_host_v(a)?;
907 let b = backend.to_host_v(b)?;
908 a.iter().zip(b.iter()).for_each(|(a, b)| {
909 assert_eq!(a, b);
910 });
911 Ok(())
912 }
913
914 #[rstest::rstest]
915 #[test]
916 #[cfg_attr(miri, ignore)]
917 fn test_copy_to_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
918 let a = make_random_m(&backend, N, N)?;
919 let mut b = backend.alloc_m(N, N)?;
920
921 backend.copy_to_m(&a, &mut b)?;
922
923 let a = backend.to_host_m(a)?;
924 let b = backend.to_host_m(b)?;
925 a.iter().zip(b.iter()).for_each(|(a, b)| {
926 assert_eq!(a, b);
927 });
928 Ok(())
929 }
930
931 #[rstest::rstest]
932 #[test]
933 #[cfg_attr(miri, ignore)]
934 fn test_clone_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
935 let c = make_random_v(&backend, N)?;
936 let c2 = backend.clone_v(&c)?;
937
938 let c = backend.to_host_v(c)?;
939 let c2 = backend.to_host_v(c2)?;
940
941 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
942 assert_eq!(c, c2);
943 });
944 Ok(())
945 }
946
947 #[rstest::rstest]
948 #[test]
949 #[cfg_attr(miri, ignore)]
950 fn test_clone_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
951 let c = make_random_m(&backend, N, N)?;
952 let c2 = backend.clone_m(&c)?;
953
954 let c = backend.to_host_m(c)?;
955 let c2 = backend.to_host_m(c2)?;
956
957 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
958 assert_eq!(c, c2);
959 });
960 Ok(())
961 }
962
963 #[rstest::rstest]
964 #[test]
965 #[cfg_attr(miri, ignore)]
966 fn test_clone_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
967 let c = make_random_cv(&backend, N)?;
968 let c2 = backend.clone_cv(&c)?;
969
970 let c = backend.to_host_cv(c)?;
971 let c2 = backend.to_host_cv(c2)?;
972
973 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
974 assert_eq!(c.re, c2.re);
975 assert_eq!(c.im, c2.im);
976 });
977 Ok(())
978 }
979
980 #[rstest::rstest]
981 #[test]
982 #[cfg_attr(miri, ignore)]
983 fn test_clone_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
984 let c = make_random_cm(&backend, N, N)?;
985 let c2 = backend.clone_cm(&c)?;
986
987 let c = backend.to_host_cm(c)?;
988 let c2 = backend.to_host_cm(c2)?;
989
990 c.iter().zip(c2.iter()).for_each(|(c, c2)| {
991 assert_eq!(c.re, c2.re);
992 assert_eq!(c.im, c2.im);
993 });
994 Ok(())
995 }
996
997 #[rstest::rstest]
998 #[test]
999 #[cfg_attr(miri, ignore)]
1000 fn test_make_complex2_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1001 let real = make_random_v(&backend, N)?;
1002 let imag = make_random_v(&backend, N)?;
1003
1004 let mut c = backend.alloc_cv(N)?;
1005 backend.make_complex2_v(&real, &imag, &mut c)?;
1006
1007 let real = backend.to_host_v(real)?;
1008 let imag = backend.to_host_v(imag)?;
1009 let c = backend.to_host_cv(c)?;
1010 real.iter()
1011 .zip(imag.iter())
1012 .zip(c.iter())
1013 .for_each(|((r, i), c)| {
1014 assert_eq!(r, &c.re);
1015 assert_eq!(i, &c.im);
1016 });
1017 Ok(())
1018 }
1019
1020 #[rstest::rstest]
1021 #[test]
1022 #[cfg_attr(miri, ignore)]
1023 fn test_create_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1024 let diagonal = make_random_v(&backend, N)?;
1025
1026 let mut c = backend.alloc_m(N, N)?;
1027
1028 backend.create_diagonal(&diagonal, &mut c)?;
1029
1030 let diagonal = backend.to_host_v(diagonal)?;
1031 let c = backend.to_host_m(c)?;
1032 (0..N).for_each(|i| {
1033 (0..N).for_each(|j| {
1034 if i == j {
1035 assert_eq!(diagonal[i], c[(i, j)]);
1036 } else {
1037 assert_eq!(0.0, c[(i, j)]);
1038 }
1039 })
1040 });
1041 Ok(())
1042 }
1043
1044 #[rstest::rstest]
1045 #[test]
1046 #[cfg_attr(miri, ignore)]
1047 fn test_create_diagonal_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1048 let diagonal = make_random_cv(&backend, N)?;
1049
1050 let mut c = backend.alloc_cm(N, N)?;
1051
1052 backend.create_diagonal_c(&diagonal, &mut c)?;
1053
1054 let diagonal = backend.to_host_cv(diagonal)?;
1055 let c = backend.to_host_cm(c)?;
1056 (0..N).for_each(|i| {
1057 (0..N).for_each(|j| {
1058 if i == j {
1059 assert_eq!(diagonal[i].re, c[(i, j)].re);
1060 assert_eq!(diagonal[i].im, c[(i, j)].im);
1061 } else {
1062 assert_eq!(0.0, c[(i, j)].re);
1063 assert_eq!(0.0, c[(i, j)].im);
1064 }
1065 })
1066 });
1067 Ok(())
1068 }
1069
1070 #[rstest::rstest]
1071 #[test]
1072 #[cfg_attr(miri, ignore)]
1073 fn test_get_diagonal(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1074 let m = make_random_m(&backend, N, N)?;
1075 let mut diagonal = backend.alloc_v(N)?;
1076
1077 backend.get_diagonal(&m, &mut diagonal)?;
1078
1079 let m = backend.to_host_m(m)?;
1080 let diagonal = backend.to_host_v(diagonal)?;
1081 (0..N).for_each(|i| {
1082 assert_eq!(m[(i, i)], diagonal[i]);
1083 });
1084 Ok(())
1085 }
1086
1087 #[rstest::rstest]
1088 #[test]
1089 #[cfg_attr(miri, ignore)]
1090 fn test_norm_squared_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1091 let v = make_random_cv(&backend, N)?;
1092
1093 let mut abs = backend.alloc_v(N)?;
1094 backend.norm_squared_cv(&v, &mut abs)?;
1095
1096 let v = backend.to_host_cv(v)?;
1097 let abs = backend.to_host_v(abs)?;
1098 v.iter().zip(abs.iter()).for_each(|(v, abs)| {
1099 assert_approx_eq::assert_approx_eq!(v.norm_squared(), abs, EPS);
1100 });
1101 Ok(())
1102 }
1103
1104 #[rstest::rstest]
1105 #[test]
1106 #[cfg_attr(miri, ignore)]
1107 fn test_real_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1108 let v = make_random_cm(&backend, N, N)?;
1109 let mut r = backend.alloc_m(N, N)?;
1110
1111 backend.real_cm(&v, &mut r)?;
1112
1113 let v = backend.to_host_cm(v)?;
1114 let r = backend.to_host_m(r)?;
1115 (0..N).for_each(|i| {
1116 (0..N).for_each(|j| {
1117 assert_approx_eq::assert_approx_eq!(v[(i, j)].re, r[(i, j)], EPS);
1118 })
1119 });
1120 Ok(())
1121 }
1122
1123 #[rstest::rstest]
1124 #[test]
1125 #[cfg_attr(miri, ignore)]
1126 fn test_imag_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1127 let v = make_random_cm(&backend, N, N)?;
1128 let mut r = backend.alloc_m(N, N)?;
1129
1130 backend.imag_cm(&v, &mut r)?;
1131
1132 let v = backend.to_host_cm(v)?;
1133 let r = backend.to_host_m(r)?;
1134 (0..N).for_each(|i| {
1135 (0..N).for_each(|j| {
1136 assert_approx_eq::assert_approx_eq!(v[(i, j)].im, r[(i, j)], EPS);
1137 })
1138 });
1139 Ok(())
1140 }
1141
1142 #[rstest::rstest]
1143 #[test]
1144 #[cfg_attr(miri, ignore)]
1145 fn test_scale_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1146 let mut v = make_random_cv(&backend, N)?;
1147 let vc = backend.clone_cv(&v)?;
1148 let mut rng = rand::rng();
1149 let scale = Complex::new(rng.random(), rng.random());
1150
1151 backend.scale_assign_cv(scale, &mut v)?;
1152
1153 let v = backend.to_host_cv(v)?;
1154 let vc = backend.to_host_cv(vc)?;
1155 v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1156 assert_approx_eq::assert_approx_eq!(scale * vc, v, EPS);
1157 });
1158 Ok(())
1159 }
1160
1161 #[rstest::rstest]
1162 #[test]
1163 #[cfg_attr(miri, ignore)]
1164 fn test_conj_assign_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1165 let mut v = make_random_cv(&backend, N)?;
1166 let vc = backend.clone_cv(&v)?;
1167
1168 backend.conj_assign_v(&mut v)?;
1169
1170 let v = backend.to_host_cv(v)?;
1171 let vc = backend.to_host_cv(vc)?;
1172 v.iter().zip(vc.iter()).for_each(|(&v, &vc)| {
1173 assert_eq!(vc.re, v.re);
1174 assert_eq!(vc.im, -v.im);
1175 });
1176 Ok(())
1177 }
1178
1179 #[rstest::rstest]
1180 #[test]
1181 #[cfg_attr(miri, ignore)]
1182 fn test_exp_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1183 let mut v = make_random_cv(&backend, N)?;
1184 let vc = backend.clone_cv(&v)?;
1185
1186 backend.exp_assign_cv(&mut v)?;
1187
1188 let v = backend.to_host_cv(v)?;
1189 let vc = backend.to_host_cv(vc)?;
1190 v.iter().zip(vc.iter()).for_each(|(v, vc)| {
1191 assert_approx_eq::assert_approx_eq!(vc.exp(), v, EPS);
1192 });
1193 Ok(())
1194 }
1195
1196 #[rstest::rstest]
1197 #[test]
1198 #[cfg_attr(miri, ignore)]
1199 fn test_concat_col_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1200 let a = make_random_cm(&backend, N, N)?;
1201 let b = make_random_cm(&backend, N, 2 * N)?;
1202 let mut c = backend.alloc_cm(N, N + 2 * N)?;
1203
1204 backend.concat_col_cm(&a, &b, &mut c)?;
1205
1206 let a = backend.to_host_cm(a)?;
1207 let b = backend.to_host_cm(b)?;
1208 let c = backend.to_host_cm(c)?;
1209 (0..N).for_each(|col| (0..N).for_each(|row| assert_eq!(a[(row, col)], c[(row, col)])));
1210 (0..2 * N)
1211 .for_each(|col| (0..N).for_each(|row| assert_eq!(b[(row, col)], c[(row, N + col)])));
1212 Ok(())
1213 }
1214
1215 #[rstest::rstest]
1216 #[test]
1217 #[cfg_attr(miri, ignore)]
1218 fn test_max_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1219 let v = make_random_v(&backend, N)?;
1220
1221 let max = backend.max_v(&v)?;
1222
1223 let v = backend.to_host_v(v)?;
1224 assert_eq!(
1225 *v.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
1226 max
1227 );
1228 Ok(())
1229 }
1230
1231 #[rstest::rstest]
1232 #[test]
1233 #[cfg_attr(miri, ignore)]
1234 fn test_hadamard_product_cm(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1235 let a = make_random_cm(&backend, N, N)?;
1236 let b = make_random_cm(&backend, N, N)?;
1237 let mut c = backend.alloc_cm(N, N)?;
1238
1239 backend.hadamard_product_cm(&a, &b, &mut c)?;
1240
1241 let a = backend.to_host_cm(a)?;
1242 let b = backend.to_host_cm(b)?;
1243 let c = backend.to_host_cm(c)?;
1244 c.iter()
1245 .zip(a.iter())
1246 .zip(b.iter())
1247 .for_each(|((c, a), b)| {
1248 assert_approx_eq::assert_approx_eq!(a.re * b.re - a.im * b.im, c.re, EPS);
1249 assert_approx_eq::assert_approx_eq!(a.re * b.im + a.im * b.re, c.im, EPS);
1250 });
1251 Ok(())
1252 }
1253
1254 #[rstest::rstest]
1255 #[test]
1256 #[cfg_attr(miri, ignore)]
1257 fn test_dot(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1258 let a = make_random_v(&backend, N)?;
1259 let b = make_random_v(&backend, N)?;
1260
1261 let dot = backend.dot(&a, &b)?;
1262
1263 let a = backend.to_host_v(a)?;
1264 let b = backend.to_host_v(b)?;
1265 let expect = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum::<f32>();
1266 assert_approx_eq::assert_approx_eq!(dot, expect, EPS);
1267 Ok(())
1268 }
1269
1270 #[rstest::rstest]
1271 #[test]
1272 #[cfg_attr(miri, ignore)]
1273 fn test_dot_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1274 let a = make_random_cv(&backend, N)?;
1275 let b = make_random_cv(&backend, N)?;
1276
1277 let dot = backend.dot_c(&a, &b)?;
1278
1279 let a = backend.to_host_cv(a)?;
1280 let b = backend.to_host_cv(b)?;
1281 let expect = a
1282 .iter()
1283 .zip(b.iter())
1284 .map(|(a, b)| a.conj() * b)
1285 .sum::<Complex>();
1286 assert_approx_eq::assert_approx_eq!(dot.re, expect.re, EPS);
1287 assert_approx_eq::assert_approx_eq!(dot.im, expect.im, EPS);
1288 Ok(())
1289 }
1290
1291 #[rstest::rstest]
1292 #[test]
1293 #[cfg_attr(miri, ignore)]
1294 fn test_add_v(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1295 let a = make_random_v(&backend, N)?;
1296 let mut b = make_random_v(&backend, N)?;
1297 let bc = backend.clone_v(&b)?;
1298
1299 let mut rng = rand::rng();
1300 let alpha = rng.random();
1301
1302 backend.add_v(alpha, &a, &mut b)?;
1303
1304 let a = backend.to_host_v(a)?;
1305 let b = backend.to_host_v(b)?;
1306 let bc = backend.to_host_v(bc)?;
1307 b.iter()
1308 .zip(a.iter())
1309 .zip(bc.iter())
1310 .for_each(|((b, a), bc)| {
1311 assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1312 });
1313 Ok(())
1314 }
1315
1316 #[rstest::rstest]
1317 #[test]
1318 #[cfg_attr(miri, ignore)]
1319 fn test_add_m(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1320 let a = make_random_m(&backend, N, N)?;
1321 let mut b = make_random_m(&backend, N, N)?;
1322 let bc = backend.clone_m(&b)?;
1323
1324 let mut rng = rand::rng();
1325 let alpha = rng.random();
1326
1327 backend.add_m(alpha, &a, &mut b)?;
1328
1329 let a = backend.to_host_m(a)?;
1330 let b = backend.to_host_m(b)?;
1331 let bc = backend.to_host_m(bc)?;
1332 b.iter()
1333 .zip(a.iter())
1334 .zip(bc.iter())
1335 .for_each(|((b, a), bc)| {
1336 assert_approx_eq::assert_approx_eq!(alpha * a + bc, b, EPS);
1337 });
1338 Ok(())
1339 }
1340
1341 #[rstest::rstest]
1342 #[test]
1343 #[cfg_attr(miri, ignore)]
1344 fn test_gevv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1345 let mut rng = rand::rng();
1346
1347 {
1348 let a = make_random_cv(&backend, N)?;
1349 let b = make_random_cv(&backend, N)?;
1350 let mut c = make_random_cm(&backend, N, N)?;
1351 let cc = backend.clone_cm(&c)?;
1352
1353 let alpha = Complex::new(rng.random(), rng.random());
1354 let beta = Complex::new(rng.random(), rng.random());
1355 backend.gevv_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1356
1357 let a = backend.to_host_cv(a)?;
1358 let b = backend.to_host_cv(b)?;
1359 let c = backend.to_host_cm(c)?;
1360 let cc = backend.to_host_cm(cc)?;
1361 let expected = a * b.transpose() * alpha + cc * beta;
1362 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1363 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1364 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1365 });
1366 }
1367
1368 {
1369 let a = make_random_cv(&backend, N)?;
1370 let b = make_random_cv(&backend, N)?;
1371 let mut c = make_random_cm(&backend, N, N)?;
1372 let cc = backend.clone_cm(&c)?;
1373
1374 let alpha = Complex::new(rng.random(), rng.random());
1375 let beta = Complex::new(rng.random(), rng.random());
1376 backend.gevv_c(
1377 Trans::NoTrans,
1378 Trans::ConjTrans,
1379 alpha,
1380 &a,
1381 &b,
1382 beta,
1383 &mut c,
1384 )?;
1385
1386 let a = backend.to_host_cv(a)?;
1387 let b = backend.to_host_cv(b)?;
1388 let c = backend.to_host_cm(c)?;
1389 let cc = backend.to_host_cm(cc)?;
1390 let expected = a * b.adjoint() * alpha + cc * beta;
1391 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1392 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1393 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1394 });
1395 }
1396
1397 {
1398 let a = make_random_cv(&backend, N)?;
1399 let b = make_random_cv(&backend, N)?;
1400 let mut c = make_random_cm(&backend, 1, 1)?;
1401 let cc = backend.clone_cm(&c)?;
1402
1403 let alpha = Complex::new(rng.random(), rng.random());
1404 let beta = Complex::new(rng.random(), rng.random());
1405 backend.gevv_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1406
1407 let a = backend.to_host_cv(a)?;
1408 let b = backend.to_host_cv(b)?;
1409 let c = backend.to_host_cm(c)?;
1410 let cc = backend.to_host_cm(cc)?;
1411 let expected = a.transpose() * b * alpha + cc * beta;
1412 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1413 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1414 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1415 });
1416 }
1417
1418 {
1419 let a = make_random_cv(&backend, N)?;
1420 let b = make_random_cv(&backend, N)?;
1421 let mut c = make_random_cm(&backend, 1, 1)?;
1422 let cc = backend.clone_cm(&c)?;
1423
1424 let alpha = Complex::new(rng.random(), rng.random());
1425 let beta = Complex::new(rng.random(), rng.random());
1426 backend.gevv_c(
1427 Trans::ConjTrans,
1428 Trans::NoTrans,
1429 alpha,
1430 &a,
1431 &b,
1432 beta,
1433 &mut c,
1434 )?;
1435
1436 let a = backend.to_host_cv(a)?;
1437 let b = backend.to_host_cv(b)?;
1438 let c = backend.to_host_cm(c)?;
1439 let cc = backend.to_host_cm(cc)?;
1440 let expected = a.adjoint() * b * alpha + cc * beta;
1441 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1442 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1443 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1444 });
1445 }
1446
1447 Ok(())
1448 }
1449
1450 #[rstest::rstest]
1451 #[test]
1452 #[cfg_attr(miri, ignore)]
1453 fn test_gemv_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1454 let m = N;
1455 let n = 2 * N;
1456
1457 let mut rng = rand::rng();
1458
1459 {
1460 let a = make_random_cm(&backend, m, n)?;
1461 let b = make_random_cv(&backend, n)?;
1462 let mut c = make_random_cv(&backend, m)?;
1463 let cc = backend.clone_cv(&c)?;
1464
1465 let alpha = Complex::new(rng.random(), rng.random());
1466 let beta = Complex::new(rng.random(), rng.random());
1467 backend.gemv_c(Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1468
1469 let a = backend.to_host_cm(a)?;
1470 let b = backend.to_host_cv(b)?;
1471 let c = backend.to_host_cv(c)?;
1472 let cc = backend.to_host_cv(cc)?;
1473 let expected = a * b * alpha + cc * beta;
1474 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1475 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1476 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1477 });
1478 }
1479
1480 {
1481 let a = make_random_cm(&backend, n, m)?;
1482 let b = make_random_cv(&backend, n)?;
1483 let mut c = make_random_cv(&backend, m)?;
1484 let cc = backend.clone_cv(&c)?;
1485
1486 let alpha = Complex::new(rng.random(), rng.random());
1487 let beta = Complex::new(rng.random(), rng.random());
1488 backend.gemv_c(Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1489
1490 let a = backend.to_host_cm(a)?;
1491 let b = backend.to_host_cv(b)?;
1492 let c = backend.to_host_cv(c)?;
1493 let cc = backend.to_host_cv(cc)?;
1494 let expected = a.transpose() * b * alpha + cc * beta;
1495 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1496 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1497 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1498 });
1499 }
1500
1501 {
1502 let a = make_random_cm(&backend, n, m)?;
1503 let b = make_random_cv(&backend, n)?;
1504 let mut c = make_random_cv(&backend, m)?;
1505 let cc = backend.clone_cv(&c)?;
1506
1507 let alpha = Complex::new(rng.random(), rng.random());
1508 let beta = Complex::new(rng.random(), rng.random());
1509 backend.gemv_c(Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1510
1511 let a = backend.to_host_cm(a)?;
1512 let b = backend.to_host_cv(b)?;
1513 let c = backend.to_host_cv(c)?;
1514 let cc = backend.to_host_cv(cc)?;
1515 let expected = a.adjoint() * b * alpha + cc * beta;
1516 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1517 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1518 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1519 });
1520 }
1521 Ok(())
1522 }
1523
1524 #[rstest::rstest]
1525 #[test]
1526 #[cfg_attr(miri, ignore)]
1527 fn test_gemm_c(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1528 let m = N;
1529 let n = 2 * N;
1530 let k = 3 * N;
1531
1532 let mut rng = rand::rng();
1533
1534 {
1535 let a = make_random_cm(&backend, m, k)?;
1536 let b = make_random_cm(&backend, k, n)?;
1537 let mut c = make_random_cm(&backend, m, n)?;
1538 let cc = backend.clone_cm(&c)?;
1539
1540 let alpha = Complex::new(rng.random(), rng.random());
1541 let beta = Complex::new(rng.random(), rng.random());
1542 backend.gemm_c(Trans::NoTrans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1543
1544 let a = backend.to_host_cm(a)?;
1545 let b = backend.to_host_cm(b)?;
1546 let c = backend.to_host_cm(c)?;
1547 let cc = backend.to_host_cm(cc)?;
1548 let expected = a * b * alpha + cc * beta;
1549 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1550 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1551 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1552 });
1553 }
1554
1555 {
1556 let a = make_random_cm(&backend, m, k)?;
1557 let b = make_random_cm(&backend, n, k)?;
1558 let mut c = make_random_cm(&backend, m, n)?;
1559 let cc = backend.clone_cm(&c)?;
1560
1561 let alpha = Complex::new(rng.random(), rng.random());
1562 let beta = Complex::new(rng.random(), rng.random());
1563 backend.gemm_c(Trans::NoTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1564
1565 let a = backend.to_host_cm(a)?;
1566 let b = backend.to_host_cm(b)?;
1567 let c = backend.to_host_cm(c)?;
1568 let cc = backend.to_host_cm(cc)?;
1569 let expected = a * b.transpose() * alpha + cc * beta;
1570 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1571 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1572 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1573 });
1574 }
1575
1576 {
1577 let a = make_random_cm(&backend, m, k)?;
1578 let b = make_random_cm(&backend, n, k)?;
1579 let mut c = make_random_cm(&backend, m, n)?;
1580 let cc = backend.clone_cm(&c)?;
1581
1582 let alpha = Complex::new(rng.random(), rng.random());
1583 let beta = Complex::new(rng.random(), rng.random());
1584 backend.gemm_c(
1585 Trans::NoTrans,
1586 Trans::ConjTrans,
1587 alpha,
1588 &a,
1589 &b,
1590 beta,
1591 &mut c,
1592 )?;
1593
1594 let a = backend.to_host_cm(a)?;
1595 let b = backend.to_host_cm(b)?;
1596 let c = backend.to_host_cm(c)?;
1597 let cc = backend.to_host_cm(cc)?;
1598 let expected = a * b.adjoint() * alpha + cc * beta;
1599 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1600 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1601 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1602 });
1603 }
1604
1605 {
1606 let a = make_random_cm(&backend, k, m)?;
1607 let b = make_random_cm(&backend, k, n)?;
1608 let mut c = make_random_cm(&backend, m, n)?;
1609 let cc = backend.clone_cm(&c)?;
1610
1611 let alpha = Complex::new(rng.random(), rng.random());
1612 let beta = Complex::new(rng.random(), rng.random());
1613 backend.gemm_c(Trans::Trans, Trans::NoTrans, alpha, &a, &b, beta, &mut c)?;
1614
1615 let a = backend.to_host_cm(a)?;
1616 let b = backend.to_host_cm(b)?;
1617 let c = backend.to_host_cm(c)?;
1618 let cc = backend.to_host_cm(cc)?;
1619 let expected = a.transpose() * b * alpha + cc * beta;
1620 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1621 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1622 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1623 });
1624 }
1625
1626 {
1627 let a = make_random_cm(&backend, k, m)?;
1628 let b = make_random_cm(&backend, n, k)?;
1629 let mut c = make_random_cm(&backend, m, n)?;
1630 let cc = backend.clone_cm(&c)?;
1631
1632 let alpha = Complex::new(rng.random(), rng.random());
1633 let beta = Complex::new(rng.random(), rng.random());
1634 backend.gemm_c(Trans::Trans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1635
1636 let a = backend.to_host_cm(a)?;
1637 let b = backend.to_host_cm(b)?;
1638 let c = backend.to_host_cm(c)?;
1639 let cc = backend.to_host_cm(cc)?;
1640 let expected = a.transpose() * b.transpose() * alpha + cc * beta;
1641 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1642 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1643 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1644 });
1645 }
1646
1647 {
1648 let a = make_random_cm(&backend, k, m)?;
1649 let b = make_random_cm(&backend, n, k)?;
1650 let mut c = make_random_cm(&backend, m, n)?;
1651 let cc = backend.clone_cm(&c)?;
1652
1653 let alpha = Complex::new(rng.random(), rng.random());
1654 let beta = Complex::new(rng.random(), rng.random());
1655 backend.gemm_c(Trans::Trans, Trans::ConjTrans, alpha, &a, &b, beta, &mut c)?;
1656
1657 let a = backend.to_host_cm(a)?;
1658 let b = backend.to_host_cm(b)?;
1659 let c = backend.to_host_cm(c)?;
1660 let cc = backend.to_host_cm(cc)?;
1661 let expected = a.transpose() * b.adjoint() * alpha + cc * beta;
1662 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1663 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1664 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1665 });
1666 }
1667
1668 {
1669 let a = make_random_cm(&backend, k, m)?;
1670 let b = make_random_cm(&backend, k, n)?;
1671 let mut c = make_random_cm(&backend, m, n)?;
1672 let cc = backend.clone_cm(&c)?;
1673
1674 let alpha = Complex::new(rng.random(), rng.random());
1675 let beta = Complex::new(rng.random(), rng.random());
1676 backend.gemm_c(
1677 Trans::ConjTrans,
1678 Trans::NoTrans,
1679 alpha,
1680 &a,
1681 &b,
1682 beta,
1683 &mut c,
1684 )?;
1685
1686 let a = backend.to_host_cm(a)?;
1687 let b = backend.to_host_cm(b)?;
1688 let c = backend.to_host_cm(c)?;
1689 let cc = backend.to_host_cm(cc)?;
1690 let expected = a.adjoint() * b * alpha + cc * beta;
1691 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1692 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1693 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1694 });
1695 }
1696
1697 {
1698 let a = make_random_cm(&backend, k, m)?;
1699 let b = make_random_cm(&backend, n, k)?;
1700 let mut c = make_random_cm(&backend, m, n)?;
1701 let cc = backend.clone_cm(&c)?;
1702
1703 let alpha = Complex::new(rng.random(), rng.random());
1704 let beta = Complex::new(rng.random(), rng.random());
1705 backend.gemm_c(Trans::ConjTrans, Trans::Trans, alpha, &a, &b, beta, &mut c)?;
1706
1707 let a = backend.to_host_cm(a)?;
1708 let b = backend.to_host_cm(b)?;
1709 let c = backend.to_host_cm(c)?;
1710 let cc = backend.to_host_cm(cc)?;
1711 let expected = a.adjoint() * b.transpose() * alpha + cc * beta;
1712 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1713 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1714 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1715 });
1716 }
1717
1718 {
1719 let a = make_random_cm(&backend, k, m)?;
1720 let b = make_random_cm(&backend, n, k)?;
1721 let mut c = make_random_cm(&backend, m, n)?;
1722 let cc = backend.clone_cm(&c)?;
1723
1724 let alpha = Complex::new(rng.random(), rng.random());
1725 let beta = Complex::new(rng.random(), rng.random());
1726 backend.gemm_c(
1727 Trans::ConjTrans,
1728 Trans::ConjTrans,
1729 alpha,
1730 &a,
1731 &b,
1732 beta,
1733 &mut c,
1734 )?;
1735
1736 let a = backend.to_host_cm(a)?;
1737 let b = backend.to_host_cm(b)?;
1738 let c = backend.to_host_cm(c)?;
1739 let cc = backend.to_host_cm(cc)?;
1740 let expected = a.adjoint() * b.adjoint() * alpha + cc * beta;
1741 c.iter().zip(expected.iter()).for_each(|(c, expected)| {
1742 assert_approx_eq::assert_approx_eq!(c.re, expected.re, EPS);
1743 assert_approx_eq::assert_approx_eq!(c.im, expected.im, EPS);
1744 });
1745 }
1746 Ok(())
1747 }
1748
1749 #[rstest::rstest]
1750 #[test]
1751 #[cfg_attr(miri, ignore)]
1752 fn test_solve_inplace(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1753 {
1754 let tmp = make_random_m(&backend, N, N)?;
1755 let tmp = backend.to_host_m(tmp)?;
1756
1757 let a = &tmp * tmp.adjoint();
1758
1759 let mut rng = rand::rng();
1760 let x = VectorX::from_iterator(N, (0..N).map(|_| rng.random()));
1761
1762 let b = &a * &x;
1763
1764 let aa = backend.from_slice_m(N, N, a.as_slice())?;
1765 let mut bb = backend.from_slice_v(b.as_slice())?;
1766
1767 backend.solve_inplace(&aa, &mut bb)?;
1768
1769 let b2 = &a * backend.to_host_v(bb)?;
1770 assert!(approx::relative_eq!(b, b2, epsilon = 1e-3));
1771 }
1772
1773 Ok(())
1774 }
1775
1776 #[rstest::rstest]
1777 #[test]
1778 #[cfg_attr(miri, ignore)]
1779 fn test_reduce_col(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1780 let a = make_random_m(&backend, N, N)?;
1781
1782 let mut b = backend.alloc_v(N)?;
1783
1784 backend.reduce_col(&a, &mut b)?;
1785
1786 let a = backend.to_host_m(a)?;
1787 let b = backend.to_host_v(b)?;
1788
1789 (0..N).for_each(|row| {
1790 let sum = a.row(row).iter().sum::<f32>();
1791 assert_approx_eq::assert_approx_eq!(sum, b[row], EPS);
1792 });
1793 Ok(())
1794 }
1795
1796 #[rstest::rstest]
1797 #[test]
1798 #[cfg_attr(miri, ignore)]
1799 fn test_scaled_to_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1800 let a = make_random_cv(&backend, N)?;
1801 let b = make_random_cv(&backend, N)?;
1802 let mut c = backend.alloc_cv(N)?;
1803
1804 backend.scaled_to_cv(&a, &b, &mut c)?;
1805
1806 let a = backend.to_host_cv(a)?;
1807 let b = backend.to_host_cv(b)?;
1808 let c = backend.to_host_cv(c)?;
1809 c.iter()
1810 .zip(a.iter())
1811 .zip(b.iter())
1812 .for_each(|((&c, &a), &b)| {
1813 assert_approx_eq::assert_approx_eq!(c, a / a.abs() * b, EPS);
1814 });
1815
1816 Ok(())
1817 }
1818
1819 #[rstest::rstest]
1820 #[test]
1821 #[cfg_attr(miri, ignore)]
1822 fn test_scaled_to_assign_cv(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1823 let a = make_random_cv(&backend, N)?;
1824 let mut b = make_random_cv(&backend, N)?;
1825 let bc = backend.clone_cv(&b)?;
1826
1827 backend.scaled_to_assign_cv(&a, &mut b)?;
1828
1829 let a = backend.to_host_cv(a)?;
1830 let b = backend.to_host_cv(b)?;
1831 let bc = backend.to_host_cv(bc)?;
1832 b.iter()
1833 .zip(a.iter())
1834 .zip(bc.iter())
1835 .for_each(|((&b, &a), &bc)| {
1836 assert_approx_eq::assert_approx_eq!(b, bc / bc.abs() * a, EPS);
1837 });
1838
1839 Ok(())
1840 }
1841
1842 #[rstest::rstest]
1843 #[test]
1844 #[case(1, 2)]
1845 #[case(2, 1)]
1846 fn test_generate_propagation_matrix(
1847 #[case] dev_num: usize,
1848 #[case] foci_num: usize,
1849 backend: ArrayFireBackend<Sphere>,
1850 ) -> Result<(), HoloError> {
1851 let reference = |geometry: Geometry, foci: Vec<Point3>| {
1852 let mut g = MatrixXc::zeros(
1853 foci.len(),
1854 geometry
1855 .iter()
1856 .map(|dev| dev.num_transducers())
1857 .sum::<usize>(),
1858 );
1859 let transducers = geometry
1860 .iter()
1861 .flat_map(|dev| dev.iter().map(|tr| (dev.idx(), tr)))
1862 .collect::<Vec<_>>();
1863 (0..foci.len()).for_each(|i| {
1864 (0..transducers.len()).for_each(|j| {
1865 g[(i, j)] = propagate::<Sphere>(
1866 transducers[j].1,
1867 geometry[transducers[j].0].wavenumber(),
1868 geometry[transducers[j].0].axial_direction(),
1869 &foci[i],
1870 )
1871 })
1872 });
1873 g
1874 };
1875
1876 let geometry = generate_geometry(dev_num);
1877 let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1878
1879 let g = backend.generate_propagation_matrix(&geometry, &foci, None)?;
1880 let g = backend.to_host_cm(g)?;
1881 reference(geometry, foci)
1882 .iter()
1883 .zip(g.iter())
1884 .for_each(|(r, g)| {
1885 assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1886 assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1887 });
1888
1889 Ok(())
1890 }
1891
1892 #[rstest::rstest]
1893 #[test]
1894 #[case(1, 2)]
1895 #[case(2, 1)]
1896 fn test_generate_propagation_matrix_with_filter(
1897 #[case] dev_num: usize,
1898 #[case] foci_num: usize,
1899 backend: ArrayFireBackend<Sphere>,
1900 ) -> Result<(), HoloError> {
1901 use std::collections::HashMap;
1902
1903 let filter = |geometry: &Geometry| {
1904 geometry
1905 .iter()
1906 .map(|dev| {
1907 let mut filter = BitVec::new();
1908 dev.iter().for_each(|tr| {
1909 filter.push(tr.idx() > dev.num_transducers() / 2);
1910 });
1911 (dev.idx(), filter)
1912 })
1913 .collect::<HashMap<_, _>>()
1914 };
1915
1916 let reference = |geometry, foci: Vec<Point3>| {
1917 let filter = filter(&geometry);
1918 let transducers = geometry
1919 .iter()
1920 .flat_map(|dev| {
1921 dev.iter().filter_map(|tr| {
1922 if filter[&dev.idx()][tr.idx()] {
1923 Some((dev.idx(), tr))
1924 } else {
1925 None
1926 }
1927 })
1928 })
1929 .collect::<Vec<_>>();
1930
1931 let mut g = MatrixXc::zeros(foci.len(), transducers.len());
1932 (0..foci.len()).for_each(|i| {
1933 (0..transducers.len()).for_each(|j| {
1934 g[(i, j)] = propagate::<Sphere>(
1935 transducers[j].1,
1936 geometry[transducers[j].0].wavenumber(),
1937 geometry[transducers[j].0].axial_direction(),
1938 &foci[i],
1939 )
1940 })
1941 });
1942 g
1943 };
1944
1945 let geometry = generate_geometry(dev_num);
1946 let foci = gen_foci(foci_num).map(|(p, _)| p).collect::<Vec<_>>();
1947 let filter = filter(&geometry);
1948
1949 let g = backend.generate_propagation_matrix(&geometry, &foci, Some(&filter))?;
1950 let g = backend.to_host_cm(g)?;
1951 assert_eq!(g.nrows(), foci.len());
1952 assert_eq!(
1953 g.ncols(),
1954 geometry
1955 .iter()
1956 .map(|dev| dev.num_transducers() / 2)
1957 .sum::<usize>()
1958 );
1959 reference(geometry, foci)
1960 .iter()
1961 .zip(g.iter())
1962 .for_each(|(r, g)| {
1963 assert_approx_eq::assert_approx_eq!(r.re, g.re, EPS);
1964 assert_approx_eq::assert_approx_eq!(r.im, g.im, EPS);
1965 });
1966
1967 Ok(())
1968 }
1969
1970 #[rstest::rstest]
1971 #[test]
1972 fn test_gen_back_prop(backend: ArrayFireBackend<Sphere>) -> Result<(), HoloError> {
1973 let geometry = generate_geometry(1);
1974 let foci = gen_foci(2).map(|(p, _)| p).collect::<Vec<_>>();
1975
1976 let m = geometry
1977 .iter()
1978 .map(|dev| dev.num_transducers())
1979 .sum::<usize>();
1980 let n = foci.len();
1981
1982 let g = backend.generate_propagation_matrix(&geometry, &foci, None)?;
1983
1984 let b = backend.gen_back_prop(m, n, &g)?;
1985 let g = backend.to_host_cm(g)?;
1986 let reference = {
1987 let mut b = MatrixXc::zeros(m, n);
1988 (0..n).for_each(|i| {
1989 let x = 1.0 / g.rows(i, 1).iter().map(|x| x.norm_sqr()).sum::<f32>();
1990 (0..m).for_each(|j| {
1991 b[(j, i)] = g[(i, j)].conj() * x;
1992 })
1993 });
1994 b
1995 };
1996
1997 let b = backend.to_host_cm(b)?;
1998 reference.iter().zip(b.iter()).for_each(|(r, b)| {
1999 assert_approx_eq::assert_approx_eq!(r.re, b.re, EPS);
2000 assert_approx_eq::assert_approx_eq!(r.im, b.im, EPS);
2001 });
2002 Ok(())
2003 }
2004}