1use num_complex::Complex;
2
3use mdarray::{DTensor, tensor};
4
5use crate::prelude::*;
6
7use crate::matmul::{Triangle, Type};
8
9pub fn test_eval_and_overwrite(bd: impl MatVec<f64>) {
10 let n = 3;
11 let x = DTensor::<f64, 1>::from_elem(n, 1.);
12 let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * n + i[1] + 1) as f64);
13 let y_result = bd.matvec(&a, &x).scale(2.).eval();
14 let y = DTensor::<f64, 1>::from_fn([n], |i| 2. * (6. + i[0] as f64 * 9.));
15 assert_eq!(y_result, y);
16
17 let mut y_overwritten = DTensor::<f64, 1>::from_elem(n, 0.);
18 bd.matvec(&a, &x).scale(2.).overwrite(&mut y_overwritten);
19 assert_eq!(y_overwritten, y);
20}
21
22pub fn test_add_to_scaled(bd: impl MatVec<f64>) {
23 let n = 3;
24 let x = DTensor::<f64, 1>::from_elem(n, 1.);
25 let mut x2 = DTensor::<f64, 1>::from_elem(n, 1.);
26 let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * 2 + i[1] + 1) as f64);
27 bd.matvec(&a, &x).add_to_scaled(&mut x2, 2.);
28 let y = DTensor::<f64, 1>::from_fn([n], |i| 2.0 * 1.0 + (6.0 + i[0] as f64 * 6.0));
29
30 assert_eq!(x2, y);
31}
32
33pub fn test_add_to(bd: impl MatVec<f64>) {
34 let n = 3;
35 let x = DTensor::<f64, 1>::from_elem(n, 1.);
36 let mut x2 = DTensor::<f64, 1>::from_elem(n, 1.);
37 let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[1] * 2 + i[0] + 1) as f64);
38 bd.matvec(&a, &x).add_to(&mut x2);
39 let y = DTensor::<f64, 1>::from_fn([n], |i| 10. + 3. * i[0] as f64);
40 assert_eq!(x2, y);
41}
42
43pub fn test_add_outer_basic(bd: impl MatVec<f64>) {
44 let m = 2;
45 let n = 3;
46
47 let x = DTensor::<f64, 1>::from_fn([m], |i| (i[0] + 1) as f64);
48 let y = DTensor::<f64, 1>::from_fn([n], |i| 10f64.powi(i[0] as i32));
49 let a = DTensor::<f64, 2>::from_fn([m, n], |i| if i[0] == i[1] { 1.0 } else { 0.0 });
50 let beta = 2.0;
51 let a_updated = bd.matvec(&a, &x).add_outer(&y, beta);
52
53 let expected = DTensor::<f64, 2>::from_fn([m, n], |i| {
54 let (row, col) = (i[0], i[1]);
55 let a_val = if row == col { 1.0 } else { 0.0 };
56 a_val + beta * (x[[row]]) * (y[[col]])
57 });
58
59 assert_eq!(a_updated, expected);
60}
61
62pub fn test_add_outer_sym(bd: impl MatVec<f64>) {
63 let n = 3;
64
65 let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
66 let a = DTensor::<f64, 2>::from_fn([n, n], |i| {
67 let (row, col) = (i[0], i[1]);
68 if row == col { 2.0 } else { 1.0 }
69 });
70 let beta = 0.5;
71
72 let a_updated = bd
73 .matvec(&a, &x)
74 .add_outer_special(beta, Type::Sym, Triangle::Upper);
75
76 let expected = DTensor::<f64, 2>::from_fn([n, n], |i| {
77 let (row, col) = (i[0], i[1]);
78 let a_val = if row == col { 2.0 } else { 1.0 };
79 if row <= col {
80 a_val + beta * (x[[row]]) * (x[[col]])
81 } else {
82 a_val
83 }
84 });
85
86 assert_eq!(a_updated, expected);
87}
88
89pub fn test_add_outer_her(bd: impl MatVec<Complex<f64>>) {
90 use num_complex::Complex64;
91
92 let n = 3;
93
94 let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
95 Complex64::new((i[0] + 1) as f64, (i[0] as f64) * 0.5)
96 });
97
98 let a = DTensor::<Complex64, 2>::from_fn([n, n], |i| {
99 let (row, col) = (i[0], i[1]);
100 if row == col {
101 Complex64::new(2.0, 0.0)
102 } else if row < col {
103 Complex64::new(1.0, 0.5)
104 } else {
105 Complex64::new(1.0, -0.5)
106 }
107 });
108 let beta = 0.3;
109
110 let a_updated =
111 bd.matvec(&a, &x)
112 .add_outer_special(Complex64::new(beta, 0.0), Type::Her, Triangle::Upper);
113
114 let expected = DTensor::<Complex64, 2>::from_fn([n, n], |i| {
115 let (row, col) = (i[0], i[1]);
116 let a_val = if row == col {
117 Complex64::new(2.0, 0.0)
118 } else if row < col {
119 Complex64::new(1.0, 0.5)
120 } else {
121 Complex64::new(1.0, -0.5)
122 };
123
124 if row <= col {
125 a_val + Complex64::new(beta, 0.0) * x[[row]] * x[[col]].conj()
126 } else {
127 a_val
128 }
129 });
130
131 assert_eq!(a_updated, expected);
132}
133
134pub fn test_add_to_scaled_vecvec(bd: impl VecOps<f64>) {
135 let n = 3;
136 let alpha = 2.0;
137 let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64); let mut y = DTensor::<f64, 1>::from_elem(n, 1.0); bd.add_to_scaled(alpha, &x, &mut y);
141
142 let expected = DTensor::<f64, 1>::from_fn([n], |i| 1.0 + alpha * (i[0] + 1) as f64);
143 assert_eq!(y, expected);
144}
145
146pub fn test_dot_real(bd: impl VecOps<f64>) {
147 let n = 3;
148 let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64); let y = DTensor::<f64, 1>::from_fn([n], |i| (2 * (i[0] + 1)) as f64); assert_eq!(bd.dot(&x, &y), 28.0);
153}
154
155pub fn test_dot_complex(bd: impl VecOps<Complex<f64>>) {
156 use num_complex::Complex64;
157 let n = 3;
158 let x = DTensor::<Complex64, 1>::from_fn([n], |i| Complex64::new((i[0] + 1) as f64, 0.)); let y = DTensor::<Complex64, 1>::from_fn([n], |i| Complex64::new(0., (2 * (i[0] + 1)) as f64)); let expected = Complex64::new(0.0, 28.0);
162
163 assert_eq!(bd.dot(&x, &y), expected);
164}
165
166pub fn test_dotc_complex(bd: impl VecOps<Complex<f64>>) {
167 use num_complex::Complex64;
168
169 let n = 2;
170 let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
171 Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
172 }); let y = DTensor::<Complex64, 1>::from_fn([n], |i| {
174 Complex64::new((i[0] + 3) as f64, (i[0] + 4) as f64)
175 }); let result = bd.dotc(&x, &y);
178
179 println!("{result:?}");
180
181 let expected = x[[0]].conj() * y[[0]] + x[[1]].conj() * y[[1]];
183 assert_eq!(result, expected);
184}
185
186pub fn test_norm1_complex(bd: impl VecOps<Complex<f64>>) {
187 use num_complex::Complex64;
188
189 let n = 3;
190 let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
191 Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
192 });
193 let expected: f64 = x.iter().map(|z| z.re.abs() + z.im.abs()).sum();
196
197 let result = bd.norm1(&x);
198
199 println!("{result}");
200 println!("{expected}");
201
202 assert!((result - expected).abs() < 1e-12);
203}
204
205pub fn test_norm2_complex(bd: impl VecOps<Complex<f64>>) {
206 use num_complex::Complex64;
207
208 let n = 3;
209 let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
210 Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
211 });
212 let expected: f64 = x.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
215
216 let result = bd.norm2(&x);
217
218 assert!((result - expected).abs() < 1e-12);
219}
220
221pub fn test_argmax_real(bd: impl Argmax<f64>) {
222 use mdarray::DTensor;
223
224 let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
226 let idx = bd.argmax(&x);
227 println!("Empty: {idx:?}");
228 assert_eq!(idx, None);
229
230 let x = tensor![42.];
232 let idx = bd.argmax(&x).unwrap();
233 println!("Scalar: {idx:?}");
234 assert_eq!(idx, vec![0]); let n = 5;
238 let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
239 let idx = bd.argmax(&x.view(..)).unwrap();
240 println!("{idx:?}");
241 assert_eq!(idx, vec![4]);
242
243 let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
245
246 let idx = bd.argmax(&x.view(.., ..).into_dyn()).unwrap();
249 println!("{idx:?}");
250 assert_eq!(idx, vec![1, 2]);
251
252 let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
254 let idx = bd.argmax(&x.view(.., .., ..).into_dyn()).unwrap();
255 println!("{idx:?}");
256 assert_eq!(idx, vec![1, 1, 1]);
257}
258
259pub fn test_argmax_overwrite_real(bd: impl Argmax<f64>) {
260 let mut output = Vec::new();
261
262 let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
264 let success = bd.argmax_overwrite(&x, &mut output);
265 assert!(!success);
266 assert_eq!(output, vec![]);
267
268 let x = tensor![42.];
270 let success = bd.argmax_overwrite(&x, &mut output);
271 assert!(success);
272 assert_eq!(output, vec![0]);
273
274 let n = 5;
276 let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
277 let success = bd.argmax_overwrite(&x.view(..), &mut output);
278 assert!(success);
279 assert_eq!(output, vec![4]);
280
281 let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
283 let success = bd.argmax_overwrite(&x.view(.., ..).into_dyn(), &mut output);
286 assert!(success);
287 assert_eq!(output, vec![1, 2]);
288
289 let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
291 let success = bd.argmax_overwrite(&x.view(.., .., ..).into_dyn(), &mut output);
292 assert!(success);
293 assert_eq!(output, vec![1, 1, 1]);
294
295 output = vec![99, 99, 99];
297 let x = DTensor::<f64, 1>::from_fn([3], |i| (3 - i[0]) as f64);
298 let success = bd.argmax_overwrite(&x.view(..), &mut output);
299 assert!(success);
300 assert_eq!(output, vec![0]); }
302
303pub fn test_argmax_abs(bd: impl Argmax<f64>) {
304 use mdarray::DTensor;
305
306 let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
308 let idx = bd.argmax_abs(&x);
309 println!("Empty: {idx:?}");
310 assert_eq!(idx, None);
311
312 let x = tensor![42.];
314 let idx = bd.argmax_abs(&x).unwrap();
315 println!("Scalar: {idx:?}");
316 assert_eq!(idx, vec![0]); let n = 6;
320 let x = DTensor::<f64, 1>::from_fn([n], |i| {
321 if i[0] % 2 == 0 {
322 (i[0] as i32 + 1) as f64
323 } else {
324 -(i[0] as i32 + 1) as f64
325 }
326 });
327 let idx = bd.argmax_abs(&x.view(..)).unwrap();
328 println!("{idx:?}");
329 assert_eq!(idx, vec![5]);
330
331 let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
333
334 let idx = bd.argmax_abs(&x.view(.., ..).into_dyn()).unwrap();
337 println!("{idx:?}");
338 assert_eq!(idx, vec![1, 2]);
339
340 let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
342 let idx = bd.argmax_abs(&x.view(.., .., ..).into_dyn()).unwrap();
343 println!("{idx:?}");
344 assert_eq!(idx, vec![1, 1, 1]);
345}