numrs/backend/cpu/
random.rs

1//! Pseudo-random number generation for internal use (Dropout, etc.)
2//!
3//! Uses a simple Xorshift128+ algorithm for speed.
4//! Not cryptographically secure, but sufficient for ML randomness.
5
6use std::cell::RefCell;
7
8
9thread_local! {
10    static RNG: RefCell<Xorshift128Plus> = RefCell::new(Xorshift128Plus::new_from_time());
11}
12
13struct Xorshift128Plus {
14    state: [u64; 2],
15}
16
17impl Xorshift128Plus {
18    fn new(seed: u64) -> Self {
19        let mut rng = Self { state: [0, 0] };
20        rng.seed(seed);
21        rng
22    }
23
24    fn new_from_time() -> Self {
25        // Simple seeding from address/stack pointer mix if strict time not available,
26        // but for now let's use a fixed seed + some variation if possible.
27        // In real impl, would use std::time or getrandom via syscall if allowed.
28        // For no-std compat, simpler is better.
29        // Here we just use a default seed to ensure reproducibility by default,
30        // or mix with some pointer values.
31        let seed = 123456789; 
32        Self::new(seed)
33    }
34
35    fn seed(&mut self, seed: u64) {
36        // SplitMix64 initialization
37        let mut z = (seed).wrapping_add(0x9e3779b97f4a7c15);
38        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
39        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
40        let s0 = z ^ (z >> 31);
41
42        let mut z = s0.wrapping_add(0x9e3779b97f4a7c15);
43        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
44        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
45        let s1 = z ^ (z >> 31);
46        
47        self.state = [s0, s1];
48    }
49
50    fn next_u64(&mut self) -> u64 {
51        let mut s1 = self.state[0];
52        let s0 = self.state[1];
53        self.state[0] = s0;
54        s1 ^= s1 << 23; // a
55        self.state[1] = s1 ^ s0 ^ (s1 >> 17) ^ (s0 >> 26); // b, c
56        self.state[1].wrapping_add(s0)
57    }
58
59    fn next_f32(&mut self) -> f32 {
60        // Generate uniform [0, 1) float
61        // Method: generate u32, divide by 2^32
62        let v = (self.next_u64() >> 32) as u32;
63        (v as f32) / (u32::MAX as f32)
64    }
65}
66
67/// Fills a buffer with random values from Uniform[0, 1)
68pub fn rand_uniform(data: &mut [f32]) {
69    RNG.with(|rng| {
70        let mut r = rng.borrow_mut();
71        for val in data.iter_mut() {
72            *val = r.next_f32();
73        }
74    })
75}
76
77/// Generates a Bernoulli mask (1.0 with prob p, 0.0 with prob 1-p)
78/// In dropout terms: 'p' usually means "probability of zeroing out" (Torch convention) 
79/// or "probability of keeping" (TensorFlow convention).
80/// PyTorch: p = prob of dropout (zeroing). 
81/// Here we use p = prob of Dropout (Zeroing).
82/// 
83/// output[i] = 1 if keep, 0 if drop.
84/// keep_prob = 1 - p
85pub fn bernoulli_mask(data: &mut [f32], p: f32) {
86    let threshold = 1.0 - p;
87    RNG.with(|rng| {
88        let mut r = rng.borrow_mut();
89        for val in data.iter_mut() {
90            let rnd = r.next_f32();
91            *val = if rnd < threshold { 1.0 } else { 0.0 };
92        }
93    })
94}
95
96/// Seeds the thread-local RNG
97pub fn seed(seed: u64) {
98    RNG.with(|rng| {
99        rng.borrow_mut().seed(seed);
100    })
101}