entrenar/transformer/
init.rs1use rand::rngs::SmallRng;
15use rand::SeedableRng;
16
17pub const INITIALIZER_RANGE: f32 = 0.02;
19
20pub fn rand_normal_seeded(n: usize, base_seed: u64, name: &str) -> Vec<f32> {
32 let name_hash = hash_name(name);
34 let seed = base_seed.wrapping_add(name_hash);
35 let mut rng = SmallRng::seed_from_u64(seed);
36
37 let std_dev = INITIALIZER_RANGE;
38 (0..n)
39 .map(|_| {
40 let u1: f32 = rand::Rng::random::<f32>(&mut rng).max(1e-7);
42 let u2: f32 = rand::Rng::random::<f32>(&mut rng);
43 ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) * std_dev
44 })
45 .collect()
46}
47
48fn hash_name(name: &str) -> u64 {
50 let mut h: u64 = 0xcbf2_9ce4_8422_2325; for byte in name.bytes() {
52 h ^= u64::from(byte);
53 h = h.wrapping_mul(0x0100_0000_01b3); }
55 h
56}
57
58static INIT_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(42);
62
63static INIT_SEED_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
69
70pub fn set_init_seed(seed: u64) {
72 INIT_SEED.store(seed, std::sync::atomic::Ordering::SeqCst);
73}
74
75pub fn get_init_seed() -> u64 {
77 INIT_SEED.load(std::sync::atomic::Ordering::SeqCst)
78}
79
80#[must_use = "the returned guard must be held until weight init finishes"]
90pub fn lock_init_seed(seed: u64) -> std::sync::MutexGuard<'static, ()> {
91 let guard = INIT_SEED_LOCK.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
92 set_init_seed(seed);
93 guard
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_rand_normal_seeded_deterministic() {
102 let a = rand_normal_seeded(100, 42, "test");
103 let b = rand_normal_seeded(100, 42, "test");
104 assert_eq!(a, b, "Same seed+name must produce identical output");
105 }
106
107 #[test]
108 fn test_rand_normal_seeded_different_seeds() {
109 let a = rand_normal_seeded(100, 42, "test");
110 let b = rand_normal_seeded(100, 123, "test");
111 assert_ne!(a, b, "Different seeds must produce different output");
112 }
113
114 #[test]
115 fn test_rand_normal_seeded_different_names() {
116 let a = rand_normal_seeded(100, 42, "w_q");
117 let b = rand_normal_seeded(100, 42, "w_k");
118 assert_ne!(a, b, "Different names must produce different output");
119 }
120
121 #[test]
122 fn test_rand_normal_seeded_statistics() {
123 let data = rand_normal_seeded(10000, 42, "stats_test");
124 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
125 let variance: f32 =
126 data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
127 let std = variance.sqrt();
128
129 assert!(mean.abs() < 0.005, "Mean should be near 0, got {mean}");
130 assert!(
131 (std - INITIALIZER_RANGE).abs() < 0.005,
132 "Std should be near {INITIALIZER_RANGE}, got {std}"
133 );
134 }
135
136 #[test]
137 fn test_rand_normal_seeded_no_sinusoidal_pattern() {
138 let data = rand_normal_seeded(1000, 42, "autocorr_test");
140 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
141 let var: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
142 let autocorr: f32 = data.windows(2).map(|w| (w[0] - mean) * (w[1] - mean)).sum::<f32>()
143 / (data.len() as f32 * var);
144 assert!(
145 autocorr.abs() < 0.1,
146 "Autocorrelation should be < 0.1 (no sinusoidal pattern), got {autocorr}"
147 );
148 }
149}