ghostflow_nn/
init.rs

1//! Weight initialization strategies
2
3use ghostflow_core::Tensor;
4use rand_distr::{Distribution, Normal, Uniform};
5
6/// Initialize tensor with Xavier/Glorot uniform
7pub fn xavier_uniform(shape: &[usize], fan_in: usize, fan_out: usize) -> Tensor {
8    let bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
9    let mut rng = rand::thread_rng();
10    let dist = Uniform::new(-bound, bound);
11    
12    let numel: usize = shape.iter().product();
13    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
14    
15    Tensor::from_slice(&data, shape).unwrap()
16}
17
18/// Initialize tensor with Xavier/Glorot normal
19pub fn xavier_normal(shape: &[usize], fan_in: usize, fan_out: usize) -> Tensor {
20    let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
21    let mut rng = rand::thread_rng();
22    let dist = Normal::new(0.0, std).unwrap();
23    
24    let numel: usize = shape.iter().product();
25    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
26    
27    Tensor::from_slice(&data, shape).unwrap()
28}
29
30/// Initialize tensor with Kaiming/He uniform (for ReLU)
31pub fn kaiming_uniform(shape: &[usize], fan_in: usize) -> Tensor {
32    let bound = (6.0 / fan_in as f32).sqrt();
33    let mut rng = rand::thread_rng();
34    let dist = Uniform::new(-bound, bound);
35    
36    let numel: usize = shape.iter().product();
37    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
38    
39    Tensor::from_slice(&data, shape).unwrap()
40}
41
42/// Initialize tensor with Kaiming/He normal (for ReLU)
43pub fn kaiming_normal(shape: &[usize], fan_in: usize) -> Tensor {
44    let std = (2.0 / fan_in as f32).sqrt();
45    let mut rng = rand::thread_rng();
46    let dist = Normal::new(0.0, std).unwrap();
47    
48    let numel: usize = shape.iter().product();
49    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
50    
51    Tensor::from_slice(&data, shape).unwrap()
52}
53
54/// Initialize tensor with uniform distribution
55pub fn uniform(shape: &[usize], low: f32, high: f32) -> Tensor {
56    let mut rng = rand::thread_rng();
57    let dist = Uniform::new(low, high);
58    
59    let numel: usize = shape.iter().product();
60    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
61    
62    Tensor::from_slice(&data, shape).unwrap()
63}
64
65/// Initialize tensor with normal distribution
66pub fn normal(shape: &[usize], mean: f32, std: f32) -> Tensor {
67    let mut rng = rand::thread_rng();
68    let dist = Normal::new(mean, std).unwrap();
69    
70    let numel: usize = shape.iter().product();
71    let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
72    
73    Tensor::from_slice(&data, shape).unwrap()
74}
75
76/// Initialize tensor with constant value
77pub fn constant(shape: &[usize], value: f32) -> Tensor {
78    Tensor::full(shape, value)
79}
80
81/// Initialize tensor with zeros
82pub fn zeros(shape: &[usize]) -> Tensor {
83    Tensor::zeros(shape)
84}
85
86/// Initialize tensor with ones
87pub fn ones(shape: &[usize]) -> Tensor {
88    Tensor::ones(shape)
89}