use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::parameter::Parameter;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NonLinearity {
Linear,
Sigmoid,
Tanh,
ReLU,
LeakyReLU(f64),
}
impl NonLinearity {
pub fn gain(&self) -> f64 {
match self {
NonLinearity::Linear | NonLinearity::Sigmoid => 1.0,
NonLinearity::Tanh => 5.0 / 3.0,
NonLinearity::ReLU => (2.0f64).sqrt(),
NonLinearity::LeakyReLU(neg_slope) => (2.0 / (1.0 + neg_slope * neg_slope)).sqrt(),
}
}
}
fn compute_fans(shape: &[usize]) -> FerrotorchResult<(usize, usize)> {
match shape.len() {
0 => Err(FerrotorchError::InvalidArgument {
message: "cannot compute fan for scalar tensor".into(),
}),
1 => Ok((shape[0], shape[0])),
2 => Ok((shape[1], shape[0])),
_ => {
let receptive_field: usize = shape[2..].iter().product();
Ok((shape[1] * receptive_field, shape[0] * receptive_field))
}
}
}
pub fn constant<T: Float>(param: &mut Parameter<T>, value: T) -> FerrotorchResult<()> {
let data = vec![value; param.numel()];
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn zeros<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
constant(param, <T as num_traits::Zero>::zero())
}
pub fn ones<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
constant(param, <T as num_traits::One>::one())
}
pub fn uniform<T: Float>(param: &mut Parameter<T>, low: f64, high: f64) -> FerrotorchResult<()> {
let numel = param.numel();
let data: Vec<T> = simple_uniform(numel, low, high);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn normal<T: Float>(param: &mut Parameter<T>, mean: f64, std: f64) -> FerrotorchResult<()> {
let numel = param.numel();
let data: Vec<T> = simple_normal(numel, mean, std);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn xavier_uniform<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
uniform(param, -limit, limit)
}
pub fn xavier_normal<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
let (fan_in, fan_out) = compute_fans(param.shape())?;
let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
normal(param, 0.0, std)
}
pub fn kaiming_uniform<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
) -> FerrotorchResult<()> {
let (fan_in, _) = compute_fans(param.shape())?;
let gain = nonlinearity.gain();
let std = gain / (fan_in as f64).sqrt();
let limit = (3.0f64).sqrt() * std;
uniform(param, -limit, limit)
}
pub fn kaiming_normal<T: Float>(
param: &mut Parameter<T>,
nonlinearity: NonLinearity,
) -> FerrotorchResult<()> {
let (fan_in, _) = compute_fans(param.shape())?;
let gain = nonlinearity.gain();
let std = gain / (fan_in as f64).sqrt();
normal(param, 0.0, std)
}
pub fn trunc_normal_<T: Float>(
param: &mut Parameter<T>,
mean: f64,
std: f64,
a: f64,
b: f64,
) -> FerrotorchResult<()> {
if a >= b {
return Err(FerrotorchError::InvalidArgument {
message: format!("trunc_normal_: a ({a}) must be less than b ({b})"),
});
}
if std <= 0.0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("trunc_normal_: std ({std}) must be positive"),
});
}
let numel = param.numel();
let mut data: Vec<T> = Vec::with_capacity(numel);
let mut remaining = numel;
while remaining > 0 {
let batch_size = remaining * 2 + 64;
let candidates: Vec<T> = simple_normal(batch_size, mean, std);
for v in candidates {
let f = v.to_f64().unwrap();
if f >= a && f <= b {
data.push(v);
remaining -= 1;
if remaining == 0 {
break;
}
}
}
}
data.truncate(numel);
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(data),
param.shape().to_vec(),
true,
)?);
Ok(())
}
pub fn orthogonal_<T: Float>(param: &mut Parameter<T>, gain: f64) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "orthogonal_ requires at least a 2D tensor".into(),
});
}
let rows = shape[0];
let cols: usize = shape[1..].iter().product();
let (n, m) = (rows, cols);
let flat: Vec<f64> = simple_normal::<f64>(n * m, 0.0, 1.0);
let transpose = n < m;
let (rows_eff, cols_eff) = if transpose { (m, n) } else { (n, m) };
let k_eff = rows_eff.min(cols_eff);
let mut q: Vec<f64> = if transpose {
let mut t = vec![0.0; m * n];
for i in 0..n {
for j in 0..m {
t[j * n + i] = flat[i * m + j];
}
}
t
} else {
flat
};
let ce = cols_eff;
let mut r_diag = vec![0.0f64; k_eff];
for j in 0..k_eff {
let mut norm: f64 = 0.0;
for i in 0..rows_eff {
let v = q[i * ce + j];
norm += v * v;
}
norm = norm.sqrt();
if norm < 1e-15 {
r_diag[j] = 1.0;
continue;
}
r_diag[j] = norm;
for i in 0..rows_eff {
q[i * ce + j] /= norm;
}
for jj in (j + 1)..cols_eff {
let mut dot = 0.0;
for i in 0..rows_eff {
dot += q[i * ce + j] * q[i * ce + jj];
}
for i in 0..rows_eff {
q[i * ce + jj] -= dot * q[i * ce + j];
}
}
}
for j in 0..k_eff {
let sign = if r_diag[j] >= 0.0 { 1.0 } else { -1.0 };
for i in 0..rows_eff {
q[i * ce + j] *= sign * gain;
}
}
let mut result = vec![T::from(0.0).unwrap(); n * m];
if transpose {
for i in 0..n.min(k_eff) {
for j in 0..m {
result[i * m + j] = T::from(q[j * ce + i]).unwrap();
}
}
} else {
for i in 0..n {
for j in 0..m.min(k_eff) {
result[i * m + j] = T::from(q[i * ce + j]).unwrap();
}
}
}
*param = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(result),
shape,
true,
)?);
Ok(())
}
pub fn sparse_<T: Float>(
param: &mut Parameter<T>,
sparsity: f64,
std: f64,
) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: "sparse_ requires a 2D tensor".into(),
});
}
if !(0.0..1.0).contains(&sparsity) {
return Err(FerrotorchError::InvalidArgument {
message: format!("sparse_: sparsity ({sparsity}) must be in [0, 1)"),
});
}
let (rows, cols) = (shape[0], shape[1]);
let num_zeros_per_col = ((rows as f64) * sparsity).ceil() as usize;
let values: Vec<f64> = simple_normal::<f64>(rows * cols, 0.0, std);
let mut data = vec![T::from(0.0).unwrap(); rows * cols];
let mut rng_state = xorshift_seed();
for j in 0..cols {
let mut indices: Vec<usize> = (0..rows).collect();
let num_to_pick = num_zeros_per_col.min(rows);
for k in 0..num_to_pick {
rng_state = xorshift_step(rng_state);
let swap_idx = k + (rng_state as usize) % (rows - k);
indices.swap(k, swap_idx);
}
let zero_set: std::collections::HashSet<usize> =
indices[..num_to_pick].iter().copied().collect();
for i in 0..rows {
if zero_set.contains(&i) {
data[i * cols + j] = T::from(0.0).unwrap();
} else {
data[i * cols + j] = T::from(values[i * cols + j]).unwrap();
}
}
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
pub fn dirac_<T: Float>(param: &mut Parameter<T>, groups: usize) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "dirac_ requires at least a 3D tensor (out_ch, in_ch/groups, *kernel_size)"
.into(),
});
}
if groups == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "dirac_: groups must be > 0".into(),
});
}
let out_channels = shape[0];
let in_channels_per_group = shape[1];
if out_channels % groups != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"dirac_: out_channels ({out_channels}) must be divisible by groups ({groups})"
),
});
}
let min_dim = (out_channels / groups).min(in_channels_per_group);
let kernel_size: usize = shape[2..].iter().product();
let center = kernel_size / 2;
let numel = param.numel();
let mut data = vec![T::from(0.0).unwrap(); numel];
let one = T::from(1.0).unwrap();
let in_stride = kernel_size;
let out_stride = in_channels_per_group * kernel_size;
for g in 0..groups {
let out_offset = g * (out_channels / groups);
for d in 0..min_dim {
let out_idx = out_offset + d;
let in_idx = d;
data[out_idx * out_stride + in_idx * in_stride + center] = one;
}
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
pub fn eye_<T: Float>(param: &mut Parameter<T>) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: "eye_ requires a 2D tensor".into(),
});
}
let (rows, cols) = (shape[0], shape[1]);
let zero = T::from(0.0).unwrap();
let one = T::from(1.0).unwrap();
let mut data = vec![zero; rows * cols];
for i in 0..rows.min(cols) {
data[i * cols + i] = one;
}
*param = Parameter::new(Tensor::from_storage(TensorStorage::cpu(data), shape, true)?);
Ok(())
}
fn xorshift_seed() -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let state = hasher.finish();
if state == 0 { 0xdeadbeefcafe } else { state }
}
fn xorshift_step(mut state: u64) -> u64 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
state
}
pub(crate) fn simple_uniform<T: Float>(n: usize, low: f64, high: f64) -> Vec<T> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut state = hasher.finish();
if state == 0 {
state = 0xdeadbeefcafe;
}
let range = high - low;
(0..n)
.map(|_| {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let u = (state as f64) / (u64::MAX as f64);
T::from(low + u * range).unwrap()
})
.collect()
}
pub(crate) fn simple_normal<T: Float>(n: usize, mean: f64, std: f64) -> Vec<T> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut state = hasher.finish();
if state == 0 {
state = 0xdeadbeefcafe;
}
let mut next_uniform = || -> f64 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
((state as f64) / (u64::MAX as f64)).max(1e-300)
};
let mut data = Vec::with_capacity(n);
let mut i = 0;
while i < n {
let u1 = next_uniform();
let u2 = next_uniform();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
data.push(T::from(mean + std * r * theta.cos()).unwrap());
if i + 1 < n {
data.push(T::from(mean + std * r * theta.sin()).unwrap());
}
i += 2;
}
data.truncate(n);
data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros_init() {
let mut p = Parameter::<f32>::ones(&[3, 4]).unwrap();
zeros(&mut p).unwrap();
assert!(p.data().unwrap().iter().all(|&x| x == 0.0));
}
#[test]
fn test_ones_init() {
let mut p = Parameter::<f32>::zeros(&[2, 3]).unwrap();
ones(&mut p).unwrap();
assert!(p.data().unwrap().iter().all(|&x| x == 1.0));
}
#[test]
#[allow(clippy::approx_constant)] fn test_constant_init() {
let mut p = Parameter::<f32>::zeros(&[5]).unwrap();
constant(&mut p, 3.14).unwrap();
assert!(p.data().unwrap().iter().all(|&x| (x - 3.14).abs() < 1e-5));
}
#[test]
fn test_uniform_init_bounds() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
uniform(&mut p, -1.0, 1.0).unwrap();
let data = p.data().unwrap();
assert!(data.iter().all(|&x| x >= -1.0 && x <= 1.0));
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1);
}
#[test]
fn test_normal_init_stats() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
normal(&mut p, 0.0, 1.0).unwrap();
let data = p.data().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "mean = {mean}");
assert!((var - 1.0).abs() < 0.2, "var = {var}");
}
#[test]
fn test_xavier_uniform_stats() {
let mut p = Parameter::<f32>::zeros(&[256, 128]).unwrap();
xavier_uniform(&mut p).unwrap();
let data = p.data().unwrap();
let limit = (6.0 / (128.0 + 256.0) as f64).sqrt() as f32;
assert!(data.iter().all(|&x| x.abs() <= limit + 0.01));
}
#[test]
fn test_xavier_normal_stats() {
let mut p = Parameter::<f32>::zeros(&[256, 128]).unwrap();
xavier_normal(&mut p).unwrap();
let data = p.data().unwrap();
let expected_std = (2.0 / (128.0 + 256.0) as f64).sqrt() as f32;
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.05, "mean = {mean}");
assert!(
(var.sqrt() - expected_std).abs() < expected_std * 0.15,
"std = {}, expected = {expected_std}",
var.sqrt()
);
}
#[test]
fn test_kaiming_uniform_relu() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_uniform(&mut p, NonLinearity::ReLU).unwrap();
let data = p.data().unwrap();
let gain = (2.0f64).sqrt();
let std = gain / (128.0f64).sqrt();
let limit = (3.0f64).sqrt() * std;
assert!(data.iter().all(|&x| (x as f64).abs() <= limit + 0.01));
}
#[test]
fn test_kaiming_normal_relu() {
let mut p = Parameter::<f32>::zeros(&[64, 128]).unwrap();
kaiming_normal(&mut p, NonLinearity::ReLU).unwrap();
let data = p.data().unwrap();
let expected_std = (2.0f64).sqrt() / (128.0f64).sqrt();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.1, "mean = {mean}");
assert!(
((var.sqrt() as f64) - expected_std).abs() < expected_std * 0.2,
"std = {}, expected = {expected_std}",
var.sqrt()
);
}
#[test]
fn test_compute_fans_2d() {
let (fi, fo) = compute_fans(&[64, 128]).unwrap();
assert_eq!(fi, 128);
assert_eq!(fo, 64);
}
#[test]
fn test_compute_fans_4d() {
let (fi, fo) = compute_fans(&[32, 16, 3, 3]).unwrap();
assert_eq!(fi, 16 * 9);
assert_eq!(fo, 32 * 9);
}
#[test]
fn test_nonlinearity_gain() {
assert!((NonLinearity::ReLU.gain() - (2.0f64).sqrt()).abs() < 1e-10);
assert!((NonLinearity::Linear.gain() - 1.0).abs() < 1e-10);
assert!((NonLinearity::Tanh.gain() - 5.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_init_preserves_requires_grad() {
let mut p = Parameter::<f32>::zeros(&[5]).unwrap();
xavier_uniform(&mut p).unwrap();
assert!(p.requires_grad());
}
#[test]
fn test_trunc_normal_bounds() {
let mut p = Parameter::<f32>::zeros(&[10000]).unwrap();
trunc_normal_(&mut p, 0.0, 1.0, -2.0, 2.0).unwrap();
let data = p.data().unwrap();
assert!(
data.iter().all(|&x| x >= -2.0 && x <= 2.0),
"all values must be within [-2, 2]"
);
}
#[test]
fn test_trunc_normal_stats() {
let mut p = Parameter::<f32>::zeros(&[50000]).unwrap();
trunc_normal_(&mut p, 0.0, 1.0, -2.0, 2.0).unwrap();
let data = p.data().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.05, "mean = {mean}");
}
#[test]
fn test_trunc_normal_rejects_bad_bounds() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(trunc_normal_(&mut p, 0.0, 1.0, 2.0, -2.0).is_err());
}
#[test]
fn test_trunc_normal_rejects_zero_std() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(trunc_normal_(&mut p, 0.0, 0.0, -1.0, 1.0).is_err());
}
#[test]
fn test_orthogonal_columns_orthonormal() {
let mut p = Parameter::<f64>::zeros(&[32, 32]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let n = 32;
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * n + i] as f64 * data[k * n + j] as f64;
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-6,
"Q^T Q [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_gain() {
let mut p = Parameter::<f64>::zeros(&[16, 16]).unwrap();
orthogonal_(&mut p, 2.0).unwrap();
let data = p.data().unwrap();
let n = 16;
for i in 0..n {
let mut col_norm_sq = 0.0;
for k in 0..n {
let v = data[k * n + i] as f64;
col_norm_sq += v * v;
}
assert!(
(col_norm_sq - 4.0).abs() < 1e-5,
"column {i} norm^2 = {col_norm_sq}, expected 4.0"
);
}
}
#[test]
fn test_orthogonal_tall_matrix() {
let mut p = Parameter::<f64>::zeros(&[64, 16]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let (n, m) = (64, 16);
for i in 0..m {
for j in 0..m {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * m + i] as f64 * data[k * m + j] as f64;
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-5,
"tall Q^T Q [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_wide_matrix() {
let mut p = Parameter::<f64>::zeros(&[16, 64]).unwrap();
orthogonal_(&mut p, 1.0).unwrap();
let data = p.data().unwrap();
let (n, m) = (16, 64);
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..m {
dot += data[i * m + k] as f64 * data[j * m + k] as f64;
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-5,
"wide Q Q^T [{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn test_orthogonal_rejects_1d() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(orthogonal_(&mut p, 1.0).is_err());
}
#[test]
fn test_sparse_sparsity_ratio() {
let mut p = Parameter::<f32>::zeros(&[100, 50]).unwrap();
sparse_(&mut p, 0.9, 0.01).unwrap();
let data = p.data().unwrap();
let num_zeros = data.iter().filter(|&&x| x == 0.0).count();
let total = data.len();
let actual_sparsity = num_zeros as f64 / total as f64;
assert!(
(actual_sparsity - 0.9).abs() < 0.05,
"sparsity = {actual_sparsity}, expected ~0.9"
);
}
#[test]
fn test_sparse_nonzero_drawn_from_normal() {
let mut p = Parameter::<f32>::zeros(&[200, 100]).unwrap();
sparse_(&mut p, 0.5, 1.0).unwrap();
let data = p.data().unwrap();
let nonzero: Vec<f64> = data
.iter()
.filter(|&&x| x != 0.0)
.map(|&x| x as f64)
.collect();
assert!(!nonzero.is_empty());
let mean: f64 = nonzero.iter().sum::<f64>() / nonzero.len() as f64;
assert!(mean.abs() < 0.15, "nonzero mean = {mean}");
}
#[test]
fn test_sparse_rejects_non_2d() {
let mut p = Parameter::<f32>::zeros(&[10]).unwrap();
assert!(sparse_(&mut p, 0.5, 1.0).is_err());
}
#[test]
fn test_sparse_rejects_bad_sparsity() {
let mut p = Parameter::<f32>::zeros(&[10, 10]).unwrap();
assert!(sparse_(&mut p, 1.0, 1.0).is_err());
assert!(sparse_(&mut p, -0.1, 1.0).is_err());
}
#[test]
fn test_dirac_3d_identity() {
let mut p = Parameter::<f32>::zeros(&[4, 4, 3]).unwrap();
dirac_(&mut p, 1).unwrap();
let data = p.data().unwrap();
let center = 1;
for out_ch in 0..4 {
for in_ch in 0..4 {
let val = data[out_ch * 4 * 3 + in_ch * 3 + center];
if out_ch == in_ch {
assert!((val - 1.0).abs() < 1e-6, "diag [{out_ch},{in_ch}] = {val}");
} else {
assert!(val.abs() < 1e-6, "off-diag [{out_ch},{in_ch}] = {val}");
}
}
}
}
#[test]
fn test_dirac_4d_identity() {
let mut p = Parameter::<f32>::zeros(&[2, 2, 3, 3]).unwrap();
dirac_(&mut p, 1).unwrap();
let data = p.data().unwrap();
let kernel_size = 9;
let center = 4;
for out_ch in 0..2 {
for in_ch in 0..2 {
let val = data[out_ch * 2 * 9 + in_ch * 9 + center];
if out_ch == in_ch {
assert!((val - 1.0).abs() < 1e-6);
} else {
assert!(val.abs() < 1e-6);
}
}
}
}
#[test]
fn test_dirac_groups() {
let mut p = Parameter::<f32>::zeros(&[4, 2, 3]).unwrap();
dirac_(&mut p, 2).unwrap();
let data = p.data().unwrap();
let center = 1;
let idx = |oc: usize, ic: usize, k: usize| oc * 6 + ic * 3 + k;
assert!((data[idx(0, 0, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(1, 1, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(2, 0, center)] - 1.0).abs() < 1e-6);
assert!((data[idx(3, 1, center)] - 1.0).abs() < 1e-6);
}
#[test]
fn test_dirac_rejects_2d() {
let mut p = Parameter::<f32>::zeros(&[4, 4]).unwrap();
assert!(dirac_(&mut p, 1).is_err());
}
#[test]
fn test_eye_square() {
let mut p = Parameter::<f32>::zeros(&[4, 4]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..4 {
for j in 0..4 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(data[i * 4 + j] - expected).abs() < 1e-6,
"eye[{i},{j}] = {}",
data[i * 4 + j]
);
}
}
}
#[test]
fn test_eye_tall() {
let mut p = Parameter::<f32>::zeros(&[6, 3]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..6 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((data[i * 3 + j] - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_eye_wide() {
let mut p = Parameter::<f32>::zeros(&[3, 6]).unwrap();
eye_(&mut p).unwrap();
let data = p.data().unwrap();
for i in 0..3 {
for j in 0..6 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((data[i * 6 + j] - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_eye_rejects_non_2d() {
let mut p = Parameter::<f32>::zeros(&[4]).unwrap();
assert!(eye_(&mut p).is_err());
}
#[test]
fn test_eye_preserves_requires_grad() {
let mut p = Parameter::<f32>::zeros(&[3, 3]).unwrap();
eye_(&mut p).unwrap();
assert!(p.requires_grad());
}
}