Skip to main content

entrenar/transformer/
init.rs

1//! Weight initialization utilities (C-INIT-001).
2//!
3//! Provides `rand_normal_seeded` for proper random normal initialization,
4//! matching HuggingFace LLaMA's `normal(0, initializer_range)`.
5//!
6//! Replaces the legacy sinusoidal `sin(i * const) * scale` placeholder
7//! that caused a 16x convergence gap vs PyTorch (entrenar#309).
8//!
9//! References:
10//! - Touvron et al. (2023) LLaMA: arxiv 2302.13971
11//! - He et al. (2015) Kaiming init: arxiv 1502.01852
12//! - HuggingFace LlamaPreTrainedModel._init_weights
13
14use rand::rngs::SmallRng;
15use rand::SeedableRng;
16
17/// Default initializer range matching HuggingFace LLaMA config.
18pub const INITIALIZER_RANGE: f32 = 0.02;
19
20/// Generate `n` random normal values with mean=0 and std=`INITIALIZER_RANGE`.
21///
22/// Uses a deterministic seed derived from `base_seed` and `name` for
23/// reproducibility (C-INIT-001, FALSIFY-INIT-003).
24///
25/// # Contract (C-INIT-001)
26///
27/// - `E[result[i]] ≈ 0`
28/// - `Var[result[i]] ≈ INITIALIZER_RANGE^2`
29/// - Same `(base_seed, name)` → identical output
30/// - Different `(base_seed, name)` → different output
31pub fn rand_normal_seeded(n: usize, base_seed: u64, name: &str) -> Vec<f32> {
32    // Derive per-parameter seed from base_seed + name hash
33    let name_hash = hash_name(name);
34    let seed = base_seed.wrapping_add(name_hash);
35    let mut rng = SmallRng::seed_from_u64(seed);
36
37    let std_dev = INITIALIZER_RANGE;
38    (0..n)
39        .map(|_| {
40            // Box-Muller transform: two uniform → one normal
41            let u1: f32 = rand::Rng::random::<f32>(&mut rng).max(1e-7);
42            let u2: f32 = rand::Rng::random::<f32>(&mut rng);
43            ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) * std_dev
44        })
45        .collect()
46}
47
48/// Simple string hash for name-based seed derivation.
49fn hash_name(name: &str) -> u64 {
50    let mut h: u64 = 0xcbf2_9ce4_8422_2325; // FNV-1a offset basis
51    for byte in name.bytes() {
52        h ^= u64::from(byte);
53        h = h.wrapping_mul(0x0100_0000_01b3); // FNV-1a prime
54    }
55    h
56}
57
58/// Global seed for weight initialization (set from training config).
59///
60/// Defaults to 42. Set via `set_init_seed()` before `Transformer::new()`.
61static INIT_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(42);
62
63/// Mutex held across a (set_init_seed, read-during-construction) critical
64/// section so parallel callers cannot clobber each other's seed between
65/// `set_init_seed` and the subsequent `rand_normal_seeded` calls inside
66/// `Transformer::new`. Required for GATE-TRAIN-006 seed reproducibility
67/// when tests run concurrently.
68static INIT_SEED_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
69
70/// Set the global initialization seed (called from training config).
71pub fn set_init_seed(seed: u64) {
72    INIT_SEED.store(seed, std::sync::atomic::Ordering::SeqCst);
73}
74
75/// Get the current initialization seed.
76pub fn get_init_seed() -> u64 {
77    INIT_SEED.load(std::sync::atomic::Ordering::SeqCst)
78}
79
80/// Lock the init-seed critical section and set the seed atomically.
81///
82/// Returns a `MutexGuard` that MUST be held for the full lifetime of
83/// the weight-init work (i.e., the entire `Transformer::new` call).
84/// Dropping the guard before init completes reopens the race window.
85///
86/// Poisoned mutexes are recovered transparently — a poisoned seed
87/// lock does not mean the seed value is corrupt (seed is a simple
88/// `u64` atomic), only that a prior holder panicked.
89#[must_use = "the returned guard must be held until weight init finishes"]
90pub fn lock_init_seed(seed: u64) -> std::sync::MutexGuard<'static, ()> {
91    let guard = INIT_SEED_LOCK.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
92    set_init_seed(seed);
93    guard
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_rand_normal_seeded_deterministic() {
102        let a = rand_normal_seeded(100, 42, "test");
103        let b = rand_normal_seeded(100, 42, "test");
104        assert_eq!(a, b, "Same seed+name must produce identical output");
105    }
106
107    #[test]
108    fn test_rand_normal_seeded_different_seeds() {
109        let a = rand_normal_seeded(100, 42, "test");
110        let b = rand_normal_seeded(100, 123, "test");
111        assert_ne!(a, b, "Different seeds must produce different output");
112    }
113
114    #[test]
115    fn test_rand_normal_seeded_different_names() {
116        let a = rand_normal_seeded(100, 42, "w_q");
117        let b = rand_normal_seeded(100, 42, "w_k");
118        assert_ne!(a, b, "Different names must produce different output");
119    }
120
121    #[test]
122    fn test_rand_normal_seeded_statistics() {
123        let data = rand_normal_seeded(10000, 42, "stats_test");
124        let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
125        let variance: f32 =
126            data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
127        let std = variance.sqrt();
128
129        assert!(mean.abs() < 0.005, "Mean should be near 0, got {mean}");
130        assert!(
131            (std - INITIALIZER_RANGE).abs() < 0.005,
132            "Std should be near {INITIALIZER_RANGE}, got {std}"
133        );
134    }
135
136    #[test]
137    fn test_rand_normal_seeded_no_sinusoidal_pattern() {
138        // FALSIFY-INIT-001: autocorrelation at lag 1 should be < 0.1
139        let data = rand_normal_seeded(1000, 42, "autocorr_test");
140        let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
141        let var: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
142        let autocorr: f32 = data.windows(2).map(|w| (w[0] - mean) * (w[1] - mean)).sum::<f32>()
143            / (data.len() as f32 * var);
144        assert!(
145            autocorr.abs() < 0.1,
146            "Autocorrelation should be < 0.1 (no sinusoidal pattern), got {autocorr}"
147        );
148    }
149}