1use axonml_tensor::Tensor;
10use rand::Rng;
11
12pub fn zeros(shape: &[usize]) -> Tensor<f32> {
18 axonml_tensor::zeros(shape)
19}
20
21pub fn ones(shape: &[usize]) -> Tensor<f32> {
23 axonml_tensor::ones(shape)
24}
25
26pub fn constant(shape: &[usize], value: f32) -> Tensor<f32> {
28 axonml_tensor::full(shape, value)
29}
30
31pub fn uniform(shape: &[usize]) -> Tensor<f32> {
37 axonml_tensor::rand(shape)
38}
39
40pub fn uniform_range(shape: &[usize], low: f32, high: f32) -> Tensor<f32> {
42 let mut rng = rand::thread_rng();
43 let numel: usize = shape.iter().product();
44 let data: Vec<f32> = (0..numel).map(|_| rng.gen_range(low..high)).collect();
45 Tensor::from_vec(data, shape).unwrap()
46}
47
48pub fn randn(shape: &[usize]) -> Tensor<f32> {
50 axonml_tensor::randn(shape)
51}
52
53pub fn normal(shape: &[usize], mean: f32, std: f32) -> Tensor<f32> {
55 let base = axonml_tensor::randn(shape);
56 base.mul_scalar(std).add_scalar(mean)
57}
58
59pub fn xavier_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
72 let a = (6.0 / (fan_in + fan_out) as f32).sqrt();
73 uniform_range(&[fan_out, fan_in], -a, a)
74}
75
76pub fn xavier_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
85 let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
86 normal(&[fan_out, fan_in], 0.0, std)
87}
88
89pub fn glorot_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
91 xavier_uniform(fan_in, fan_out)
92}
93
94pub fn glorot_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
96 xavier_normal(fan_in, fan_out)
97}
98
99pub fn kaiming_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
112 let bound = (6.0 / fan_in as f32).sqrt();
113 uniform_range(&[fan_out, fan_in], -bound, bound)
114}
115
116pub fn kaiming_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
125 let std = (2.0 / fan_in as f32).sqrt();
126 normal(&[fan_out, fan_in], 0.0, std)
127}
128
129pub fn he_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
131 kaiming_uniform(fan_out, fan_in)
132}
133
134pub fn he_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
136 kaiming_normal(fan_out, fan_in)
137}
138
139pub fn orthogonal(rows: usize, cols: usize, gain: f32) -> Tensor<f32> {
153 let mut data = vec![0.0f32; rows * cols];
156 let mut rng = rand::thread_rng();
157
158 for val in data.iter_mut() {
160 *val = rng.gen_range(-1.0..1.0);
161 }
162
163 for i in 0..rows.min(cols) {
166 let start = i * cols;
167 let end = start + cols;
168 let row = &mut data[start..end];
169
170 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
172 if norm > 1e-8 {
173 for val in row.iter_mut() {
174 *val = (*val / norm) * gain;
175 }
176 }
177 }
178
179 Tensor::from_vec(data, &[rows, cols]).unwrap()
180}
181
182pub fn sparse(rows: usize, cols: usize, sparsity: f32, std: f32) -> Tensor<f32> {
192 let mut data = vec![0.0f32; rows * cols];
193 let mut rng = rand::thread_rng();
194
195 let num_nonzero = (rows as f32 * sparsity).ceil() as usize;
196
197 for col in 0..cols {
198 let mut indices: Vec<usize> = (0..rows).collect();
200 for i in 0..num_nonzero.min(rows) {
201 let j = rng.gen_range(i..rows);
202 indices.swap(i, j);
203 }
204
205 for &row in indices.iter().take(num_nonzero) {
207 let val: f32 = rng.gen::<f32>() * 2.0 - 1.0; data[row * cols + col] = val * std;
209 }
210 }
211
212 Tensor::from_vec(data, &[rows, cols]).unwrap()
213}
214
215pub fn eye(size: usize) -> Tensor<f32> {
219 axonml_tensor::eye(size)
220}
221
222pub fn diag(values: &[f32]) -> Tensor<f32> {
226 let n = values.len();
227 let mut data = vec![0.0f32; n * n];
228 for (i, &val) in values.iter().enumerate() {
229 data[i * n + i] = val;
230 }
231 Tensor::from_vec(data, &[n, n]).unwrap()
232}
233
234#[derive(Debug, Clone, Copy, PartialEq)]
240pub enum InitMode {
241 Zeros,
243 Ones,
245 Constant(f32),
247 Uniform,
249 UniformRange(f32, f32),
251 Normal(f32, f32), XavierUniform,
255 XavierNormal,
257 KaimingUniform,
259 KaimingNormal,
261 Orthogonal(f32), }
264
265impl InitMode {
266 pub fn init(&self, fan_out: usize, fan_in: usize) -> Tensor<f32> {
268 match self {
269 InitMode::Zeros => zeros(&[fan_out, fan_in]),
270 InitMode::Ones => ones(&[fan_out, fan_in]),
271 InitMode::Constant(val) => constant(&[fan_out, fan_in], *val),
272 InitMode::Uniform => uniform(&[fan_out, fan_in]),
273 InitMode::UniformRange(low, high) => uniform_range(&[fan_out, fan_in], *low, *high),
274 InitMode::Normal(mean, std) => normal(&[fan_out, fan_in], *mean, *std),
275 InitMode::XavierUniform => xavier_uniform(fan_in, fan_out),
276 InitMode::XavierNormal => xavier_normal(fan_in, fan_out),
277 InitMode::KaimingUniform => kaiming_uniform(fan_out, fan_in),
278 InitMode::KaimingNormal => kaiming_normal(fan_out, fan_in),
279 InitMode::Orthogonal(gain) => orthogonal(fan_out, fan_in, *gain),
280 }
281 }
282}
283
284#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_zeros() {
294 let t = zeros(&[2, 3]);
295 assert_eq!(t.shape(), &[2, 3]);
296 assert!(t.to_vec().iter().all(|&x| x == 0.0));
297 }
298
299 #[test]
300 fn test_ones() {
301 let t = ones(&[2, 3]);
302 assert_eq!(t.shape(), &[2, 3]);
303 assert!(t.to_vec().iter().all(|&x| x == 1.0));
304 }
305
306 #[test]
307 fn test_uniform_range() {
308 let t = uniform_range(&[100], 0.0, 1.0);
309 let data = t.to_vec();
310 assert!(data.iter().all(|&x| (0.0..1.0).contains(&x)));
311 }
312
313 #[test]
314 fn test_xavier_uniform() {
315 let t = xavier_uniform(100, 100);
316 assert_eq!(t.shape(), &[100, 100]);
317 let bound = (6.0 / 200.0_f32).sqrt();
318 let data = t.to_vec();
319 assert!(data.iter().all(|&x| x.abs() <= bound * 1.1)); }
321
322 #[test]
323 fn test_kaiming_uniform() {
324 let t = kaiming_uniform(100, 100);
325 assert_eq!(t.shape(), &[100, 100]);
326 }
327
328 #[test]
329 fn test_eye() {
330 let t = eye(3);
331 assert_eq!(t.shape(), &[3, 3]);
332 let data = t.to_vec();
333 assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
334 }
335
336 #[test]
337 fn test_init_mode() {
338 let mode = InitMode::KaimingUniform;
339 let t = mode.init(10, 5);
340 assert_eq!(t.shape(), &[10, 5]);
341 }
342}