use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub enum Init {
Zeros,
Ones,
Const(f32),
Uniform(f32),
PyTorchLinear,
PyTorchEmbedding,
Kaiming,
Xavier,
Randn { mean: f64, stdev: f64 },
TruncatedNormal { mean: f64, stdev: f64 },
}
impl Init {
pub fn init_tensor<R, C>(
&self,
shape: &[usize],
dtype: DType,
device: &R::Device,
client: &C,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: numr::runtime::RuntimeClient<R>
+ numr::ops::RandomOps<R>
+ numr::ops::ScalarOps<R>
+ numr::ops::BinaryOps<R>
+ numr::ops::CompareOps<R>
+ numr::ops::TensorOps<R>,
{
match *self {
Init::Zeros => Ok(Tensor::zeros(shape, dtype, device)),
Init::Ones => Ok(Tensor::ones(shape, dtype, device)),
Init::Const(val) => {
let t = Tensor::zeros(shape, dtype, device);
client.add_scalar(&t, val as f64).map_err(Error::Numr)
}
Init::Uniform(bound) => {
let r = client.rand(shape, dtype).map_err(Error::Numr)?;
let scaled = client
.mul_scalar(&r, 2.0 * bound as f64)
.map_err(Error::Numr)?;
client
.add_scalar(&scaled, -(bound as f64))
.map_err(Error::Numr)
}
Init::PyTorchLinear => {
let fan_in = shape[0];
let bound = 1.0 / (fan_in as f64).sqrt();
let r = client.rand(shape, dtype).map_err(Error::Numr)?;
let scaled = client.mul_scalar(&r, 2.0 * bound).map_err(Error::Numr)?;
client.add_scalar(&scaled, -bound).map_err(Error::Numr)
}
Init::PyTorchEmbedding => {
let r = client.rand(shape, dtype).map_err(Error::Numr)?;
let scaled = client.mul_scalar(&r, 2.0).map_err(Error::Numr)?;
client.add_scalar(&scaled, -1.0).map_err(Error::Numr)
}
Init::Kaiming => {
let fan_in = if shape.len() >= 2 {
shape[..shape.len() - 1].iter().product::<usize>()
} else {
shape[0]
};
let std = (2.0 / fan_in as f64).sqrt();
let r = client.randn(shape, dtype).map_err(Error::Numr)?;
client.mul_scalar(&r, std).map_err(Error::Numr)
}
Init::Xavier => {
let (fan_in, fan_out) = if shape.len() >= 2 {
let fi = shape[..shape.len() - 1].iter().product::<usize>();
let fo = shape[shape.len() - 1];
(fi, fo)
} else {
(shape[0], shape[0])
};
let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
let r = client.randn(shape, dtype).map_err(Error::Numr)?;
client.mul_scalar(&r, std).map_err(Error::Numr)
}
Init::Randn { mean, stdev } => {
let r = client.randn(shape, dtype).map_err(Error::Numr)?;
let scaled = client.mul_scalar(&r, stdev).map_err(Error::Numr)?;
if mean != 0.0 {
client.add_scalar(&scaled, mean).map_err(Error::Numr)
} else {
Ok(scaled)
}
}
Init::TruncatedNormal { mean, stdev } => {
let r = client.randn(shape, dtype).map_err(Error::Numr)?;
let clamped = client.clamp(&r, -2.0, 2.0).map_err(Error::Numr)?;
let scaled = client.mul_scalar(&clamped, stdev).map_err(Error::Numr)?;
if mean != 0.0 {
client.add_scalar(&scaled, mean).map_err(Error::Numr)
} else {
Ok(scaled)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
fn device() -> CpuDevice {
CpuDevice::new()
}
fn client() -> numr::runtime::cpu::CpuClient {
let d = device();
CpuRuntime::default_client(&d)
}
#[test]
fn test_init_zeros() {
let d = device();
let c = client();
let t = Init::Zeros
.init_tensor(&[2, 3], DType::F32, &d, &c)
.unwrap();
assert_eq!(t.shape(), &[2, 3]);
let data: Vec<f32> = t.to_vec();
assert!(data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_init_kaiming() {
let d = device();
let c = client();
let t = Init::Kaiming
.init_tensor(&[64, 128], DType::F32, &d, &c)
.unwrap();
assert_eq!(t.shape(), &[64, 128]);
let data: Vec<f32> = t.to_vec();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "Kaiming mean too large: {mean}");
let var: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
let std = var.sqrt();
let expected_std = (2.0f32 / 64.0).sqrt();
assert!(
(std - expected_std).abs() < 0.05,
"Kaiming std {std} vs expected {expected_std}"
);
}
#[test]
fn test_init_xavier() {
let d = device();
let c = client();
let t = Init::Xavier
.init_tensor(&[256, 512], DType::F32, &d, &c)
.unwrap();
assert_eq!(t.shape(), &[256, 512]);
let data: Vec<f32> = t.to_vec();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.05, "Xavier mean too large: {mean}");
}
#[test]
fn test_init_randn() {
let d = device();
let c = client();
let t = Init::Randn {
mean: 5.0,
stdev: 0.1,
}
.init_tensor(&[1000], DType::F32, &d, &c)
.unwrap();
let data: Vec<f32> = t.to_vec();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!((mean - 5.0).abs() < 0.1, "Randn mean {mean} should be ~5.0");
}
#[test]
fn test_init_truncated_normal() {
let d = device();
let c = client();
let t = Init::TruncatedNormal {
mean: 0.0,
stdev: 0.02,
}
.init_tensor(&[10000], DType::F32, &d, &c)
.unwrap();
let data: Vec<f32> = t.to_vec();
for &v in &data {
assert!(
(-0.04..=0.04).contains(&v),
"Truncated normal value {v} out of range [-0.04, 0.04]"
);
}
}
}