rstsr_core/dev_utilities/
mod.rs

1#![allow(dead_code)]
2
3use crate::prelude::*;
4use crate::prelude_dev::*;
5use num::complex::ComplexFloat;
6
7/// Compare two tensors with f64 data type.
8///
9/// This function assumes c-contiguous iteration, and will not check two
10/// dimensions are broadcastable.
11pub fn allclose_f64<RA, RB, DA, DB, BA, BB>(a: &TensorAny<RA, f64, BA, DA>, b: &TensorAny<RB, f64, BB, DB>) -> bool
12where
13    RA: DataAPI<Data = <BA as DeviceRawAPI<f64>>::Raw>,
14    RB: DataAPI<Data = <BB as DeviceRawAPI<f64>>::Raw>,
15    DA: DimAPI,
16    DB: DimAPI,
17    BA: DeviceAPI<f64, Raw = Vec<f64>>,
18    BB: DeviceAPI<f64, Raw = Vec<f64>>,
19{
20    let la = a.layout().reverse_axes();
21    let lb = b.layout().reverse_axes();
22    if la.size() != lb.size() {
23        return false;
24    }
25    let it_la = IterLayoutColMajor::new(&la).unwrap();
26    let it_lb = IterLayoutColMajor::new(&lb).unwrap();
27    let data_a = a.raw();
28    let data_b = b.raw();
29    let atol = 1e-8;
30    let rtol = 1e-5;
31    for (idx_a, idx_b) in izip!(it_la, it_lb) {
32        let va = data_a[idx_a];
33        let vb = data_b[idx_b];
34        let comp = (va - vb).abs() <= atol + rtol * vb.abs();
35        if !comp {
36            return false;
37        }
38    }
39    return true;
40}
41
42/// Get a somehow unique fingerprint of a tensor.
43///
44/// # See also
45///
46/// PySCF `pyscf.lib.misc.fingerprint`
47/// <https://github.com/pyscf/pyscf/blob/6f6d3741bf42543e02ccaa1d4ef43d9bf83b3dda/pyscf/lib/misc.py#L1249-L1253>
48pub fn fingerprint<R, T, B, D>(a: &TensorAny<R, T, B, D>) -> T
49where
50    T: ComplexFloat,
51    D: DimAPI,
52    B: DeviceAPI<T>
53        + DeviceRawAPI<MaybeUninit<T>>
54        + DeviceCreationComplexFloatAPI<T>
55        + DeviceCosAPI<T, IxD, TOut = T>
56        + DeviceCreationAnyAPI<T>
57        + OpAssignAPI<T, IxD>
58        + OpAssignArbitaryAPI<T, IxD, D>
59        + OpAssignArbitaryAPI<T, D, D>
60        + DeviceMatMulAPI<T, T, T, IxD, IxD, IxD>,
61    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
62    <B as DeviceRawAPI<T>>::Raw: Clone,
63    for<'a> R: DataIntoCowAPI<'a>,
64{
65    let range = linspace((T::zero(), T::from(a.size()).unwrap(), a.size(), false, a.device()));
66    let val = a.to_contig(RowMajor).reshape(-1) % range.cos();
67    val.to_scalar()
68}