use axonml_tensor::Tensor;
use rand::Rng;
pub fn zeros(shape: &[usize]) -> Tensor<f32> {
axonml_tensor::zeros(shape)
}
pub fn ones(shape: &[usize]) -> Tensor<f32> {
axonml_tensor::ones(shape)
}
pub fn constant(shape: &[usize], value: f32) -> Tensor<f32> {
axonml_tensor::full(shape, value)
}
pub fn uniform(shape: &[usize]) -> Tensor<f32> {
axonml_tensor::rand(shape)
}
pub fn uniform_range(shape: &[usize], low: f32, high: f32) -> Tensor<f32> {
let mut rng = rand::thread_rng();
let numel: usize = shape.iter().product();
let data: Vec<f32> = (0..numel).map(|_| rng.gen_range(low..high)).collect();
Tensor::from_vec(data, shape).unwrap()
}
pub fn randn(shape: &[usize]) -> Tensor<f32> {
axonml_tensor::randn(shape)
}
pub fn normal(shape: &[usize], mean: f32, std: f32) -> Tensor<f32> {
let base = axonml_tensor::randn(shape);
base.mul_scalar(std).add_scalar(mean)
}
pub fn xavier_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
let a = (6.0 / (fan_in + fan_out) as f32).sqrt();
uniform_range(&[fan_out, fan_in], -a, a)
}
pub fn xavier_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
normal(&[fan_out, fan_in], 0.0, std)
}
pub fn glorot_uniform(fan_in: usize, fan_out: usize) -> Tensor<f32> {
xavier_uniform(fan_in, fan_out)
}
pub fn glorot_normal(fan_in: usize, fan_out: usize) -> Tensor<f32> {
xavier_normal(fan_in, fan_out)
}
pub fn kaiming_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
let bound = (6.0 / fan_in as f32).sqrt();
uniform_range(&[fan_out, fan_in], -bound, bound)
}
pub fn kaiming_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
let std = (2.0 / fan_in as f32).sqrt();
normal(&[fan_out, fan_in], 0.0, std)
}
pub fn he_uniform(fan_out: usize, fan_in: usize) -> Tensor<f32> {
kaiming_uniform(fan_out, fan_in)
}
pub fn he_normal(fan_out: usize, fan_in: usize) -> Tensor<f32> {
kaiming_normal(fan_out, fan_in)
}
pub fn orthogonal(rows: usize, cols: usize, gain: f32) -> Tensor<f32> {
let mut data = vec![0.0f32; rows * cols];
let mut rng = rand::thread_rng();
for val in data.iter_mut() {
*val = rng.gen_range(-1.0..1.0);
}
for i in 0..rows.min(cols) {
let start = i * cols;
let end = start + cols;
let row = &mut data[start..end];
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for val in row.iter_mut() {
*val = (*val / norm) * gain;
}
}
}
Tensor::from_vec(data, &[rows, cols]).unwrap()
}
pub fn sparse(rows: usize, cols: usize, sparsity: f32, std: f32) -> Tensor<f32> {
let mut data = vec![0.0f32; rows * cols];
let mut rng = rand::thread_rng();
let num_nonzero = (rows as f32 * sparsity).ceil() as usize;
for col in 0..cols {
let mut indices: Vec<usize> = (0..rows).collect();
for i in 0..num_nonzero.min(rows) {
let j = rng.gen_range(i..rows);
indices.swap(i, j);
}
for &row in indices.iter().take(num_nonzero) {
let val: f32 = rng.r#gen::<f32>() * 2.0 - 1.0; data[row * cols + col] = val * std;
}
}
Tensor::from_vec(data, &[rows, cols]).unwrap()
}
pub fn eye(size: usize) -> Tensor<f32> {
axonml_tensor::eye(size)
}
pub fn diag(values: &[f32]) -> Tensor<f32> {
let n = values.len();
let mut data = vec![0.0f32; n * n];
for (i, &val) in values.iter().enumerate() {
data[i * n + i] = val;
}
Tensor::from_vec(data, &[n, n]).unwrap()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum InitMode {
Zeros,
Ones,
Constant(f32),
Uniform,
UniformRange(f32, f32),
Normal(f32, f32), XavierUniform,
XavierNormal,
KaimingUniform,
KaimingNormal,
Orthogonal(f32), }
impl InitMode {
pub fn init(&self, fan_out: usize, fan_in: usize) -> Tensor<f32> {
match self {
InitMode::Zeros => zeros(&[fan_out, fan_in]),
InitMode::Ones => ones(&[fan_out, fan_in]),
InitMode::Constant(val) => constant(&[fan_out, fan_in], *val),
InitMode::Uniform => uniform(&[fan_out, fan_in]),
InitMode::UniformRange(low, high) => uniform_range(&[fan_out, fan_in], *low, *high),
InitMode::Normal(mean, std) => normal(&[fan_out, fan_in], *mean, *std),
InitMode::XavierUniform => xavier_uniform(fan_in, fan_out),
InitMode::XavierNormal => xavier_normal(fan_in, fan_out),
InitMode::KaimingUniform => kaiming_uniform(fan_out, fan_in),
InitMode::KaimingNormal => kaiming_normal(fan_out, fan_in),
InitMode::Orthogonal(gain) => orthogonal(fan_out, fan_in, *gain),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros() {
let t = zeros(&[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert!(t.to_vec().iter().all(|&x| x == 0.0));
}
#[test]
fn test_ones() {
let t = ones(&[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert!(t.to_vec().iter().all(|&x| x == 1.0));
}
#[test]
fn test_uniform_range() {
let t = uniform_range(&[100], 0.0, 1.0);
let data = t.to_vec();
assert!(data.iter().all(|&x| (0.0..1.0).contains(&x)));
}
#[test]
fn test_xavier_uniform() {
let t = xavier_uniform(100, 100);
assert_eq!(t.shape(), &[100, 100]);
let bound = (6.0 / 200.0_f32).sqrt();
let data = t.to_vec();
assert!(data.iter().all(|&x| x.abs() <= bound * 1.1)); }
#[test]
fn test_kaiming_uniform() {
let t = kaiming_uniform(100, 100);
assert_eq!(t.shape(), &[100, 100]);
}
#[test]
fn test_eye() {
let t = eye(3);
assert_eq!(t.shape(), &[3, 3]);
let data = t.to_vec();
assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_init_mode() {
let mode = InitMode::KaimingUniform;
let t = mode.init(10, 5);
assert_eq!(t.shape(), &[10, 5]);
}
}