1use ghostflow_core::Tensor;
4use rand_distr::{Distribution, Normal, Uniform};
5
6pub fn xavier_uniform(shape: &[usize], fan_in: usize, fan_out: usize) -> Tensor {
8 let bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
9 let mut rng = rand::thread_rng();
10 let dist = Uniform::new(-bound, bound);
11
12 let numel: usize = shape.iter().product();
13 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
14
15 Tensor::from_slice(&data, shape).unwrap()
16}
17
18pub fn xavier_normal(shape: &[usize], fan_in: usize, fan_out: usize) -> Tensor {
20 let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
21 let mut rng = rand::thread_rng();
22 let dist = Normal::new(0.0, std).unwrap();
23
24 let numel: usize = shape.iter().product();
25 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
26
27 Tensor::from_slice(&data, shape).unwrap()
28}
29
30pub fn kaiming_uniform(shape: &[usize], fan_in: usize) -> Tensor {
32 let bound = (6.0 / fan_in as f32).sqrt();
33 let mut rng = rand::thread_rng();
34 let dist = Uniform::new(-bound, bound);
35
36 let numel: usize = shape.iter().product();
37 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
38
39 Tensor::from_slice(&data, shape).unwrap()
40}
41
42pub fn kaiming_normal(shape: &[usize], fan_in: usize) -> Tensor {
44 let std = (2.0 / fan_in as f32).sqrt();
45 let mut rng = rand::thread_rng();
46 let dist = Normal::new(0.0, std).unwrap();
47
48 let numel: usize = shape.iter().product();
49 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
50
51 Tensor::from_slice(&data, shape).unwrap()
52}
53
54pub fn uniform(shape: &[usize], low: f32, high: f32) -> Tensor {
56 let mut rng = rand::thread_rng();
57 let dist = Uniform::new(low, high);
58
59 let numel: usize = shape.iter().product();
60 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
61
62 Tensor::from_slice(&data, shape).unwrap()
63}
64
65pub fn normal(shape: &[usize], mean: f32, std: f32) -> Tensor {
67 let mut rng = rand::thread_rng();
68 let dist = Normal::new(mean, std).unwrap();
69
70 let numel: usize = shape.iter().product();
71 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
72
73 Tensor::from_slice(&data, shape).unwrap()
74}
75
76pub fn constant(shape: &[usize], value: f32) -> Tensor {
78 Tensor::full(shape, value)
79}
80
81pub fn zeros(shape: &[usize]) -> Tensor {
83 Tensor::zeros(shape)
84}
85
86pub fn ones(shape: &[usize]) -> Tensor {
88 Tensor::ones(shape)
89}