use crate::tensor::{Device, Result, Tensor, TensorOptions};
pub fn kaiming_uniform(shape: &[i64], fan_in: i64, a: f64, device: Device) -> Result<Tensor> {
let gain = (2.0 / (1.0 + a * a)).sqrt();
let std = gain / (fan_in as f64).sqrt();
let bound = 3.0_f64.sqrt() * std;
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}
pub fn kaiming_normal(shape: &[i64], fan_in: i64, a: f64, device: Device) -> Result<Tensor> {
let gain = (2.0 / (1.0 + a * a)).sqrt();
let std = gain / (fan_in as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::randn(shape, opts)?.mul_scalar(std)
}
pub fn xavier_uniform(shape: &[i64], fan_in: i64, fan_out: i64, device: Device) -> Result<Tensor> {
let bound = (6.0 / (fan_in + fan_out) as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}
pub fn xavier_normal(shape: &[i64], fan_in: i64, fan_out: i64, device: Device) -> Result<Tensor> {
let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::randn(shape, opts)?.mul_scalar(std)
}
pub fn uniform_bias(fan_in: i64, shape: &[i64], device: Device) -> Result<Tensor> {
let bound = 1.0 / (fan_in as f64).sqrt();
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(2.0 * bound)?.add_scalar(-bound)
}
pub fn uniform(shape: &[i64], low: f64, high: f64, device: Device) -> Result<Tensor> {
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::rand(shape, opts)?.mul_scalar(high - low)?.add_scalar(low)
}
pub fn normal(shape: &[i64], mean: f64, std: f64, device: Device) -> Result<Tensor> {
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
Tensor::randn(shape, opts)?.mul_scalar(std)?.add_scalar(mean)
}
pub fn orthogonal(shape: &[i64], gain: f64, device: Device) -> Result<Tensor> {
assert!(shape.len() == 2, "orthogonal init requires a 2D shape");
let rows = shape[0];
let cols = shape[1];
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
let n = rows.max(cols) as usize;
let a = Tensor::randn(&[n as i64, n as i64], opts)?;
let mut data = a.to_f32_vec()?;
for i in 0..n {
let row_start = i * n;
for j in 0..i {
let prev_start = j * n;
let mut dot = 0.0f64;
for k in 0..n {
dot += data[row_start + k] as f64 * data[prev_start + k] as f64;
}
for k in 0..n {
data[row_start + k] -= (dot * data[prev_start + k] as f64) as f32;
}
}
let mut norm = 0.0f64;
for k in 0..n {
norm += data[row_start + k] as f64 * data[row_start + k] as f64;
}
let norm = norm.sqrt().max(1e-10);
for k in 0..n {
data[row_start + k] = (data[row_start + k] as f64 / norm) as f32;
}
}
let q = Tensor::from_f32(&data, &[n as i64, n as i64], device)?;
let result = q.narrow(0, 0, rows)?.narrow(1, 0, cols)?.contiguous()?;
if (gain - 1.0).abs() > 1e-10 {
result.mul_scalar(gain)
} else {
Ok(result)
}
}
pub fn trunc_normal(
shape: &[i64],
mean: f64,
std: f64,
a: f64,
b: f64,
device: Device,
) -> Result<Tensor> {
let opts = TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
};
let low = mean + a * std;
let high = mean + b * std;
Tensor::randn(shape, opts)?.mul_scalar(std)?.add_scalar(mean)?.clamp(low, high)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_range() {
let t = uniform(&[1000], -2.0, 3.0, crate::tensor::test_device()).unwrap();
let data = t.to_f32_vec().unwrap();
for &v in &data {
assert!((-2.0..=3.0).contains(&v), "value {} out of range [-2, 3]", v);
}
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!((mean - 0.5).abs() < 0.2, "mean {} too far from 0.5", mean);
}
#[test]
fn test_normal_stats() {
let t = normal(&[10000], 5.0, 0.1, crate::tensor::test_device()).unwrap();
let data = t.to_f32_vec().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!((mean - 5.0).abs() < 0.05, "mean {} too far from 5.0", mean);
let var: f32 = data.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / data.len() as f32;
let std_dev = var.sqrt();
assert!((std_dev - 0.1).abs() < 0.02, "std {} too far from 0.1", std_dev);
}
#[test]
fn test_orthogonal_square() {
let t = orthogonal(&[4, 4], 1.0, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![4, 4]);
let qt = t.transpose(0, 1).unwrap();
let qqt = t.matmul(&qt).unwrap();
let data = qqt.to_f32_vec().unwrap();
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(data[i * 4 + j] - expected).abs() < 0.01,
"Q@Q^T[{},{}] = {}, expected {}",
i, j, data[i * 4 + j], expected
);
}
}
}
#[test]
fn test_orthogonal_tall() {
let t = orthogonal(&[6, 4], 1.0, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![6, 4]);
let qt = t.transpose(0, 1).unwrap();
let qtq = qt.matmul(&t).unwrap();
let data = qtq.to_f32_vec().unwrap();
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(data[i * 4 + j] - expected).abs() < 0.01,
"Q^T@Q[{},{}] = {}, expected {}",
i, j, data[i * 4 + j], expected
);
}
}
}
#[test]
fn test_orthogonal_wide() {
let t = orthogonal(&[4, 6], 1.0, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![4, 6]);
let qt = t.transpose(0, 1).unwrap();
let qqt = t.matmul(&qt).unwrap();
let data = qqt.to_f32_vec().unwrap();
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(data[i * 4 + j] - expected).abs() < 0.01,
"Q@Q^T[{},{}] = {}, expected {}",
i, j, data[i * 4 + j], expected
);
}
}
}
#[test]
fn test_orthogonal_gain() {
let t = orthogonal(&[4, 4], 2.0, crate::tensor::test_device()).unwrap();
let data = t.to_f32_vec().unwrap();
for i in 0..4 {
let norm: f32 = (0..4).map(|j| data[i * 4 + j] * data[i * 4 + j]).sum::<f32>().sqrt();
assert!(
(norm - 2.0).abs() < 0.1,
"row {} norm = {}, expected ~2.0", i, norm
);
}
}
#[test]
fn test_trunc_normal_bounds() {
let t = trunc_normal(&[10000], 0.0, 1.0, -2.0, 2.0, crate::tensor::test_device()).unwrap();
let data = t.to_f32_vec().unwrap();
for &v in &data {
assert!((-2.0..=2.0).contains(&v), "value {} out of [-2, 2]", v);
}
}
#[test]
fn test_trunc_normal_centered() {
let t = trunc_normal(&[10000], 3.0, 0.5, -2.0, 2.0, crate::tensor::test_device()).unwrap();
let data = t.to_f32_vec().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!((mean - 3.0).abs() < 0.1, "mean {} too far from 3.0", mean);
for &v in &data {
assert!((2.0..=4.0).contains(&v), "value {} out of [2.0, 4.0]", v);
}
}
#[test]
fn test_kaiming_uniform_shape() {
let t = kaiming_uniform(&[3, 4], 4, 0.0, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![3, 4]);
}
#[test]
fn test_xavier_uniform_shape() {
let t = xavier_uniform(&[3, 4], 3, 4, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![3, 4]);
}
}