use crate::tensor::{Device, Result, Tensor, TensorOptions};
pub fn kaiming_uniform(shape: &[i64], fan_in: i64, a: f64, device: Device) -> Result<Tensor> {
let gain = (2.0 / (1.0 + a * a)).sqrt();
let std = gain / (fan_in as f64).sqrt();
let bound = 3.0_f64.sqrt() * std;
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}
pub fn kaiming_normal(shape: &[i64], fan_in: i64, a: f64, device: Device) -> Result<Tensor> {
let gain = (2.0 / (1.0 + a * a)).sqrt();
let std = gain / (fan_in as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::randn(shape, opts)?.mul_scalar(std)
}
pub fn xavier_uniform(shape: &[i64], fan_in: i64, fan_out: i64, device: Device) -> Result<Tensor> {
let bound = (6.0 / (fan_in + fan_out) as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}
pub fn xavier_normal(shape: &[i64], fan_in: i64, fan_out: i64, device: Device) -> Result<Tensor> {
let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::randn(shape, opts)?.mul_scalar(std)
}
pub fn uniform_bias(fan_in: i64, shape: &[i64], device: Device) -> Result<Tensor> {
let bound = 1.0 / (fan_in as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}