use ferrotorch_core::rng::with_thread_rng;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Generator, Tensor, TensorStorage};
use crate::parameter::Parameter;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NonLinearity {
Linear,
Sigmoid,
Tanh,
ReLU,
LeakyReLU(f64),
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum FanMode {
#[default]
FanIn,
FanOut,
}
impl NonLinearity {
pub fn gain(&self) -> f64 {
match self {
NonLinearity::Linear | NonLinearity::Sigmoid => 1.0,
NonLinearity::Tanh => 5.0 / 3.0,
NonLinearity::ReLU => (2.0f64).sqrt(),
NonLinearity::LeakyReLU(neg_slope) => (2.0 / (1.0 + neg_slope * neg_slope)).sqrt(),
}
}
}
fn compute_fans(shape: &[usize]) -> FerrotorchResult<(usize, usize)> {
match shape.len() {
0 => Err(FerrotorchError::InvalidArgument {
message: "cannot compute fan for scalar tensor".into(),
}),
1 => Ok((shape[0], shape[0])),
2 => Ok((shape[1], shape[0])),
_ => {
let receptive_field: usize = shape[2..].iter().product();
Ok((shape[1] * receptive_field, shape[0] * receptive_field))
}
}
}
pub fn constant<T: Float>(param: &mut Parameter<T>, value: T) -> FerrotorchResult<()> {
let data = vec![value; param.numel()];
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn zeros<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
constant(param, <T as num_traits::Zero>::zero())
}
pub fn ones<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
constant(param, <T as num_traits::One>::one())
}
pub fn uniform<T: Float>(param: &mut Parameter<T>, low: f64, high: f64) -> FerrotorchResult<()> {
with_thread_rng(|g| uniform_with_generator(param, low, high, g))
}
pub fn uniform_with_generator<T: Float>(
param: &mut Parameter<T>,
low: f64,
high: f64,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let numel = param.numel();
let data: Vec<T> = sample_uniform_with(generator, numel, low, high);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn normal<T: Float>(param: &mut Parameter<T>, mean: f64, std: f64) -> FerrotorchResult<()> {
with_thread_rng(|g| normal_with_generator(param, mean, std, g))
}
pub fn normal_with_generator<T: Float>(
param: &mut Parameter<T>,
mean: f64,
std: f64,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let numel = param.numel();
let data: Vec<T> = sample_normal_with(generator, numel, mean, std);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn xavier_uniform<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
with_thread_rng(|g| xavier_uniform_with_generator(param, g))
}
pub fn xavier_uniform_with_generator<T: Float>(
param: &mut Parameter<T>,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
uniform_with_generator(param, -limit, limit, generator)
}
pub fn xavier_normal<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
with_thread_rng(|g| xavier_normal_with_generator(param, g))
}
pub fn xavier_normal_with_generator<T: Float>(
param: &mut Parameter<T>,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
normal_with_generator(param, 0.0, std, generator)
}
pub fn kaiming_uniform<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
) -> FerrotorchResult<()> {
with_thread_rng(|g| kaiming_uniform_with_generator(param, nonlinearity, g))
}
pub fn kaiming_uniform_with_generator<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
generator: &mut Generator,
) -> FerrotorchResult<()> {
kaiming_uniform_with_fan_mode_and_generator(param, nonlinearity, FanMode::FanIn, generator)
}
pub fn kaiming_uniform_with_fan_mode<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
mode: FanMode,
) -> FerrotorchResult<()> {
with_thread_rng(|g| kaiming_uniform_with_fan_mode_and_generator(param, nonlinearity, mode, g))
}
pub fn kaiming_uniform_with_fan_mode_and_generator<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
mode: FanMode,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let fan = match mode {
FanMode::FanIn => fan_in,
FanMode::FanOut => fan_out,
};
let gain = nonlinearity.gain();
let std = gain / (fan as f64).sqrt();
let limit = (3.0f64).sqrt() * std;
uniform_with_generator(param, -limit, limit, generator)
}
pub fn kaiming_normal<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
) -> FerrotorchResult<()> {
with_thread_rng(|g| kaiming_normal_with_generator(param, nonlinearity, g))
}
pub fn kaiming_normal_with_generator<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
generator: &mut Generator,
) -> FerrotorchResult<()> {
kaiming_normal_with_fan_mode_and_generator(param, nonlinearity, FanMode::FanIn, generator)
}
pub fn kaiming_normal_with_fan_mode<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
mode: FanMode,
) -> FerrotorchResult<()> {
with_thread_rng(|g| kaiming_normal_with_fan_mode_and_generator(param, nonlinearity, mode, g))
}
pub fn kaiming_normal_with_fan_mode_and_generator<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
mode: FanMode,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let fan = match mode {
FanMode::FanIn => fan_in,
FanMode::FanOut => fan_out,
};
let gain = nonlinearity.gain();
let std = gain / (fan as f64).sqrt();
normal_with_generator(param, 0.0, std, generator)
}
pub fn trunc_normal_<T: Float>(
param: &mut Parameter<T>,
mean: f64,
std: f64,
a: f64,
b: f64,
) -> FerrotorchResult<()> {
with_thread_rng(|g| trunc_normal_with_generator(param, mean, std, a, b, g))
}
pub fn trunc_normal_with_generator<T: Float>(
param: &mut Parameter<T>,
mean: f64,
std: f64,
a: f64,
b: f64,
generator: &mut Generator,
) -> FerrotorchResult<()> {
if a >= b {
return Err(FerrotorchError::InvalidArgument {
message: format!("trunc_normal_: a ({a}) must be less than b ({b})"),
});
}
if std <= 0.0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("trunc_normal_: std ({std}) must be positive"),
});
}
let numel = param.numel();
let mut data: Vec<T> = Vec::with_capacity(numel);
let mut remaining = numel;
while remaining > 0 {
let batch_size = remaining * 2 + 64;
let candidates: Vec<T> = sample_normal_with(generator, batch_size, mean, std);
for v in candidates {
let f = v.to_f64().unwrap();
if f >= a && f <= b {
data.push(v);
remaining -= 1;
if remaining == 0 {
break;
}
}
}
}
data.truncate(numel);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn orthogonal_<T: Float>(param: &mut Parameter<T>, gain: f64) -> FerrotorchResult<()> {
with_thread_rng(|g| orthogonal_with_generator(param, gain, g))
}
pub fn orthogonal_with_generator<T: Float>(
param: &mut Parameter<T>,
gain: f64,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "orthogonal_ requires at least a 2D tensor".into(),
});
}
let rows = shape[0];
let cols: usize = shape[1..].iter().product();
let (n, m) = (rows, cols);
let flat: Vec<f64> = sample_normal_with::<f64>(generator, n * m, 0.0, 1.0);
let transpose = n < m;
let (rows_eff, cols_eff) = if transpose { (m, n) } else { (n, m) };
let k_eff = rows_eff.min(cols_eff);
let mut q: Vec<f64> = if transpose {
let mut t = vec![0.0; m * n];
for i in 0..n {
for j in 0..m {
t[j * n + i] = flat[i * m + j];
}
}
t
} else {
flat
};
let ce = cols_eff;
let mut r_diag = vec![0.0f64; k_eff];
for j in 0..k_eff {
let mut norm: f64 = 0.0;
for i in 0..rows_eff {
let v = q[i * ce + j];
norm += v * v;
}
norm = norm.sqrt();
if norm < 1e-15 {
r_diag[j] = 1.0;
continue;
}
r_diag[j] = norm;
for i in 0..rows_eff {
q[i * ce + j] /= norm;
}
for jj in (j + 1)..cols_eff {
let mut dot = 0.0;
for i in 0..rows_eff {
dot += q[i * ce + j] * q[i * ce + jj];
}
for i in 0..rows_eff {
q[i * ce + jj] -= dot * q[i * ce + j];
}
}
}
for j in 0..k_eff {
let sign = if r_diag[j] >= 0.0 { 1.0 } else { -1.0 };
for i in 0..rows_eff {
q[i * ce + j] *= sign * gain;
}
}
let mut result = vec![T::from(0.0).unwrap(); n * m];
if transpose {
for i in 0..n.min(k_eff) {
for j in 0..m {
result[i * m + j] = T::from(q[j * ce + i]).unwrap();
}
}
} else {
for i in 0..n {
for j in 0..m.min(k_eff) {
result[i * m + j] = T::from(q[i * ce + j]).unwrap();
}
}
}
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(result),
shape,
true,
)?);
Ok(())
}
pub fn sparse_<T: Float>(
param: &mut Parameter<T>,
sparsity: f64,
std: f64,
) -> FerrotorchResult<()> {
with_thread_rng(|g| sparse_with_generator(param, sparsity, std, g))
}
pub fn sparse_with_generator<T: Float>(
param: &mut Parameter<T>,
sparsity: f64,
std: f64,
generator: &mut Generator,
) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: "sparse_ requires a 2D tensor".into(),
});
}
if !(0.0..1.0).contains(&sparsity) {
return Err(FerrotorchError::InvalidArgument {
message: format!("sparse_: sparsity ({sparsity}) must be in [0, 1)"),
});
}
let (rows, cols) = (shape[0], shape[1]);
let num_zeros_per_col = ((rows as f64) * sparsity).ceil() as usize;
let values: Vec<f64> = sample_normal_with::<f64>(generator, rows * cols, 0.0, std);
let mut data = vec![T::from(0.0).unwrap(); rows * cols];
let rand_indices: Vec<u32> = (0..(cols * num_zeros_per_col.min(rows)))
.map(|_| generator.random_u32())
.collect();
let mut rand_idx_pos = 0usize;
for j in 0..cols {
let mut indices: Vec<usize> = (0..rows).collect();
let num_to_pick = num_zeros_per_col.min(rows);
for k in 0..num_to_pick {
let r = rand_indices[rand_idx_pos] as usize;
rand_idx_pos += 1;
let swap_idx = k + r % (rows - k);
indices.swap(k, swap_idx);
}
let zero_set: std::collections::HashSet<usize> =
indices[..num_to_pick].iter().copied().collect();
for i in 0..rows {
if zero_set.contains(&i) {
data[i * cols + j] = T::from(0.0).unwrap();
} else {
data[i * cols + j] = T::from(values[i * cols + j]).unwrap();
}
}
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
pub fn dirac_<T: Float>(param: &mut Parameter<T>, groups: usize) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "dirac_ requires at least a 3D tensor (out_ch, in_ch/groups, *kernel_size)"
.into(),
});
}
if groups == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "dirac_: groups must be > 0".into(),
});
}
let out_channels = shape[0];
let in_channels_per_group = shape[1];
if out_channels % groups != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"dirac_: out_channels ({out_channels}) must be divisible by groups ({groups})"
),
});
}
let min_dim = (out_channels / groups).min(in_channels_per_group);
let kernel_size: usize = shape[2..].iter().product();
let center = kernel_size / 2;
let numel = param.numel();
let mut data = vec![T::from(0.0).unwrap(); numel];
let one = T::from(1.0).unwrap();
let in_stride = kernel_size;
let out_stride = in_channels_per_group * kernel_size;
for g in 0..groups {
let out_offset = g * (out_channels / groups);
for d in 0..min_dim {
let out_idx = out_offset + d;
let in_idx = d;
data[out_idx * out_stride + in_idx * in_stride + center] = one;
}
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
pub fn eye_<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: "eye_ requires a 2D tensor".into(),
});
}
let (rows, cols) = (shape[0], shape[1]);
let zero = T::from(0.0).unwrap();
let one = T::from(1.0).unwrap();
let mut data = vec![zero; rows * cols];
for i in 0..rows.min(cols) {
data[i * cols + i] = one;
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
fn sample_uniform_with<T: Float>(
generator: &mut Generator,
n: usize,
low: f64,
high: f64,
) -> Vec<T> {
let range = high - low;
(0..n)
.map(|_| {
let u = generator.next_uniform_f64();
T::from(low + u * range).unwrap()
})
.collect()
}
fn sample_normal_with<T: Float>(
generator: &mut Generator,
n: usize,
mean: f64,
std: f64,
) -> Vec<T> {
let mut data = Vec::with_capacity(n);
for _ in 0..n {
let z = generator.next_normal_f64();
data.push(T::from(mean + std * z).unwrap());
}
data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros_init() {
let mut p = Parameter::<f32>::ones(&[3, 4]).unwrap();
zeros(&mut p).unwrap();
assert!(p.data().unwrap().iter().all(|&x| x == 0.0));
}
#[test]
fn test_ones_init() {
let mut p = Parameter::<f32>::zeros(&[2, 3]).unwrap();
ones(&mut p).unwrap();
assert!(p.data().unwrap().iter().all(|&x| x == 1.0));
}
#[test]
#[allow(clippy::approx_constant)] fn test_constant_init() {
let mut p = Parameter::<f32>::zeros(&[5]).unwrap();
constant(&mut p, 3.14).unwrap();
assert!(p.data().unwrap().iter().all(|&x| (x - 3.14).abs() < 1e-5));
}
#[test]
fn test_uniform_init_bounds() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
uniform(&mut p, -1.0, 1.0).unwrap();
let data = p.data().unwrap();
assert!(data.iter().all(|&x| (-1.0..=1.0).contains(&x)));
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1);
}
#[test]
fn test_normal_init_stats() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
normal(&mut p, 0.0, 1.0).unwrap();
let data = p.data().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "mean = {mean}");
assert!((var - 1.0).abs() < 0.2, "var = {var}");
}
#[test]
fn test_xavier_uniform_stats() {
let mut p = Parameter::<f32>::zeros(&[256, 128]).unwrap();
xavier_uniform(&mut p).unwrap();
let data = p.data().unwrap();
let limit = (6.0_f32 / (128.0 + 256.0)).sqrt();
assert!(data.iter().all(|&x| x.abs() <= limit + 0.01));
}
#[test]
fn test_xavier_normal_stats() {
let mut p = Parameter::<f32>::zeros(&[256, 128]).unwrap();
xavier_normal(&mut p).unwrap();
let data = p.data().unwrap();
let expected_std = (2.0_f32 / (128.0 + 256.0)).sqrt();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.05, "mean = {mean}");
assert!(
(var.sqrt() - expected_std).abs() < expected_std * 0.15,
"std = {}, expected = {expected_std}",
var.sqrt()
);
}
#[test]
fn test_kaiming_uniform_relu() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_uniform(&mut p, NonLinearity::ReLU).unwrap();
let data = p.data().unwrap();
let gain = (2.0f64).sqrt();
let std = gain / (128.0f64).sqrt();
let limit = (3.0f64).sqrt() * std;
assert!(data.iter().all(|&x| (x as f64).abs() <= limit + 0.01));
}
#[test]
fn test_kaiming_uniform_fan_out() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_uniform_with_fan_mode(&mut p, NonLinearity::ReLU, FanMode::FanOut).unwrap();
let data = p.data().unwrap();
let gain = (2.0f64).sqrt();
let std = gain / (64.0f64).sqrt();
let limit = (3.0f64).sqrt() * std;
assert!(data.iter().all(|&x| (x as f64).abs() <= limit + 0.01));
}
#[test]
fn test_kaiming_normal_fan_out() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_normal_with_fan_mode(&mut p, NonLinearity::ReLU, FanMode::FanOut).unwrap();
let data = p.data().unwrap();
let expected_std = (2.0f64).sqrt() / (64.0f64).sqrt();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "mean = {mean}");
assert!(
((var.sqrt() as f64) - expected_std).abs() < expected_std * 0.2,
"std = {}, expected = {expected_std}",
var.sqrt()
);
}
#[test]
fn test_fan_mode_default_is_fan_in() {
assert_eq!(FanMode::default(), FanMode::FanIn);
}
#[test]
fn test_kaiming_normal_relu() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_normal(&mut p, NonLinearity::ReLU).unwrap();
let data = p.data().unwrap();
let expected_std = (2.0f64).sqrt() / (128.0f64).sqrt();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "mean = {mean}");
assert!(
((var.sqrt() as f64) - expected_std).abs() < expected_std * 0.2,
"std = {}, expected = {expected_std}",
var.sqrt()
);
}
#[test]
fn test_compute_fans_2d() {
let (fi, fo) = compute_fans(&[64, 128]).unwrap();
assert_eq!(fi, 128);
assert_eq!(fo, 64);
}
#[test]
fn test_compute_fans_4d() {
let (fi, fo) = compute_fans(&[32, 16, 3, 3]).unwrap();
assert_eq!(fi, 16 * 9);
assert_eq!(fo, 32 * 9);
}
#[test]
fn test_nonlinearity_gain() {
assert!((NonLinearity::ReLU.gain() - (2.0f64).sqrt()).abs() < 1e-10);
assert!((NonLinearity::Linear.gain() - 1.0).abs() < 1e-10);
assert!((NonLinearity::Tanh.gain() - 5.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_init_preserves_requires_grad() {
let mut p = Parameter::<f32>::zeros(&[5]).unwrap();
xavier_uniform(&mut p).unwrap();
assert!(p.requires_grad());
}
#[test]
fn test_trunc_normal_bounds() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
trunc_normal_(&mut p, 0.0, 1.0, -2.0, 2.0).unwrap();
let data = p.data().unwrap();
assert!(
data.iter().all(|&x| (-2.0..=2.0).contains(&x)),
"all values must be within [-2, 2]"
);
}
#[test]
fn test_trunc_normal_stats() {
let mut p = Parameter::<f32>::zeros(&[50000]).unwrap();
trunc_normal_(&mut p, 0.0, 1.0, -2.0, 2.0).unwrap();
let data = p.data().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.05, "mean = {mean}");
}
#[test]
fn test_trunc_normal_rejects_bad_bounds() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(trunc_normal_(&mut p, 0.0, 1.0, 2.0, -2.0).is_err());
}
#[test]
fn test_trunc_normal_rejects_zero_std() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(trunc_normal_(&mut p, 0.0, 0.0, -1.0, 1.0).is_err());
}
#[test]
fn test_orthogonal_columns_orthonormal() {
let mut p = Parameter::<f64>::zeros(&[32, 32]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let n = 32;
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * n + i] * data[k * n + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-6,
"Q^T Q [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_gain() {
let mut p = Parameter::<f64>::zeros(&[16, 16]).unwrap();
orthogonal_(&mut p, 2.0).unwrap();
let data = p.data().unwrap();
let n = 16;
for i in 0..n {
let mut col_norm_sq = 0.0;
for k in 0..n {
let v = data[k * n + i];
col_norm_sq += v * v;
}
assert!(
(col_norm_sq - 4.0).abs() < 1e-5,
"column {i} norm^2 = {col_norm_sq}, expected 4.0"
);
}
}
#[test]
fn test_orthogonal_tall_matrix() {
let mut p = Parameter::<f64>::zeros(&[64, 16]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let (n, m) = (64, 16);
for i in 0..m {
for j in 0..m {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * m + i] * data[k * m + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-5,
"tall Q^T Q [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_wide_matrix() {
let mut p = Parameter::<f64>::zeros(&[16, 64]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let (n, m) = (16, 64);
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..m {
dot += data[i * m + k] * data[j * m + k];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-5,
"wide Q Q^T [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_rejects_1d() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(orthogonal_(&mut p, 1.0).is_err());
}
#[test]
fn test_sparse_sparsity_ratio() {
let mut p = Parameter::<f32>::zeros(&[100, 50]).unwrap();
sparse_(&mut p, 0.9, 0.01).unwrap();
let data = p.data().unwrap();
let num_zeros = data.iter().filter(|&&x| x == 0.0).count();
let total = data.len();
let actual_sparsity = num_zeros as f64 / total as f64;
assert!(
(actual_sparsity - 0.9).abs() < 0.05,
"sparsity = {actual_sparsity}, expected ~0.9"
);
}
#[test]
fn test_sparse_nonzero_drawn_from_normal() {
let mut p = Parameter::<f32>::zeros(&[200, 100]).unwrap();
sparse_(&mut p, 0.5, 1.0).unwrap();
let data = p.data().unwrap();
let nonzero: Vec<f64> = data
.iter()
.filter(|&&x| x != 0.0)
.map(|&x| x as f64)
.collect();
assert!(!nonzero.is_empty());
let mean: f64 = nonzero.iter().sum::<f64>() / nonzero.len() as f64;
assert!(mean.abs() < 0.15, "nonzero mean = {mean}");
}
#[test]
fn test_sparse_rejects_non_2d() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(sparse_(&mut p, 0.5, 1.0).is_err());
}
#[test]
fn test_sparse_rejects_bad_sparsity() {
let mut p = Parameter::<f32>::zeros(&[10, 10]).unwrap();
assert!(sparse_(&mut p, 1.0, 1.0).is_err());
assert!(sparse_(&mut p, -0.1, 1.0).is_err());
}
#[test]
fn test_dirac_3d_identity() {
let mut p = Parameter::<f32>::zeros(&[4, 4, 3]).unwrap();
dirac_(&mut p, 1).unwrap();
let data = p.data().unwrap();
let center = 1;
for out_ch in 0..4 {
for in_ch in 0..4 {
let val = data[out_ch * 4 * 3 + in_ch * 3 + center];
if out_ch == in_ch {
assert!((val - 1.0).abs() < 1e-6, "diag [{out_ch},{in_ch}] = {val}");
} else {
assert!(val.abs() < 1e-6, "off-diag [{out_ch},{in_ch}] = {val}");
}
}
}
}
#[test]
fn test_dirac_4d_identity() {
let mut p = Parameter::<f32>::zeros(&[2, 2, 3, 3]).unwrap();
dirac_(&mut p, 1).unwrap();
let data = p.data().unwrap();
let _kernel_size = 9;
let center = 4;
for out_ch in 0..2 {
for in_ch in 0..2 {
let val = data[out_ch * 2 * 9 + in_ch * 9 + center];
if out_ch == in_ch {
assert!((val - 1.0).abs() < 1e-6);
} else {
assert!(val.abs() < 1e-6);
}
}
}
}
#[test]
fn test_dirac_groups() {
let mut p = Parameter::<f32>::zeros(&[4, 2, 3]).unwrap();
dirac_(&mut p, 2).unwrap();
let data = p.data().unwrap();
let center = 1;
let idx = |oc: usize, ic: usize, k: usize| oc * 6 + ic * 3 + k;
assert!((data[idx(0, 0, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(1, 1, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(2, 0, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(3, 1, center)] - 1.0).abs() < 1e-6);
}
#[test]
fn test_dirac_rejects_2d() {
let mut p = Parameter::<f32>::zeros(&[4, 4]).unwrap();
assert!(dirac_(&mut p, 1).is_err());
}
#[test]
fn test_eye_square() {
let mut p = Parameter::<f32>::zeros(&[4, 4]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(data[i * 4 + j] - expected).abs() < 1e-6,
"eye[{i},{j}] = {}",
data[i * 4 + j]
);
}
}
}
#[test]
fn test_eye_tall() {
let mut p = Parameter::<f32>::zeros(&[6, 3]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..6 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((data[i * 3 + j] - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_eye_wide() {
let mut p = Parameter::<f32>::zeros(&[3, 6]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..3 {
for j in 0..6 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((data[i * 6 + j] - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_eye_rejects_non_2d() {
let mut p = Parameter::<f32>::zeros(&[4]).unwrap();
assert!(eye_(&mut p).is_err());
}
#[test]
fn test_eye_preserves_requires_grad() {
let mut p = Parameter::<f32>::zeros(&[3, 3]).unwrap();
eye_(&mut p).unwrap();
assert!(p.requires_grad());
}
}