use scirs2_core::random::{quick::random_f32, thread_rng};
use scirs2_core::slice_random::shuffle;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::{creation::*, Tensor};
pub trait Initializer {
fn initialize(&self, shape: &[usize]) -> Result<Tensor>;
}
#[derive(Debug, Clone)]
pub enum InitMethod {
XavierUniform { gain: f32 },
XavierNormal { gain: f32 },
KaimingUniform {
mode: FanMode,
nonlinearity: Nonlinearity,
},
KaimingNormal {
mode: FanMode,
nonlinearity: Nonlinearity,
},
Uniform { low: f32, high: f32 },
Normal { mean: f32, std: f32 },
Zeros,
Ones,
Constant { value: f32 },
Orthogonal { gain: f32 },
Sparse { sparsity: f32, std: f32 },
Eye,
LecunUniform,
LecunNormal,
TruncatedNormal { mean: f32, std: f32, a: f32, b: f32 },
VarianceScaling {
scale: f32,
mode: FanMode,
distribution: Distribution,
},
Dirac,
SIREN { c: f32, w0: f32 },
}
#[derive(Debug, Clone, Copy)]
pub enum Distribution {
Uniform,
Normal,
TruncatedNormal,
}
#[derive(Debug, Clone, Copy)]
pub enum FanMode {
FanIn,
FanOut,
FanAvg,
}
#[derive(Debug, Clone, Copy)]
pub enum Nonlinearity {
ReLU,
LeakyReLU { negative_slope: f32 },
Tanh,
Sigmoid,
SELU,
ELU,
Swish,
Linear,
}
impl Nonlinearity {
pub fn gain(&self) -> f32 {
match self {
Nonlinearity::ReLU => (2.0_f32).sqrt(),
Nonlinearity::LeakyReLU { negative_slope } => {
(2.0 / (1.0 + negative_slope.powi(2))).sqrt()
}
Nonlinearity::Tanh => (5.0_f32 / 3.0_f32).sqrt(),
Nonlinearity::Sigmoid => 1.0,
Nonlinearity::SELU => (3.0_f32 / 4.0_f32).sqrt(),
Nonlinearity::ELU => (5.0_f32 / 3.0_f32).sqrt(),
Nonlinearity::Swish => (2.0_f32).sqrt(),
Nonlinearity::Linear => 1.0,
}
}
}
impl InitMethod {
pub fn xavier_uniform() -> Self {
InitMethod::XavierUniform { gain: 1.0 }
}
pub fn xavier_normal() -> Self {
InitMethod::XavierNormal { gain: 1.0 }
}
pub fn kaiming_uniform() -> Self {
InitMethod::KaimingUniform {
mode: FanMode::FanIn,
nonlinearity: Nonlinearity::ReLU,
}
}
pub fn kaiming_normal() -> Self {
InitMethod::KaimingNormal {
mode: FanMode::FanIn,
nonlinearity: Nonlinearity::ReLU,
}
}
pub fn uniform_range(low: f32, high: f32) -> Self {
InitMethod::Uniform { low, high }
}
pub fn normal_dist(mean: f32, std: f32) -> Self {
InitMethod::Normal { mean, std }
}
pub fn zeros() -> Self {
InitMethod::Zeros
}
pub fn ones() -> Self {
InitMethod::Ones
}
pub fn constant(value: f32) -> Self {
InitMethod::Constant { value }
}
pub fn orthogonal() -> Self {
InitMethod::Orthogonal { gain: 1.0 }
}
pub fn lecun_uniform() -> Self {
InitMethod::LecunUniform
}
pub fn lecun_normal() -> Self {
InitMethod::LecunNormal
}
pub fn dirac() -> Self {
InitMethod::Dirac
}
pub fn siren_first_layer() -> Self {
InitMethod::SIREN { c: 1.0, w0: 30.0 }
}
pub fn siren_hidden_layer() -> Self {
InitMethod::SIREN { c: 6.0, w0: 1.0 }
}
pub fn with_gain(self, gain: f32) -> Self {
match self {
InitMethod::XavierUniform { .. } => InitMethod::XavierUniform { gain },
InitMethod::XavierNormal { .. } => InitMethod::XavierNormal { gain },
InitMethod::Orthogonal { .. } => InitMethod::Orthogonal { gain },
other => other,
}
}
pub fn with_fan_mode(self, mode: FanMode) -> Self {
match self {
InitMethod::KaimingUniform {
nonlinearity,
mode: _,
} => InitMethod::KaimingUniform { mode, nonlinearity },
InitMethod::KaimingNormal {
nonlinearity,
mode: _,
} => InitMethod::KaimingNormal { mode, nonlinearity },
InitMethod::VarianceScaling {
scale,
distribution,
mode: _,
} => InitMethod::VarianceScaling {
scale,
mode,
distribution,
},
other => other,
}
}
pub fn with_nonlinearity(self, nonlinearity: Nonlinearity) -> Self {
match self {
InitMethod::KaimingUniform { mode, .. } => {
InitMethod::KaimingUniform { mode, nonlinearity }
}
InitMethod::KaimingNormal { mode, .. } => {
InitMethod::KaimingNormal { mode, nonlinearity }
}
other => other,
}
}
pub fn name(&self) -> &str {
match self {
InitMethod::XavierUniform { .. } => "Xavier Uniform",
InitMethod::XavierNormal { .. } => "Xavier Normal",
InitMethod::KaimingUniform { .. } => "Kaiming Uniform",
InitMethod::KaimingNormal { .. } => "Kaiming Normal",
InitMethod::Uniform { .. } => "Uniform",
InitMethod::Normal { .. } => "Normal",
InitMethod::Zeros => "Zeros",
InitMethod::Ones => "Ones",
InitMethod::Constant { .. } => "Constant",
InitMethod::Orthogonal { .. } => "Orthogonal",
InitMethod::Sparse { .. } => "Sparse",
InitMethod::Eye => "Eye/Identity",
InitMethod::LecunUniform => "LeCun Uniform",
InitMethod::LecunNormal => "LeCun Normal",
InitMethod::TruncatedNormal { .. } => "Truncated Normal",
InitMethod::VarianceScaling { .. } => "Variance Scaling",
InitMethod::Dirac => "Dirac",
InitMethod::SIREN { .. } => "SIREN",
}
}
}
impl Initializer for InitMethod {
fn initialize(&self, shape: &[usize]) -> Result<Tensor> {
match self {
InitMethod::XavierUniform { gain } => xavier_uniform_with_gain(shape, *gain),
InitMethod::XavierNormal { gain } => xavier_normal_with_gain(shape, *gain),
InitMethod::KaimingUniform { mode, nonlinearity } => {
kaiming_uniform_with_nonlinearity(shape, *mode, *nonlinearity)
}
InitMethod::KaimingNormal { mode, nonlinearity } => {
kaiming_normal_with_nonlinearity(shape, *mode, *nonlinearity)
}
InitMethod::Uniform { low, high } => uniform(shape, *low, *high),
InitMethod::Normal { mean, std } => normal(shape, *mean, *std),
InitMethod::Zeros => zeros(shape),
InitMethod::Ones => ones(shape),
InitMethod::Constant { value } => constant(shape, *value),
InitMethod::Orthogonal { gain } => orthogonal_init(shape, *gain),
InitMethod::Sparse { sparsity, std } => sparse_init(shape, *sparsity, *std),
InitMethod::Eye => eye_init_tensor(shape),
InitMethod::LecunUniform => lecun_uniform(shape),
InitMethod::LecunNormal => lecun_normal(shape),
InitMethod::TruncatedNormal { mean, std, a, b } => {
truncated_normal(shape, *mean, *std, *a, *b)
}
InitMethod::VarianceScaling {
scale,
mode,
distribution,
} => variance_scaling(shape, *scale, *mode, *distribution),
InitMethod::Dirac => dirac_init(shape),
InitMethod::SIREN { c, w0 } => siren_init(shape, *c, *w0),
}
}
}
pub fn constant(shape: &[usize], value: f32) -> Result<Tensor> {
let size = shape.iter().product();
let values = vec![value; size];
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create constant tensor: {}", e)))
}
pub fn init(method: InitMethod) -> impl Initializer {
method
}
pub fn calculate_fan_in_fan_out(shape: &[usize]) -> Result<(usize, usize)> {
let dimensions = shape.len();
if dimensions < 2 {
return Err(TorshError::InvalidArgument(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
.to_string(),
));
}
let num_input_fmaps = shape[1];
let num_output_fmaps = shape[0];
let mut receptive_field_size = 1;
if dimensions > 2 {
for &size in shape.iter().skip(2).take(dimensions - 2) {
receptive_field_size *= size;
}
}
let fan_in = num_input_fmaps * receptive_field_size;
let fan_out = num_output_fmaps * receptive_field_size;
Ok((fan_in, fan_out))
}
pub fn calculate_fan(shape: &[usize], mode: FanMode) -> Result<usize> {
let (fan_in, fan_out) = calculate_fan_in_fan_out(shape)?;
match mode {
FanMode::FanIn => Ok(fan_in),
FanMode::FanOut => Ok(fan_out),
FanMode::FanAvg => Ok((fan_in + fan_out) / 2),
}
}
pub fn xavier_uniform(shape: &[usize]) -> Result<Tensor> {
xavier_uniform_with_gain(shape, 1.0)
}
pub fn xavier_uniform_with_gain(shape: &[usize], gain: f32) -> Result<Tensor> {
let (fan_in, fan_out) = calculate_fan_in_fan_out(shape)?;
let std = gain * (2.0 / (fan_in + fan_out) as f32).sqrt();
let bound = std * 3.0_f32.sqrt();
uniform(shape, -bound, bound)
}
pub fn xavier_normal(shape: &[usize]) -> Result<Tensor> {
xavier_normal_with_gain(shape, 1.0)
}
pub fn xavier_normal_with_gain(shape: &[usize], gain: f32) -> Result<Tensor> {
let (fan_in, fan_out) = calculate_fan_in_fan_out(shape)?;
let std = gain * (2.0 / (fan_in + fan_out) as f32).sqrt();
normal(shape, 0.0, std)
}
pub fn kaiming_uniform(shape: &[usize], mode: &str) -> Result<Tensor> {
let fan_mode = match mode {
"fan_in" => FanMode::FanIn,
"fan_out" => FanMode::FanOut,
"fan_avg" => FanMode::FanAvg,
_ => {
return Err(TorshError::InvalidArgument(format!(
"Mode {} not supported, please use one of 'fan_in', 'fan_out', or 'fan_avg'.",
mode
)))
}
};
kaiming_uniform_with_nonlinearity(shape, fan_mode, Nonlinearity::ReLU)
}
pub fn kaiming_uniform_with_nonlinearity(
shape: &[usize],
mode: FanMode,
nonlinearity: Nonlinearity,
) -> Result<Tensor> {
let fan = calculate_fan(shape, mode)?;
let gain = nonlinearity.gain();
let std = gain / (fan as f32).sqrt();
let bound = std * 3.0_f32.sqrt();
uniform(shape, -bound, bound)
}
pub fn kaiming_normal(shape: &[usize], mode: &str) -> Result<Tensor> {
let fan_mode = match mode {
"fan_in" => FanMode::FanIn,
"fan_out" => FanMode::FanOut,
"fan_avg" => FanMode::FanAvg,
_ => {
return Err(TorshError::InvalidArgument(format!(
"Mode {} not supported, please use one of 'fan_in', 'fan_out', or 'fan_avg'.",
mode
)))
}
};
kaiming_normal_with_nonlinearity(shape, fan_mode, Nonlinearity::ReLU)
}
pub fn kaiming_normal_with_nonlinearity(
shape: &[usize],
mode: FanMode,
nonlinearity: Nonlinearity,
) -> Result<Tensor> {
let fan = calculate_fan(shape, mode)?;
let gain = nonlinearity.gain();
let std = gain / (fan as f32).sqrt();
normal(shape, 0.0, std)
}
pub fn uniform(shape: &[usize], low: f32, high: f32) -> Result<Tensor> {
if low >= high {
return Err(TorshError::InvalidArgument(
"Low bound must be less than high bound for uniform initialization".to_string(),
));
}
let size = shape.iter().product();
let range = high - low;
let values: Vec<f32> = (0..size).map(|_| low + random_f32() * range).collect();
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create uniform tensor: {}", e)))
}
pub fn normal(shape: &[usize], mean: f32, std: f32) -> Result<Tensor> {
if std <= 0.0 {
return Err(TorshError::InvalidArgument(
"Standard deviation must be positive for normal initialization".to_string(),
));
}
let size = shape.iter().product();
let values: Vec<f32> = (0..size)
.map(|_| {
let u1 = random_f32();
let u2 = random_f32();
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
mean + z0 * std
})
.collect();
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create normal tensor: {}", e)))
}
pub fn lecun_uniform(shape: &[usize]) -> Result<Tensor> {
let fan_in = calculate_fan(shape, FanMode::FanIn)?;
let limit = (3.0 / fan_in as f32).sqrt();
uniform(shape, -limit, limit)
}
pub fn lecun_normal(shape: &[usize]) -> Result<Tensor> {
let fan_in = calculate_fan(shape, FanMode::FanIn)?;
let std = (1.0 / fan_in as f32).sqrt();
normal(shape, 0.0, std)
}
pub fn truncated_normal(shape: &[usize], mean: f32, std: f32, a: f32, b: f32) -> Result<Tensor> {
if std <= 0.0 {
return Err(TorshError::InvalidArgument(
"Standard deviation must be positive for truncated normal initialization".to_string(),
));
}
if a >= b {
return Err(TorshError::InvalidArgument(
"Lower bound must be less than upper bound for truncated normal initialization"
.to_string(),
));
}
let size = shape.iter().product();
let mut values = Vec::with_capacity(size);
for _ in 0..size {
loop {
let u1 = random_f32();
let u2 = random_f32();
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
let sample = mean + z0 * std;
if sample >= a && sample <= b {
values.push(sample);
break;
}
}
}
Tensor::from_vec(values, shape).map_err(|e| {
TorshError::RuntimeError(format!("Failed to create truncated normal tensor: {}", e))
})
}
pub fn eye_init(n: usize) -> Result<Tensor> {
eye(n).map_err(|e| TorshError::RuntimeError(format!("Failed to create eye tensor: {}", e)))
}
pub fn eye_init_tensor(shape: &[usize]) -> Result<Tensor> {
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"Eye initialization requires at least 2D tensor".to_string(),
));
}
let rows = shape[0];
let cols = shape[1];
if rows != cols {
return Err(TorshError::InvalidArgument(
"Eye initialization requires square matrices (rows == cols)".to_string(),
));
}
eye_init(rows)
}
pub fn orthogonal_init(shape: &[usize], gain: f32) -> Result<Tensor> {
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"Orthogonal initialization requires at least 2D tensor".to_string(),
));
}
let num_rows = shape[0];
let num_cols = shape[1];
let (qr_rows, qr_cols) = if num_rows < num_cols {
(num_cols, num_rows)
} else {
(num_rows, num_cols)
};
let random_tensor = normal(&[qr_rows, qr_cols], 0.0, 1.0)?;
let (q, _r) = torsh_linalg::decomposition::qr(&random_tensor)?;
let orthogonal_tensor = if num_rows < num_cols {
let mut values = Vec::with_capacity(num_rows * num_cols);
for col in 0..num_cols {
for row in 0..num_rows {
values.push(q.get(&[col, row])?);
}
}
Tensor::from_vec(values, &[num_rows, num_cols])?
} else {
let mut values = Vec::with_capacity(num_rows * num_cols);
for row in 0..num_rows {
for col in 0..num_cols {
values.push(q.get(&[row, col])?);
}
}
Tensor::from_vec(values, &[num_rows, num_cols])?
};
if (gain - 1.0).abs() > 1e-6 {
let values: Vec<f32> = orthogonal_tensor
.to_vec()?
.iter()
.map(|&v| v * gain)
.collect();
Tensor::from_vec(values, shape)
} else {
Ok(orthogonal_tensor)
}
}
pub fn sparse_init(shape: &[usize], sparsity: f32, std: f32) -> Result<Tensor> {
if shape.len() != 2 {
return Err(TorshError::InvalidArgument(
"Only tensors with 2 dimensions are supported for sparse initialization".to_string(),
));
}
if !(0.0..=1.0).contains(&sparsity) {
return Err(TorshError::InvalidArgument(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
let rows = shape[0];
let cols = shape[1];
let total_elements = rows * cols;
let num_zeros = (total_elements as f32 * sparsity) as usize;
let mut values = Vec::with_capacity(total_elements);
let _rng = thread_rng();
for _ in 0..total_elements {
let u1 = random_f32();
let u2 = random_f32();
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
values.push(z0 * std); }
let mut indices: Vec<usize> = (0..total_elements).collect();
shuffle(&mut indices);
for &idx in indices.iter().take(num_zeros) {
values[idx] = 0.0;
}
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create sparse tensor: {}", e)))
}
pub fn init_tensor(
tensor: &mut Tensor,
method: &str,
gain: Option<f32>,
mode: Option<&str>,
) -> Result<()> {
let binding = tensor.shape();
let shape = binding.dims();
let gain = gain.unwrap_or(1.0);
let mode = mode.unwrap_or("fan_in");
let initialized = match method {
"xavier_uniform" | "glorot_uniform" => xavier_uniform_with_gain(shape, gain),
"xavier_normal" | "glorot_normal" => xavier_normal_with_gain(shape, gain),
"kaiming_uniform" | "he_uniform" => kaiming_uniform(shape, mode),
"kaiming_normal" | "he_normal" => kaiming_normal(shape, mode),
"orthogonal" => orthogonal_init(shape, gain),
"lecun_uniform" => lecun_uniform(shape),
"lecun_normal" => lecun_normal(shape),
"zeros" => zeros(shape),
"ones" => ones(shape),
"eye" => eye_init_tensor(shape),
_ => {
return Err(TorshError::InvalidArgument(format!(
"Unknown initialization method: {}",
method
)))
}
}?;
*tensor = initialized;
Ok(())
}
pub trait Initializable {
fn reset_parameters(&mut self);
}
pub fn variance_scaling(
shape: &[usize],
scale: f32,
mode: FanMode,
distribution: Distribution,
) -> Result<Tensor> {
let fan = calculate_fan(shape, mode)?;
let variance = scale / fan as f32;
match distribution {
Distribution::Uniform => {
let limit = (3.0 * variance).sqrt();
uniform(shape, -limit, limit)
}
Distribution::Normal => {
let std = variance.sqrt();
normal(shape, 0.0, std)
}
Distribution::TruncatedNormal => {
let std = variance.sqrt();
truncated_normal(shape, 0.0, std, -2.0 * std, 2.0 * std)
}
}
}
pub fn dirac_init(shape: &[usize]) -> Result<Tensor> {
if shape.len() < 3 {
return Err(TorshError::InvalidArgument(
"Dirac initialization requires at least 3D tensor (out_channels, in_channels, kernel_size)".to_string(),
));
}
let out_channels = shape[0];
let in_channels = shape[1];
let total_size: usize = shape.iter().product();
let kernel_spatial_size: usize = shape[2..].iter().product();
let mut values = vec![0.0_f32; total_size];
let mut center_offset = 0;
let mut stride = 1;
for &dim_size in shape[2..].iter().rev() {
center_offset += (dim_size / 2) * stride;
stride *= dim_size;
}
let min_channels = out_channels.min(in_channels);
for i in 0..min_channels {
let idx = i * in_channels * kernel_spatial_size + i * kernel_spatial_size + center_offset;
if idx < total_size {
values[idx] = 1.0;
}
}
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create Dirac tensor: {}", e)))
}
pub fn siren_init(shape: &[usize], c: f32, w0: f32) -> Result<Tensor> {
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"SIREN initialization requires at least 2D tensor".to_string(),
));
}
let fan_in = calculate_fan(shape, FanMode::FanIn)?;
let bound = if (w0 - 1.0).abs() < 1e-6 {
(c / fan_in as f32).sqrt()
} else {
1.0 / fan_in as f32
};
let mut tensor = uniform(shape, -bound, bound)?;
if (w0 - 1.0).abs() > 1e-6 {
let values: Vec<f32> = tensor.to_vec()?.iter().map(|&v| v * w0).collect();
tensor = Tensor::from_vec(values, shape)?;
}
Ok(tensor)
}
pub fn fixup_init(
shape: &[usize],
num_layers: usize,
num_residual_blocks: usize,
is_residual_branch: bool,
) -> Result<Tensor> {
let mut tensor = kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)?;
if is_residual_branch && num_residual_blocks > 1 {
let exponent = -1.0 / (2.0 * num_residual_blocks as f32 - 2.0);
let scale = (2.0 * num_layers as f32).powf(exponent);
let values: Vec<f32> = tensor.to_vec()?.iter().map(|&v| v * scale).collect();
tensor = Tensor::from_vec(values, shape).map_err(|e| {
TorshError::RuntimeError(format!("Failed to create Fixup tensor: {}", e))
})?;
}
Ok(tensor)
}
pub fn rezero_init(shape: &[usize]) -> Result<Tensor> {
kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)
}
pub fn rezero_alpha_init() -> Result<Tensor> {
Tensor::from_vec(vec![0.0_f32], &[1])
.map_err(|e| TorshError::RuntimeError(format!("Failed to create ReZero alpha: {}", e)))
}
pub fn delta_orthogonal_init(shape: &[usize], gain: f32) -> Result<Tensor> {
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"Delta-Orthogonal initialization requires at least 2D tensor".to_string(),
));
}
orthogonal_init(shape, gain)
}
pub fn metainit(shape: &[usize], sparsity: f32, scale: f32) -> Result<Tensor> {
if sparsity < 0.0 || sparsity >= 1.0 {
return Err(TorshError::InvalidArgument(format!(
"Sparsity must be in [0, 1), got {}",
sparsity
)));
}
if scale <= 0.0 {
return Err(TorshError::InvalidArgument(format!(
"Scale must be positive, got {}",
scale
)));
}
let size = shape.iter().product();
let mut values = Vec::with_capacity(size);
for _ in 0..size {
if random_f32() < sparsity {
values.push(0.0);
} else {
let sign = if random_f32() < 0.5 { -1.0 } else { 1.0 };
values.push(sign * scale * random_f32());
}
}
Tensor::from_vec(values, shape)
.map_err(|e| TorshError::RuntimeError(format!("Failed to create MetaInit tensor: {}", e)))
}
pub fn lsuv_init(shape: &[usize]) -> Result<Tensor> {
orthogonal_init(shape, 1.0)
}
pub fn zero_centered_variance_init(shape: &[usize], target_variance: f32) -> Result<Tensor> {
if target_variance <= 0.0 {
return Err(TorshError::InvalidArgument(format!(
"Target variance must be positive, got {}",
target_variance
)));
}
let std = target_variance.sqrt();
normal(shape, 0.0, std)
}
pub fn gan_balanced_init(shape: &[usize], is_generator: bool) -> Result<Tensor> {
let gain = if is_generator { 0.5 } else { 1.0 };
let fan_in = calculate_fan(shape, FanMode::FanIn)?;
let fan_out = calculate_fan(shape, FanMode::FanOut)?;
let fan_avg = (fan_in + fan_out) / 2;
let std = gain * (2.0 / fan_avg as f32).sqrt();
normal(shape, 0.0, std)
}
pub fn coordinate_mlp_init(shape: &[usize], omega_0: f32) -> Result<Tensor> {
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"Coordinate MLP initialization requires at least 2D tensor".to_string(),
));
}
let fan_in = calculate_fan(shape, FanMode::FanIn)?;
let std = 1.0 / (fan_in as f32 * omega_0).sqrt();
normal(shape, 0.0, std)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArchitectureHint {
Feedforward,
Convolutional,
Recurrent,
Transformer,
Residual,
VeryDeep,
GAN,
CoordinateBased,
Periodic,
Autoencoder,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ActivationHint {
ReLU,
Tanh,
Sigmoid,
SELU,
Swish,
GELU,
Sine,
Linear,
}
pub fn auto_init(
shape: &[usize],
arch: ArchitectureHint,
activation: ActivationHint,
layer_depth: Option<usize>,
) -> Result<Tensor> {
match (arch, activation) {
(ArchitectureHint::Periodic, ActivationHint::Sine) | (_, ActivationHint::Sine) => {
let is_first_layer = layer_depth.unwrap_or(0) == 0;
if is_first_layer {
siren_init(shape, 1.0, 30.0)
} else {
siren_init(shape, 6.0, 1.0)
}
}
(ArchitectureHint::CoordinateBased, _) => coordinate_mlp_init(shape, 1.0),
(ArchitectureHint::VeryDeep, ActivationHint::ReLU) => {
if let Some(depth) = layer_depth {
fixup_init(shape, depth, depth / 2, true)
} else {
kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)
}
}
(ArchitectureHint::Residual, _) => {
if layer_depth.is_some() {
rezero_init(shape)
} else {
kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)
}
}
(ArchitectureHint::Recurrent, _) => orthogonal_init(shape, 1.0),
(ArchitectureHint::Transformer, _) => xavier_uniform(shape),
(ArchitectureHint::GAN, _) => {
gan_balanced_init(shape, true)
}
(ArchitectureHint::Convolutional, ActivationHint::ReLU) => {
kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)
}
(ArchitectureHint::Convolutional, ActivationHint::Tanh) => xavier_normal(shape),
(ArchitectureHint::Convolutional, ActivationHint::Sigmoid) => xavier_normal(shape),
(ArchitectureHint::Convolutional, ActivationHint::SELU) => lecun_normal(shape),
(ArchitectureHint::Convolutional, ActivationHint::Swish)
| (ArchitectureHint::Convolutional, ActivationHint::GELU) => {
kaiming_normal_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::Swish)
}
(ArchitectureHint::Feedforward, ActivationHint::ReLU)
| (ArchitectureHint::Autoencoder, ActivationHint::ReLU) => {
kaiming_uniform_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::ReLU)
}
(ArchitectureHint::Feedforward, ActivationHint::Tanh)
| (ArchitectureHint::Autoencoder, ActivationHint::Tanh) => xavier_uniform(shape),
(ArchitectureHint::Feedforward, ActivationHint::Sigmoid)
| (ArchitectureHint::Autoencoder, ActivationHint::Sigmoid) => xavier_uniform(shape),
(ArchitectureHint::Feedforward, ActivationHint::SELU)
| (ArchitectureHint::Autoencoder, ActivationHint::SELU) => lecun_uniform(shape),
(ArchitectureHint::Feedforward, ActivationHint::Swish)
| (ArchitectureHint::Feedforward, ActivationHint::GELU)
| (ArchitectureHint::Autoencoder, ActivationHint::Swish)
| (ArchitectureHint::Autoencoder, ActivationHint::GELU) => {
kaiming_uniform_with_nonlinearity(shape, FanMode::FanIn, Nonlinearity::Swish)
}
(ArchitectureHint::Feedforward, ActivationHint::Linear) | (_, ActivationHint::Linear) => {
xavier_uniform(shape)
}
_ => xavier_uniform(shape),
}
}
pub fn recommend_init_method(
arch: ArchitectureHint,
activation: ActivationHint,
layer_depth: Option<usize>,
) -> InitMethod {
match (arch, activation) {
(ArchitectureHint::Periodic, ActivationHint::Sine) | (_, ActivationHint::Sine) => {
let is_first_layer = layer_depth.unwrap_or(0) == 0;
if is_first_layer {
InitMethod::SIREN { c: 1.0, w0: 30.0 }
} else {
InitMethod::SIREN { c: 6.0, w0: 1.0 }
}
}
(ArchitectureHint::VeryDeep, ActivationHint::ReLU)
| (ArchitectureHint::Residual, ActivationHint::ReLU) => InitMethod::KaimingNormal {
mode: FanMode::FanIn,
nonlinearity: Nonlinearity::ReLU,
},
(ArchitectureHint::Recurrent, _) => InitMethod::Orthogonal { gain: 1.0 },
(ArchitectureHint::Transformer, _) => InitMethod::XavierUniform { gain: 1.0 },
(ArchitectureHint::Convolutional, ActivationHint::ReLU) => InitMethod::KaimingNormal {
mode: FanMode::FanIn,
nonlinearity: Nonlinearity::ReLU,
},
(_, ActivationHint::SELU) => InitMethod::LecunNormal,
(_, ActivationHint::Tanh) | (_, ActivationHint::Sigmoid) => {
InitMethod::XavierUniform { gain: 1.0 }
}
_ => InitMethod::XavierUniform { gain: 1.0 },
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fan_calculation() {
let (fan_in, fan_out) = calculate_fan_in_fan_out(&[64, 32, 3, 3]).unwrap();
assert_eq!(fan_in, 32 * 3 * 3);
assert_eq!(fan_out, 64 * 3 * 3);
}
#[test]
fn test_xavier_uniform() {
let tensor = xavier_uniform(&[10, 5]).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
}
#[test]
fn test_init_method_enum() {
let method = InitMethod::XavierUniform { gain: 1.0 };
let tensor = method.initialize(&[5, 3]).unwrap();
assert_eq!(tensor.shape().dims(), &[5, 3]);
}
#[test]
fn test_nonlinearity_gains() {
assert!((Nonlinearity::ReLU.gain() - (2.0_f32).sqrt()).abs() < 1e-6);
assert!((Nonlinearity::Linear.gain() - 1.0).abs() < 1e-6);
assert!(
(Nonlinearity::LeakyReLU {
negative_slope: 0.01
}
.gain()
- (2.0 / (1.0 + 0.01_f32.powi(2))).sqrt())
.abs()
< 1e-6
);
}
#[test]
fn test_sparse_initialization() {
let tensor = sparse_init(&[10, 10], 0.5, 1.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
assert!(sparse_init(&[10, 10], 1.5, 1.0).is_err());
assert!(sparse_init(&[10, 10], -0.1, 1.0).is_err());
}
#[test]
fn test_variance_scaling() {
let tensor =
variance_scaling(&[10, 5], 2.0, FanMode::FanIn, Distribution::Uniform).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let tensor = variance_scaling(&[10, 5], 2.0, FanMode::FanIn, Distribution::Normal).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let tensor =
variance_scaling(&[10, 5], 2.0, FanMode::FanIn, Distribution::TruncatedNormal).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
}
#[test]
fn test_dirac_initialization() {
let tensor = dirac_init(&[16, 16, 3]).unwrap();
assert_eq!(tensor.shape().dims(), &[16, 16, 3]);
assert!(dirac_init(&[10, 10]).is_err());
}
#[test]
fn test_siren_initialization() {
let tensor = siren_init(&[10, 5], 1.0, 30.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let tensor = siren_init(&[10, 5], 6.0, 1.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
assert!(siren_init(&[10], 6.0, 1.0).is_err());
}
#[test]
fn test_init_method_builders() {
let method = InitMethod::xavier_uniform();
assert_eq!(method.name(), "Xavier Uniform");
let method = InitMethod::kaiming_normal().with_fan_mode(FanMode::FanOut);
assert_eq!(method.name(), "Kaiming Normal");
let method = InitMethod::orthogonal().with_gain(2.0);
assert_eq!(method.name(), "Orthogonal");
let method = InitMethod::siren_first_layer();
assert_eq!(method.name(), "SIREN");
let method = InitMethod::dirac();
assert_eq!(method.name(), "Dirac");
}
#[test]
fn test_init_method_enum_variants() {
let method = InitMethod::VarianceScaling {
scale: 2.0,
mode: FanMode::FanIn,
distribution: Distribution::Normal,
};
let tensor = method.initialize(&[10, 5]).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let method = InitMethod::Dirac;
let tensor = method.initialize(&[8, 8, 3]).unwrap();
assert_eq!(tensor.shape().dims(), &[8, 8, 3]);
let method = InitMethod::SIREN { c: 6.0, w0: 1.0 };
let tensor = method.initialize(&[10, 5]).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
}
#[test]
fn test_fixup_initialization() {
let tensor = fixup_init(&[10, 10], 50, 10, true).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let tensor = fixup_init(&[10, 10], 50, 10, false).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let tensor = fixup_init(&[5, 5], 2, 1, true).unwrap();
assert_eq!(tensor.shape().dims(), &[5, 5]);
}
#[test]
fn test_rezero_initialization() {
let tensor = rezero_init(&[10, 5]).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let alpha = rezero_alpha_init().unwrap();
assert_eq!(alpha.shape().dims(), &[1]);
let alpha_val: Vec<f32> = alpha
.to_vec()
.expect("tensor to vec conversion should succeed");
assert_eq!(alpha_val[0], 0.0);
}
#[test]
fn test_delta_orthogonal_initialization() {
let tensor = delta_orthogonal_init(&[10, 10], 1.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let tensor = delta_orthogonal_init(&[8, 8], 2.0).unwrap();
assert_eq!(tensor.shape().dims(), &[8, 8]);
assert!(delta_orthogonal_init(&[10], 1.0).is_err());
}
#[test]
fn test_metainit() {
let tensor = metainit(&[10, 10], 0.8, 0.05).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let values: Vec<f32> = tensor
.to_vec()
.expect("tensor to vec conversion should succeed");
let zero_count = values.iter().filter(|&&v| v == 0.0).count();
let sparsity_ratio = zero_count as f32 / values.len() as f32;
assert!(sparsity_ratio > 0.6 && sparsity_ratio < 0.95);
assert!(metainit(&[10, 10], 1.5, 0.05).is_err()); assert!(metainit(&[10, 10], -0.1, 0.05).is_err()); assert!(metainit(&[10, 10], 0.8, -0.05).is_err()); }
#[test]
fn test_lsuv_initialization() {
let tensor = lsuv_init(&[10, 10]).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let tensor = lsuv_init(&[64, 32]).unwrap();
assert_eq!(tensor.shape().dims(), &[64, 32]);
}
#[test]
fn test_zero_centered_variance_init() {
let tensor = zero_centered_variance_init(&[10, 10], 1.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 10]);
let tensor = zero_centered_variance_init(&[20, 20], 0.5).unwrap();
assert_eq!(tensor.shape().dims(), &[20, 20]);
assert!(zero_centered_variance_init(&[10, 10], 0.0).is_err());
assert!(zero_centered_variance_init(&[10, 10], -1.0).is_err());
}
#[test]
fn test_gan_balanced_initialization() {
let gen_tensor = gan_balanced_init(&[10, 10], true).unwrap();
assert_eq!(gen_tensor.shape().dims(), &[10, 10]);
let disc_tensor = gan_balanced_init(&[10, 10], false).unwrap();
assert_eq!(disc_tensor.shape().dims(), &[10, 10]);
}
#[test]
fn test_coordinate_mlp_initialization() {
let tensor = coordinate_mlp_init(&[10, 3], 1.0).unwrap();
assert_eq!(tensor.shape().dims(), &[10, 3]);
let tensor = coordinate_mlp_init(&[64, 32], 30.0).unwrap();
assert_eq!(tensor.shape().dims(), &[64, 32]);
assert!(coordinate_mlp_init(&[10], 1.0).is_err());
}
#[test]
fn test_auto_init() {
let tensor = auto_init(
&[10, 5],
ArchitectureHint::Feedforward,
ActivationHint::ReLU,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[10, 5]);
let tensor = auto_init(
&[64, 32, 3, 3],
ArchitectureHint::Convolutional,
ActivationHint::ReLU,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[64, 32, 3, 3]);
let tensor = auto_init(
&[128, 256],
ArchitectureHint::Recurrent,
ActivationHint::Tanh,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[128, 256]);
let tensor = auto_init(
&[512, 512],
ArchitectureHint::Transformer,
ActivationHint::GELU,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[512, 512]);
let tensor = auto_init(
&[32, 16],
ArchitectureHint::Periodic,
ActivationHint::Sine,
Some(0),
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[32, 16]);
let tensor = auto_init(
&[256, 256],
ArchitectureHint::VeryDeep,
ActivationHint::ReLU,
Some(100),
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[256, 256]);
let tensor = auto_init(
&[100, 784],
ArchitectureHint::GAN,
ActivationHint::ReLU,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[100, 784]);
let tensor = auto_init(
&[64, 3],
ArchitectureHint::CoordinateBased,
ActivationHint::ReLU,
None,
)
.unwrap();
assert_eq!(tensor.shape().dims(), &[64, 3]);
}
#[test]
fn test_recommend_init_method() {
let method =
recommend_init_method(ArchitectureHint::Feedforward, ActivationHint::ReLU, None);
matches!(method, InitMethod::KaimingNormal { .. });
let method =
recommend_init_method(ArchitectureHint::Transformer, ActivationHint::GELU, None);
matches!(method, InitMethod::XavierUniform { .. });
let method = recommend_init_method(ArchitectureHint::Recurrent, ActivationHint::Tanh, None);
matches!(method, InitMethod::Orthogonal { .. });
let method =
recommend_init_method(ArchitectureHint::Periodic, ActivationHint::Sine, Some(0));
matches!(method, InitMethod::SIREN { .. });
let method =
recommend_init_method(ArchitectureHint::Feedforward, ActivationHint::SELU, None);
matches!(method, InitMethod::LecunNormal);
}
#[test]
fn test_architecture_hints() {
assert_ne!(
ArchitectureHint::Feedforward,
ArchitectureHint::Convolutional
);
assert_ne!(ArchitectureHint::Recurrent, ArchitectureHint::Transformer);
assert_ne!(ArchitectureHint::Residual, ArchitectureHint::VeryDeep);
assert_ne!(ArchitectureHint::GAN, ArchitectureHint::CoordinateBased);
assert_ne!(ArchitectureHint::Periodic, ArchitectureHint::Autoencoder);
}
#[test]
fn test_activation_hints() {
assert_ne!(ActivationHint::ReLU, ActivationHint::Tanh);
assert_ne!(ActivationHint::Sigmoid, ActivationHint::SELU);
assert_ne!(ActivationHint::Swish, ActivationHint::GELU);
assert_ne!(ActivationHint::Sine, ActivationHint::Linear);
}
}