mdarray_linalg/testing/matvec/
mod.rs

1use num_complex::Complex;
2
3use mdarray::{DTensor, tensor};
4
5use crate::prelude::*;
6
7use crate::matmul::{Triangle, Type};
8
9pub fn test_eval_and_overwrite(bd: impl MatVec<f64>) {
10    let n = 3;
11    let x = DTensor::<f64, 1>::from_elem(n, 1.);
12    let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * n + i[1] + 1) as f64);
13    let y_result = bd.matvec(&a, &x).scale(2.).eval();
14    let y = DTensor::<f64, 1>::from_fn([n], |i| 2. * (6. + i[0] as f64 * 9.));
15    assert_eq!(y_result, y);
16
17    let mut y_overwritten = DTensor::<f64, 1>::from_elem(n, 0.);
18    bd.matvec(&a, &x).scale(2.).overwrite(&mut y_overwritten);
19    assert_eq!(y_overwritten, y);
20}
21
22pub fn test_add_to_scaled(bd: impl MatVec<f64>) {
23    let n = 3;
24    let x = DTensor::<f64, 1>::from_elem(n, 1.);
25    let mut x2 = DTensor::<f64, 1>::from_elem(n, 1.);
26    let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * 2 + i[1] + 1) as f64);
27    bd.matvec(&a, &x).add_to_scaled(&mut x2, 2.);
28    let y = DTensor::<f64, 1>::from_fn([n], |i| 2.0 * 1.0 + (6.0 + i[0] as f64 * 6.0));
29
30    assert_eq!(x2, y);
31}
32
33pub fn test_add_to(bd: impl MatVec<f64>) {
34    let n = 3;
35    let x = DTensor::<f64, 1>::from_elem(n, 1.);
36    let mut x2 = DTensor::<f64, 1>::from_elem(n, 1.);
37    let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[1] * 2 + i[0] + 1) as f64);
38    bd.matvec(&a, &x).add_to(&mut x2);
39    let y = DTensor::<f64, 1>::from_fn([n], |i| 10. + 3. * i[0] as f64);
40    assert_eq!(x2, y);
41}
42
43pub fn test_add_outer_basic(bd: impl MatVec<f64>) {
44    let m = 2;
45    let n = 3;
46
47    let x = DTensor::<f64, 1>::from_fn([m], |i| (i[0] + 1) as f64);
48    let y = DTensor::<f64, 1>::from_fn([n], |i| 10f64.powi(i[0] as i32));
49    let a = DTensor::<f64, 2>::from_fn([m, n], |i| if i[0] == i[1] { 1.0 } else { 0.0 });
50    let beta = 2.0;
51    let a_updated = bd.matvec(&a, &x).add_outer(&y, beta);
52
53    let expected = DTensor::<f64, 2>::from_fn([m, n], |i| {
54        let (row, col) = (i[0], i[1]);
55        let a_val = if row == col { 1.0 } else { 0.0 };
56        a_val + beta * (x[[row]]) * (y[[col]])
57    });
58
59    assert_eq!(a_updated, expected);
60}
61
62pub fn test_add_outer_sym(bd: impl MatVec<f64>) {
63    let n = 3;
64
65    let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
66    let a = DTensor::<f64, 2>::from_fn([n, n], |i| {
67        let (row, col) = (i[0], i[1]);
68        if row == col { 2.0 } else { 1.0 }
69    });
70    let beta = 0.5;
71
72    let a_updated = bd
73        .matvec(&a, &x)
74        .add_outer_special(beta, Type::Sym, Triangle::Upper);
75
76    let expected = DTensor::<f64, 2>::from_fn([n, n], |i| {
77        let (row, col) = (i[0], i[1]);
78        let a_val = if row == col { 2.0 } else { 1.0 };
79        if row <= col {
80            a_val + beta * (x[[row]]) * (x[[col]])
81        } else {
82            a_val
83        }
84    });
85
86    assert_eq!(a_updated, expected);
87}
88
89pub fn test_add_outer_her(bd: impl MatVec<Complex<f64>>) {
90    use num_complex::Complex64;
91
92    let n = 3;
93
94    let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
95        Complex64::new((i[0] + 1) as f64, (i[0] as f64) * 0.5)
96    });
97
98    let a = DTensor::<Complex64, 2>::from_fn([n, n], |i| {
99        let (row, col) = (i[0], i[1]);
100        if row == col {
101            Complex64::new(2.0, 0.0)
102        } else if row < col {
103            Complex64::new(1.0, 0.5)
104        } else {
105            Complex64::new(1.0, -0.5)
106        }
107    });
108    let beta = 0.3;
109
110    let a_updated =
111        bd.matvec(&a, &x)
112            .add_outer_special(Complex64::new(beta, 0.0), Type::Her, Triangle::Upper);
113
114    let expected = DTensor::<Complex64, 2>::from_fn([n, n], |i| {
115        let (row, col) = (i[0], i[1]);
116        let a_val = if row == col {
117            Complex64::new(2.0, 0.0)
118        } else if row < col {
119            Complex64::new(1.0, 0.5)
120        } else {
121            Complex64::new(1.0, -0.5)
122        };
123
124        if row <= col {
125            a_val + Complex64::new(beta, 0.0) * x[[row]] * x[[col]].conj()
126        } else {
127            a_val
128        }
129    });
130
131    assert_eq!(a_updated, expected);
132}
133
134pub fn test_add_to_scaled_vecvec(bd: impl VecOps<f64>) {
135    let n = 3;
136    let alpha = 2.0;
137    let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64); // [1., 2., 3.]
138    let mut y = DTensor::<f64, 1>::from_elem(n, 1.0); // [1., 1., 1.]
139
140    bd.add_to_scaled(alpha, &x, &mut y);
141
142    let expected = DTensor::<f64, 1>::from_fn([n], |i| 1.0 + alpha * (i[0] + 1) as f64);
143    assert_eq!(y, expected);
144}
145
146pub fn test_dot_real(bd: impl VecOps<f64>) {
147    let n = 3;
148    let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64); // [1., 2., 3.]
149    let y = DTensor::<f64, 1>::from_fn([n], |i| (2 * (i[0] + 1)) as f64); // [2., 4., 6.]
150
151    // dot(x, y) = 1*2 + 2*4 + 3*6 = 28
152    assert_eq!(bd.dot(&x, &y), 28.0);
153}
154
155pub fn test_dot_complex(bd: impl VecOps<Complex<f64>>) {
156    use num_complex::Complex64;
157    let n = 3;
158    let x = DTensor::<Complex64, 1>::from_fn([n], |i| Complex64::new((i[0] + 1) as f64, 0.)); // [1., 2., 3.]
159    let y = DTensor::<Complex64, 1>::from_fn([n], |i| Complex64::new(0., (2 * (i[0] + 1)) as f64)); // [2i, 4i, 6i]
160
161    let expected = Complex64::new(0.0, 28.0);
162
163    assert_eq!(bd.dot(&x, &y), expected);
164}
165
166pub fn test_dotc_complex(bd: impl VecOps<Complex<f64>>) {
167    use num_complex::Complex64;
168
169    let n = 2;
170    let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
171        Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
172    }); // [(1+2i), (2+3i)]
173    let y = DTensor::<Complex64, 1>::from_fn([n], |i| {
174        Complex64::new((i[0] + 3) as f64, (i[0] + 4) as f64)
175    }); // [(3+4i), (4+5i)]
176
177    let result = bd.dotc(&x, &y);
178
179    println!("{result:?}");
180
181    // dotc(x, y) = conj(x1)*y1 + conj(x2)*y2
182    let expected = x[[0]].conj() * y[[0]] + x[[1]].conj() * y[[1]];
183    assert_eq!(result, expected);
184}
185
186pub fn test_norm1_complex(bd: impl VecOps<Complex<f64>>) {
187    use num_complex::Complex64;
188
189    let n = 3;
190    let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
191        Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
192    });
193    // x = [1+2i, 2+3i, 3+4i]
194    // norm1 = sum(|z_k|)
195    let expected: f64 = x.iter().map(|z| z.re.abs() + z.im.abs()).sum();
196
197    let result = bd.norm1(&x);
198
199    println!("{result}");
200    println!("{expected}");
201
202    assert!((result - expected).abs() < 1e-12);
203}
204
205pub fn test_norm2_complex(bd: impl VecOps<Complex<f64>>) {
206    use num_complex::Complex64;
207
208    let n = 3;
209    let x = DTensor::<Complex64, 1>::from_fn([n], |i| {
210        Complex64::new((i[0] + 1) as f64, (i[0] + 2) as f64)
211    });
212    // x = [1+2i, 2+3i, 3+4i]
213    // norm2 = sqrt(sum(|z_k|²))
214    let expected: f64 = x.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
215
216    let result = bd.norm2(&x);
217
218    assert!((result - expected).abs() < 1e-12);
219}
220
221pub fn test_argmax_real(bd: impl Argmax<f64>) {
222    use mdarray::DTensor;
223
224    // ----- Empty tensor -----
225    let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
226    let idx = bd.argmax(&x);
227    println!("Empty: {idx:?}");
228    assert_eq!(idx, None);
229
230    // ----- Scalar (rank 0) -----
231    let x = tensor![42.];
232    let idx = bd.argmax(&x).unwrap();
233    println!("Scalar: {idx:?}");
234    assert_eq!(idx, vec![0]); // Empty vec for scalar
235
236    // ----- 1D -----
237    let n = 5;
238    let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
239    let idx = bd.argmax(&x.view(..)).unwrap();
240    println!("{idx:?}");
241    assert_eq!(idx, vec![4]);
242
243    // ----- 2D -----
244    let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
245
246    // [[0., 1., 2.],
247    //  [3., 4., 5.]]
248    let idx = bd.argmax(&x.view(.., ..).into_dyn()).unwrap();
249    println!("{idx:?}");
250    assert_eq!(idx, vec![1, 2]);
251
252    // ----- 3D -----
253    let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
254    let idx = bd.argmax(&x.view(.., .., ..).into_dyn()).unwrap();
255    println!("{idx:?}");
256    assert_eq!(idx, vec![1, 1, 1]);
257}
258
259pub fn test_argmax_overwrite_real(bd: impl Argmax<f64>) {
260    let mut output = Vec::new();
261
262    // ----- Empty tensor -----
263    let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
264    let success = bd.argmax_overwrite(&x, &mut output);
265    assert!(!success);
266    assert_eq!(output, vec![]);
267
268    // ----- Scalar (rank 0) -----
269    let x = tensor![42.];
270    let success = bd.argmax_overwrite(&x, &mut output);
271    assert!(success);
272    assert_eq!(output, vec![0]);
273
274    // ----- 1D -----
275    let n = 5;
276    let x = DTensor::<f64, 1>::from_fn([n], |i| (i[0] + 1) as f64);
277    let success = bd.argmax_overwrite(&x.view(..), &mut output);
278    assert!(success);
279    assert_eq!(output, vec![4]);
280
281    // ----- 2D -----
282    let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
283    // [[0., 1., 2.],
284    //  [3., 4., 5.]]
285    let success = bd.argmax_overwrite(&x.view(.., ..).into_dyn(), &mut output);
286    assert!(success);
287    assert_eq!(output, vec![1, 2]);
288
289    // ----- 3D -----
290    let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
291    let success = bd.argmax_overwrite(&x.view(.., .., ..).into_dyn(), &mut output);
292    assert!(success);
293    assert_eq!(output, vec![1, 1, 1]);
294
295    // ----- Test reuse of output buffer -----
296    output = vec![99, 99, 99];
297    let x = DTensor::<f64, 1>::from_fn([3], |i| (3 - i[0]) as f64);
298    let success = bd.argmax_overwrite(&x.view(..), &mut output);
299    assert!(success);
300    assert_eq!(output, vec![0]); // Should be cleared and contain only result
301}
302
303pub fn test_argmax_abs(bd: impl Argmax<f64>) {
304    use mdarray::DTensor;
305
306    // ----- Empty tensor -----
307    let x = DTensor::<f64, 1>::from_fn([0], |_| 0.0);
308    let idx = bd.argmax_abs(&x);
309    println!("Empty: {idx:?}");
310    assert_eq!(idx, None);
311
312    // ----- Scalar (rank 0) -----
313    let x = tensor![42.];
314    let idx = bd.argmax_abs(&x).unwrap();
315    println!("Scalar: {idx:?}");
316    assert_eq!(idx, vec![0]); // Empty vec for scalar
317
318    // ----- 1D -----
319    let n = 6;
320    let x = DTensor::<f64, 1>::from_fn([n], |i| {
321        if i[0] % 2 == 0 {
322            (i[0] as i32 + 1) as f64
323        } else {
324            -(i[0] as i32 + 1) as f64
325        }
326    });
327    let idx = bd.argmax_abs(&x.view(..)).unwrap();
328    println!("{idx:?}");
329    assert_eq!(idx, vec![5]);
330
331    // ----- 2D -----
332    let x = DTensor::<f64, 2>::from_fn([2, 3], |i| (i[0] * 3 + i[1]) as f64);
333
334    // [[0., 1., 2.],
335    //  [3., 4., 5.]]
336    let idx = bd.argmax_abs(&x.view(.., ..).into_dyn()).unwrap();
337    println!("{idx:?}");
338    assert_eq!(idx, vec![1, 2]);
339
340    // ----- 3D -----
341    let x = DTensor::<f64, 3>::from_fn([2, 2, 2], |i| (i[0] * 4 + i[1] * 2 + i[2]) as f64);
342    let idx = bd.argmax_abs(&x.view(.., .., ..).into_dyn()).unwrap();
343    println!("{idx:?}");
344    assert_eq!(idx, vec![1, 1, 1]);
345}