1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#[cfg(not(feature = "threading"))]
use std::{
rc::Rc,
cell::RefCell,
};
#[cfg(feature = "threading")]
use {
std::sync::Arc,
parking_lot::RwLock,
};
use rand::Rng;
use crate::{
scalar::Real,
};
#[cfg(not(feature = "threading"))]
pub type RcT<T> = Rc<T>;
#[cfg(feature = "threading")]
pub type RcT<T> = Arc<T>;
#[cfg(not(feature = "threading"))]
pub type RcCell<T> = Rc<RefCell<T>>;
#[cfg(feature = "threading")]
pub type RcCell<T> = Arc<RwLock<T>>;
#[inline]
pub fn negative_index(i: isize, n: usize, start_behind: bool) -> usize {
if i < 0 {
let offset = if start_behind { 1 } else { 0 };
let out = n as isize + i + offset;
assert!(out >= 0, "Negative index {i} into rank {n} shape");
out as usize
} else {
i as usize
}
}
pub fn randn<T: Real>() -> (T, T) {
let mut rng = rand::thread_rng();
let u = rng.gen_range(-T::one(), T::one());
let v = rng.gen_range(-T::one(), T::one());
let r = u * u + v * v;
if r == T::zero() || r >= T::one() { return randn() }
let c = (T::from(-2.0).unwrap() * r.ln() / r).sqrt();
(u * c, v * c)
}