Skip to main content

axonml_nn/
init.rs

1//! Weight Initialization - Parameter Initialization Strategies
2//!
3//! Provides various weight initialization strategies for neural networks.
4//! Proper initialization is crucial for training deep networks.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use axonml_tensor::Tensor;
10use rand::Rng;
11
12// =============================================================================
13// Basic Initializers
14// =============================================================================
15
16/// Creates a tensor filled with zeros.
17pub fn zeros(shape: &[usize]) -> Tensor<f32> {
18    axonml_tensor::zeros(shape)
19}
20
21/// Creates a tensor filled with ones.
22pub fn ones(shape: &[usize]) -> Tensor<f32> {
23    axonml_tensor::ones(shape)
24}
25
26/// Creates a tensor filled with a constant value.
27pub fn constant(shape: &[usize], value: f32) -> Tensor<f32> {
28    axonml_tensor::full(shape, value)
29}
30
31// =============================================================================
32// Random Initializers
33// =============================================================================
34
35/// Creates a tensor with uniform random values in [0, 1).
36pub fn uniform(shape: &[usize]) -> Tensor<f32> {
37    axonml_tensor::rand(shape)
38}
39
40/// Creates a tensor with uniform random values in [low, high).
41pub fn uniform_range(shape: &[usize], low: f32, high: f32) -> Tensor<f32> {
42    let mut rng = rand::thread_rng();
43    let numel: usize = shape.iter().product();
44    let data: Vec<f32> = (0..numel).map(|_| rng.gen_range(low..high)).collect();
45    Tensor::from_vec(data, shape).unwrap()
46}
47
48/// Creates a tensor with standard normal random values (mean=0, std=1).
49pub fn randn(shape: &[usize]) -> Tensor<f32> {
50    axonml_tensor::randn(shape)
51}
52
53/// Creates a tensor with normal random values (specified mean and std).
54pub fn normal(shape: &[usize], mean: f32, std: f32) -> Tensor<f32> {
55    let base = axonml_tensor::randn(shape);
56    base.mul_scalar(std).add_scalar(mean)
57}
58
59// =============================================================================
60// Xavier/Glorot Initialization
61// =============================================================================
62
63/// Xavier uniform initialization.
64///
65/// Designed for layers with tanh or sigmoid activations.
66/// Samples from U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
67///
68/// # Arguments
69/// * `fan_in` - Number of input units
70/// * `fan_out` - Number of output units
71pub fn xavier_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
72    let a = (6.0 / (fan_in + fan_out) as f32).sqrt();
73    uniform_range(&[fan_out, fan_in], -a, a)
74}
75
76/// Xavier normal initialization.
77///
78/// Designed for layers with tanh or sigmoid activations.
79/// Samples from N(0, std) where std = sqrt(2 / (fan_in + fan_out))
80///
81/// # Arguments
82/// * `fan_in` - Number of input units
83/// * `fan_out` - Number of output units
84pub fn xavier_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
85    let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
86    normal(&[fan_out, fan_in], 0.0, std)
87}
88
89/// Alias for xavier_uniform.
90pub fn glorot_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
91    xavier_uniform(fan_in, fan_out)
92}
93
94/// Alias for xavier_normal.
95pub fn glorot_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
96    xavier_normal(fan_in, fan_out)
97}
98
99// =============================================================================
100// Kaiming/He Initialization
101// =============================================================================
102
103/// Kaiming uniform initialization.
104///
105/// Designed for layers with ReLU activations.
106/// Samples from U(-bound, bound) where bound = sqrt(6 / fan_in)
107///
108/// # Arguments
109/// * `fan_in` - Number of input units
110/// * `fan_out` - Number of output units
111pub fn kaiming_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
112    let bound = (6.0 / fan_in as f32).sqrt();
113    uniform_range(&[fan_out, fan_in], -bound, bound)
114}
115
116/// Kaiming normal initialization.
117///
118/// Designed for layers with ReLU activations.
119/// Samples from N(0, std) where std = sqrt(2 / fan_in)
120///
121/// # Arguments
122/// * `fan_in` - Number of input units
123/// * `fan_out` - Number of output units
124pub fn kaiming_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
125    let std = (2.0 / fan_in as f32).sqrt();
126    normal(&[fan_out, fan_in], 0.0, std)
127}
128
129/// Alias for kaiming_uniform.
130pub fn he_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
131    kaiming_uniform(fan_out, fan_in)
132}
133
134/// Alias for kaiming_normal.
135pub fn he_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
136    kaiming_normal(fan_out, fan_in)
137}
138
139// =============================================================================
140// Other Initializers
141// =============================================================================
142
143/// Orthogonal initialization.
144///
145/// Creates a (semi-)orthogonal matrix using QR decomposition.
146/// Good for RNNs.
147///
148/// # Arguments
149/// * `rows` - Number of rows
150/// * `cols` - Number of columns
151/// * `gain` - Multiplicative factor (default 1.0)
152pub fn orthogonal(rows: usize, cols: usize, gain: f32) -> Tensor<f32> {
153    // Simple implementation: start with random matrix and use Gram-Schmidt
154    // For a full implementation, we'd use QR decomposition
155    let mut data = vec![0.0f32; rows * cols];
156    let mut rng = rand::thread_rng();
157
158    // Generate random matrix
159    for val in data.iter_mut() {
160        *val = rng.gen_range(-1.0..1.0);
161    }
162
163    // Simple normalization (not true orthogonal, but approximation)
164    // A proper implementation would use QR decomposition
165    for i in 0..rows.min(cols) {
166        let start = i * cols;
167        let end = start + cols;
168        let row = &mut data[start..end];
169
170        // Normalize the row
171        let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
172        if norm > 1e-8 {
173            for val in row.iter_mut() {
174                *val = (*val / norm) * gain;
175            }
176        }
177    }
178
179    Tensor::from_vec(data, &[rows, cols]).unwrap()
180}
181
182/// Sparse initialization.
183///
184/// Creates a matrix where each column has only `sparsity` fraction of non-zero elements.
185///
186/// # Arguments
187/// * `rows` - Number of rows
188/// * `cols` - Number of columns
189/// * `sparsity` - Fraction of non-zero elements per column
190/// * `std` - Standard deviation of non-zero elements
191pub fn sparse(rows: usize, cols: usize, sparsity: f32, std: f32) -> Tensor<f32> {
192    let mut data = vec![0.0f32; rows * cols];
193    let mut rng = rand::thread_rng();
194
195    let num_nonzero = (rows as f32 * sparsity).ceil() as usize;
196
197    for col in 0..cols {
198        // Randomly select which rows will be non-zero
199        let mut indices: Vec<usize> = (0..rows).collect();
200        for i in 0..num_nonzero.min(rows) {
201            let j = rng.gen_range(i..rows);
202            indices.swap(i, j);
203        }
204
205        // Set non-zero values
206        for &row in indices.iter().take(num_nonzero) {
207            let val: f32 = rng.gen::<f32>() * 2.0 - 1.0; // Approximate normal
208            data[row * cols + col] = val * std;
209        }
210    }
211
212    Tensor::from_vec(data, &[rows, cols]).unwrap()
213}
214
215/// Identity matrix initialization.
216///
217/// Creates an identity matrix (or as close as possible for non-square).
218pub fn eye(size: usize) -> Tensor<f32> {
219    axonml_tensor::eye(size)
220}
221
222/// Diagonal initialization.
223///
224/// Creates a matrix with specified values on the diagonal.
225pub fn diag(values: &[f32]) -> Tensor<f32> {
226    let n = values.len();
227    let mut data = vec![0.0f32; n * n];
228    for (i, &val) in values.iter().enumerate() {
229        data[i * n + i] = val;
230    }
231    Tensor::from_vec(data, &[n, n]).unwrap()
232}
233
234// =============================================================================
235// Initialization Mode Enum
236// =============================================================================
237
238/// Initialization strategies as an enum for dynamic selection.
239#[derive(Debug, Clone, Copy, PartialEq)]
240pub enum InitMode {
241    /// Zeros initialization.
242    Zeros,
243    /// Ones initialization.
244    Ones,
245    /// Constant value initialization.
246    Constant(f32),
247    /// Uniform random initialization.
248    Uniform,
249    /// Uniform random in range.
250    UniformRange(f32, f32),
251    /// Normal distribution.
252    Normal(f32, f32), // mean, std
253    /// Xavier/Glorot uniform.
254    XavierUniform,
255    /// Xavier/Glorot normal.
256    XavierNormal,
257    /// Kaiming/He uniform.
258    KaimingUniform,
259    /// Kaiming/He normal.
260    KaimingNormal,
261    /// Orthogonal.
262    Orthogonal(f32), // gain
263}
264
265impl InitMode {
266    /// Initializes a tensor using this mode.
267    pub fn init(&self, fan_out: usize, fan_in: usize) -> Tensor<f32> {
268        match self {
269            InitMode::Zeros => zeros(&[fan_out, fan_in]),
270            InitMode::Ones => ones(&[fan_out, fan_in]),
271            InitMode::Constant(val) => constant(&[fan_out, fan_in], *val),
272            InitMode::Uniform => uniform(&[fan_out, fan_in]),
273            InitMode::UniformRange(low, high) => uniform_range(&[fan_out, fan_in], *low, *high),
274            InitMode::Normal(mean, std) => normal(&[fan_out, fan_in], *mean, *std),
275            InitMode::XavierUniform => xavier_uniform(fan_in, fan_out),
276            InitMode::XavierNormal => xavier_normal(fan_in, fan_out),
277            InitMode::KaimingUniform => kaiming_uniform(fan_out, fan_in),
278            InitMode::KaimingNormal => kaiming_normal(fan_out, fan_in),
279            InitMode::Orthogonal(gain) => orthogonal(fan_out, fan_in, *gain),
280        }
281    }
282}
283
284// =============================================================================
285// Tests
286// =============================================================================
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_zeros() {
294        let t = zeros(&[2, 3]);
295        assert_eq!(t.shape(), &[2, 3]);
296        assert!(t.to_vec().iter().all(|&x| x == 0.0));
297    }
298
299    #[test]
300    fn test_ones() {
301        let t = ones(&[2, 3]);
302        assert_eq!(t.shape(), &[2, 3]);
303        assert!(t.to_vec().iter().all(|&x| x == 1.0));
304    }
305
306    #[test]
307    fn test_uniform_range() {
308        let t = uniform_range(&[100], 0.0, 1.0);
309        let data = t.to_vec();
310        assert!(data.iter().all(|&x| (0.0..1.0).contains(&x)));
311    }
312
313    #[test]
314    fn test_xavier_uniform() {
315        let t = xavier_uniform(100, 100);
316        assert_eq!(t.shape(), &[100, 100]);
317        let bound = (6.0 / 200.0_f32).sqrt();
318        let data = t.to_vec();
319        assert!(data.iter().all(|&x| x.abs() <= bound * 1.1)); // Small margin
320    }
321
322    #[test]
323    fn test_kaiming_uniform() {
324        let t = kaiming_uniform(100, 100);
325        assert_eq!(t.shape(), &[100, 100]);
326    }
327
328    #[test]
329    fn test_eye() {
330        let t = eye(3);
331        assert_eq!(t.shape(), &[3, 3]);
332        let data = t.to_vec();
333        assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
334    }
335
336    #[test]
337    fn test_init_mode() {
338        let mode = InitMode::KaimingUniform;
339        let t = mode.init(10, 5);
340        assert_eq!(t.shape(), &[10, 5]);
341    }
342}