basic/basic.rs
1use approx::assert_abs_diff_eq;
2use nalgebra::{Const, SVector};
3use raddy::make::var;
4use raddy::Ad;
5use rand::{thread_rng, Rng};
6
7const EPS: f64 = 1e-10;
8
9fn main() {
10 // 1.
11 // ################ scalar ################
12 let mut rng = thread_rng();
13 let val = rng.gen_range(0.0..10.0);
14
15 let var = var::scalar(val);
16 let var = &var;
17 let y = var.sin() * var + var.ln();
18 let g = val * val.cos() + val.sin() + val.recip();
19 let h = -val * val.sin() + 2.0 * val.cos() - val.powi(-2);
20
21 assert_abs_diff_eq!(y.grad()[(0, 0)], g, epsilon = EPS);
22 assert_abs_diff_eq!(y.hess()[(0, 0)], h, epsilon = EPS);
23
24 // 2.
25 // ############################# Matrix #############################
26
27 const N_TEST_MAT_4: usize = 4;
28 type NaConst = Const<N_TEST_MAT_4>;
29 const N_VEC_4: usize = N_TEST_MAT_4 * N_TEST_MAT_4;
30
31 let vals: &[f64] = &(0..N_VEC_4)
32 .map(|_| rng.gen_range(-4.0..4.0))
33 .collect::<Vec<_>>();
34
35 let s: SVector<Ad<N_VEC_4>, N_VEC_4> = var::vector_from_slice(vals);
36 let z = s
37 .clone()
38 // This reshape is COL MAJOR!!!!!!!!!!!!!
39 .reshape_generic(NaConst {}, NaConst {})
40 .transpose();
41
42 let det = z.determinant();
43 let _grad = det.grad();
44 let _hess = det.hess();
45 // core logic ends ####################################################
46
47 // correctness
48 // let expected_grad = grad_det4(
49 // vals[0], vals[1], vals[2], vals[3], vals[4], vals[5], vals[6], vals[7], vals[8], vals[9],
50 // vals[10], vals[11], vals[12], vals[13], vals[14], vals[15],
51 // );
52 // let g_diff = (expected_grad - det.grad()).norm_squared();
53 // assert_abs_diff_eq!(g_diff, 0.0, epsilon = EPS);
54
55 // let expected_hess = hess_det4(
56 // vals[0], vals[1], vals[2], vals[3], vals[4], vals[5], vals[6], vals[7], vals[8], vals[9],
57 // vals[10], vals[11], vals[12], vals[13], vals[14], vals[15],
58 // );
59 // let h_diff = (det.hess() - expected_hess).norm_squared();
60 // assert_abs_diff_eq!(h_diff, 0.0, epsilon = EPS);
61}