Skip to main content

yscv_model/
init.rs

1//! Weight initialization strategies (Kaiming, Xavier, orthogonal).
2//!
3//! These match PyTorch's `torch.nn.init` functions and produce deterministic
4//! results given the same seed.
5
6use yscv_tensor::Tensor;
7
8use crate::ModelError;
9
10/// Simple xorshift64 PRNG (deterministic, no external deps).
11struct Rng(u64);
12
13impl Rng {
14    fn new(seed: u64) -> Self {
15        Self(if seed == 0 { 0xDEAD_BEEF } else { seed })
16    }
17
18    fn next_u64(&mut self) -> u64 {
19        let mut x = self.0;
20        x ^= x << 13;
21        x ^= x >> 7;
22        x ^= x << 17;
23        self.0 = x;
24        x
25    }
26
27    /// Uniform in [0, 1).
28    fn uniform(&mut self) -> f32 {
29        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
30    }
31
32    /// Uniform in [lo, hi).
33    fn uniform_range(&mut self, lo: f32, hi: f32) -> f32 {
34        lo + (hi - lo) * self.uniform()
35    }
36
37    /// Standard normal via Box-Muller.
38    fn normal(&mut self) -> f32 {
39        let u1 = self.uniform().max(1e-10);
40        let u2 = self.uniform();
41        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
42    }
43}
44
45/// Kaiming (He) uniform initialization.
46///
47/// Fills with values from U(-bound, bound) where bound = sqrt(6 / fan_in).
48pub fn kaiming_uniform(shape: Vec<usize>, fan_in: usize, seed: u64) -> Result<Tensor, ModelError> {
49    let bound = (6.0 / fan_in as f32).sqrt();
50    let n: usize = shape.iter().product();
51    let mut rng = Rng::new(seed);
52    let data: Vec<f32> = (0..n).map(|_| rng.uniform_range(-bound, bound)).collect();
53    Ok(Tensor::from_vec(shape, data)?)
54}
55
56/// Kaiming (He) normal initialization.
57///
58/// Fills with values from N(0, std) where std = sqrt(2 / fan_in).
59pub fn kaiming_normal(shape: Vec<usize>, fan_in: usize, seed: u64) -> Result<Tensor, ModelError> {
60    let std = (2.0 / fan_in as f32).sqrt();
61    let n: usize = shape.iter().product();
62    let mut rng = Rng::new(seed);
63    let data: Vec<f32> = (0..n).map(|_| rng.normal() * std).collect();
64    Ok(Tensor::from_vec(shape, data)?)
65}
66
67/// Xavier (Glorot) uniform initialization.
68///
69/// Fills with values from U(-bound, bound) where bound = sqrt(6 / (fan_in + fan_out)).
70pub fn xavier_uniform(
71    shape: Vec<usize>,
72    fan_in: usize,
73    fan_out: usize,
74    seed: u64,
75) -> Result<Tensor, ModelError> {
76    let bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
77    let n: usize = shape.iter().product();
78    let mut rng = Rng::new(seed);
79    let data: Vec<f32> = (0..n).map(|_| rng.uniform_range(-bound, bound)).collect();
80    Ok(Tensor::from_vec(shape, data)?)
81}
82
83/// Xavier (Glorot) normal initialization.
84///
85/// Fills with values from N(0, std) where std = sqrt(2 / (fan_in + fan_out)).
86pub fn xavier_normal(
87    shape: Vec<usize>,
88    fan_in: usize,
89    fan_out: usize,
90    seed: u64,
91) -> Result<Tensor, ModelError> {
92    let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
93    let n: usize = shape.iter().product();
94    let mut rng = Rng::new(seed);
95    let data: Vec<f32> = (0..n).map(|_| rng.normal() * std).collect();
96    Ok(Tensor::from_vec(shape, data)?)
97}
98
99/// Orthogonal initialization via QR decomposition (simplified Gram-Schmidt).
100///
101/// Creates a matrix of shape `[rows, cols]` with orthonormal rows (or columns).
102pub fn orthogonal(rows: usize, cols: usize, seed: u64) -> Result<Tensor, ModelError> {
103    let n = rows.max(cols);
104    let mut rng = Rng::new(seed);
105
106    // Generate random matrix
107    let mut mat: Vec<Vec<f32>> = (0..n)
108        .map(|_| (0..n).map(|_| rng.normal()).collect())
109        .collect();
110
111    // Modified Gram-Schmidt QR
112    for i in 0..n {
113        // Normalize column i
114        let norm: f32 = (0..n).map(|r| mat[r][i] * mat[r][i]).sum::<f32>().sqrt();
115        if norm > 1e-10 {
116            for r in 0..n {
117                mat[r][i] /= norm;
118            }
119        }
120        // Orthogonalize remaining columns
121        for j in (i + 1)..n {
122            let dot: f32 = (0..n).map(|r| mat[r][i] * mat[r][j]).sum();
123            for r in 0..n {
124                mat[r][j] -= dot * mat[r][i];
125            }
126        }
127    }
128
129    // Extract [rows, cols] submatrix from Q
130    let mut data = Vec::with_capacity(rows * cols);
131    for r in 0..rows {
132        for c in 0..cols {
133            data.push(mat[r][c]);
134        }
135    }
136    Ok(Tensor::from_vec(vec![rows, cols], data)?)
137}
138
139/// Fill a tensor with a constant value.
140pub fn constant(shape: Vec<usize>, value: f32) -> Result<Tensor, ModelError> {
141    Ok(Tensor::filled(shape, value)?)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn kaiming_uniform_shape_and_bounds() {
150        let t = kaiming_uniform(vec![64, 32], 32, 42).unwrap();
151        assert_eq!(t.shape(), &[64, 32]);
152        let bound = (6.0f32 / 32.0).sqrt();
153        for &v in t.data() {
154            assert!(v >= -bound && v <= bound, "value {v} out of bounds");
155        }
156    }
157
158    #[test]
159    fn kaiming_normal_shape() {
160        let t = kaiming_normal(vec![128, 64], 64, 42).unwrap();
161        assert_eq!(t.shape(), &[128, 64]);
162        // Check mean is roughly 0
163        let mean: f32 = t.data().iter().sum::<f32>() / t.data().len() as f32;
164        assert!(mean.abs() < 0.1, "mean {mean} too far from 0");
165    }
166
167    #[test]
168    fn xavier_uniform_shape_and_bounds() {
169        let t = xavier_uniform(vec![100, 50], 100, 50, 42).unwrap();
170        assert_eq!(t.shape(), &[100, 50]);
171        let bound = (6.0f32 / 150.0).sqrt();
172        for &v in t.data() {
173            assert!(v >= -bound && v <= bound);
174        }
175    }
176
177    #[test]
178    fn xavier_normal_shape() {
179        let t = xavier_normal(vec![100, 50], 100, 50, 42).unwrap();
180        assert_eq!(t.shape(), &[100, 50]);
181        let mean: f32 = t.data().iter().sum::<f32>() / t.data().len() as f32;
182        assert!(mean.abs() < 0.1);
183    }
184
185    #[test]
186    fn orthogonal_produces_orthonormal_columns() {
187        let t = orthogonal(4, 4, 42).unwrap();
188        let d = t.data();
189        // Check column 0 has unit norm
190        let norm: f32 = (0..4).map(|r| d[r * 4] * d[r * 4]).sum::<f32>().sqrt();
191        assert!((norm - 1.0).abs() < 1e-4, "col 0 norm = {norm}");
192        // Check col 0 dot col 1 ≈ 0
193        let dot: f32 = (0..4).map(|r| d[r * 4] * d[r * 4 + 1]).sum();
194        assert!(dot.abs() < 1e-4, "dot = {dot}");
195    }
196
197    #[test]
198    fn constant_fills_with_value() {
199        let t = constant(vec![2, 3], 7.0).unwrap();
200        assert_eq!(t.data(), &[7.0; 6]);
201    }
202}