macro_rules! normalization_layer_output_shape {
($self:expr) => {
if !$self.input_shape.is_empty() {
format!(
"({})",
$self
.input_shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ")
)
} else {
String::from("Unknown")
}
};
}
macro_rules! normalization_layer_update_parameters_sgd {
($self:expr, $lr:expr) => {{
if let Some(grad_gamma) = &$self.grad_gamma {
$self.gamma = &$self.gamma - &(grad_gamma * $lr);
}
if let Some(grad_beta) = &$self.grad_beta {
$self.beta = &$self.beta - &(grad_beta * $lr);
}
}};
}
macro_rules! normalization_layer_update_parameters_adam {
($self:expr, $lr:expr, $beta1:expr, $beta2:expr, $epsilon:expr, $t:expr) => {{
if $self.optimizer_cache.adam_states.is_none() {
$self.optimizer_cache.adam_states = Some(AdamStatesNormalizationLayer {
m_gamma: Tensor::zeros($self.gamma.raw_dim()),
v_gamma: Tensor::zeros($self.gamma.raw_dim()),
m_beta: Tensor::zeros($self.beta.raw_dim()),
v_beta: Tensor::zeros($self.beta.raw_dim()),
});
}
if let Some(grad_gamma) = &$self.grad_gamma {
let adam_states = $self.optimizer_cache.adam_states.as_mut().unwrap();
let m_gamma = &mut adam_states.m_gamma;
let v_gamma = &mut adam_states.v_gamma;
*m_gamma = m_gamma.clone() * $beta1 + grad_gamma * (1.0 - $beta1);
*v_gamma = v_gamma.clone() * $beta2 + &(grad_gamma * grad_gamma) * (1.0 - $beta2);
let m_hat = m_gamma.clone() / (1.0 - $beta1.powi($t as i32));
let v_hat = v_gamma.clone() / (1.0 - $beta2.powi($t as i32));
$self.gamma = &$self.gamma - &(&m_hat / &(v_hat.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
if let Some(grad_beta) = &$self.grad_beta {
let adam_states = $self.optimizer_cache.adam_states.as_mut().unwrap();
let m_beta = &mut adam_states.m_beta;
let v_beta = &mut adam_states.v_beta;
*m_beta = m_beta.clone() * $beta1 + grad_beta * (1.0 - $beta1);
*v_beta = v_beta.clone() * $beta2 + &(grad_beta * grad_beta) * (1.0 - $beta2);
let m_hat = m_beta.clone() / (1.0 - $beta1.powi($t as i32));
let v_hat = v_beta.clone() / (1.0 - $beta2.powi($t as i32));
$self.beta = &$self.beta - &(&m_hat / &(v_hat.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
}};
}
macro_rules! normalization_layer_update_parameters_rmsprop {
($self:expr, $lr:expr, $rho:expr, $epsilon:expr) => {{
if $self.optimizer_cache.rmsprop_cache.is_none() {
$self.optimizer_cache.rmsprop_cache = Some(RMSpropCacheNormalizationLayer {
cache_gamma: Tensor::zeros($self.gamma.raw_dim()),
cache_beta: Tensor::zeros($self.beta.raw_dim()),
});
}
if let Some(grad_gamma) = &$self.grad_gamma {
let rmsprop_cache = $self.optimizer_cache.rmsprop_cache.as_mut().unwrap();
let cache_gamma = &mut rmsprop_cache.cache_gamma;
*cache_gamma = cache_gamma.clone() * $rho + &(grad_gamma * grad_gamma) * (1.0 - $rho);
$self.gamma =
&$self.gamma - &(grad_gamma / &(cache_gamma.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
if let Some(grad_beta) = &$self.grad_beta {
let rmsprop_cache = $self.optimizer_cache.rmsprop_cache.as_mut().unwrap();
let cache_beta = &mut rmsprop_cache.cache_beta;
*cache_beta = cache_beta.clone() * $rho + &(grad_beta * grad_beta) * (1.0 - $rho);
$self.beta =
&$self.beta - &(grad_beta / &(cache_beta.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
}};
}
macro_rules! normalization_layer_update_parameters_ada_grad {
($self:expr, $lr:expr, $epsilon:expr) => {{
if $self.optimizer_cache.ada_grad_cache.is_none() {
$self.optimizer_cache.ada_grad_cache = Some(AdaGradStatesNormalizationLayer {
acc_grad_gamma: Tensor::zeros($self.gamma.raw_dim()),
acc_grad_beta: Tensor::zeros($self.beta.raw_dim()),
});
}
if let Some(grad_gamma) = &$self.grad_gamma {
let ada_grad_cache = $self.optimizer_cache.ada_grad_cache.as_mut().unwrap();
let acc_grad_gamma = &mut ada_grad_cache.acc_grad_gamma;
*acc_grad_gamma = acc_grad_gamma.clone() + &(grad_gamma * grad_gamma);
$self.gamma = &$self.gamma
- &(grad_gamma / &(acc_grad_gamma.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
if let Some(grad_beta) = &$self.grad_beta {
let ada_grad_cache = $self.optimizer_cache.ada_grad_cache.as_mut().unwrap();
let acc_grad_beta = &mut ada_grad_cache.acc_grad_beta;
*acc_grad_beta = acc_grad_beta.clone() + &(grad_beta * grad_beta);
$self.beta =
&$self.beta - &(grad_beta / &(acc_grad_beta.mapv(|x| x.sqrt()) + $epsilon) * $lr);
}
}};
}
macro_rules! compute_normalization_layer_parameter_gradients {
(
$grad_gamma:expr,
$grad_beta:expr,
$grad_output:expr,
$x_normalized:expr,
$batch_size:expr,
$num_channels:expr,
$spatial_size:expr,
$total_elements:expr,
$parallel_threshold:expr
) => {{
if $total_elements >= $parallel_threshold {
let (grad_gamma_vec, grad_beta_vec): (Vec<f32>, Vec<f32>) = (0..$num_channels)
.into_par_iter()
.map(|channel_idx| {
let mut gamma_sum = 0.0f32;
let mut beta_sum = 0.0f32;
for batch_idx in 0..$batch_size {
let start = (batch_idx * $num_channels + channel_idx) * $spatial_size;
let end = start + $spatial_size;
for i in start..end {
let grad_out = $grad_output.as_slice().unwrap()[i];
gamma_sum += grad_out * $x_normalized.as_slice().unwrap()[i];
beta_sum += grad_out;
}
}
(gamma_sum, beta_sum)
})
.unzip();
$grad_gamma
.as_slice_mut()
.unwrap()
.copy_from_slice(&grad_gamma_vec);
$grad_beta
.as_slice_mut()
.unwrap()
.copy_from_slice(&grad_beta_vec);
} else {
for channel_idx in 0..$num_channels {
let mut gamma_sum = 0.0f32;
let mut beta_sum = 0.0f32;
for batch_idx in 0..$batch_size {
let start = (batch_idx * $num_channels + channel_idx) * $spatial_size;
let end = start + $spatial_size;
for i in start..end {
gamma_sum += $grad_output.as_slice().unwrap()[i]
* $x_normalized.as_slice().unwrap()[i];
beta_sum += $grad_output.as_slice().unwrap()[i];
}
}
$grad_gamma.as_slice_mut().unwrap()[channel_idx] = gamma_sum;
$grad_beta.as_slice_mut().unwrap()[channel_idx] = beta_sum;
}
}
}};
}
pub mod batch_normalization;
pub mod group_normalization;
pub mod instance_normalization;
pub mod layer_normalization;
pub use batch_normalization::BatchNormalization;
pub use group_normalization::GroupNormalization;
pub use instance_normalization::InstanceNormalization;
pub use layer_normalization::{LayerNormalization, LayerNormalizationAxis};