gradient_centralization

Function gradient_centralization 

Source
pub fn gradient_centralization<A, D>(
    gradients: &mut Array<A, D>,
) -> &mut Array<A, D>
Expand description

Compute gradient centralization

Gradient Centralization is a technique that improves training stability by removing the mean from each gradient tensor.

§Arguments

  • gradients - The gradients to centralize

§Returns

The centralized gradients (in-place modification)

§Examples

use scirs2_core::ndarray::Array1;
use optirs_core::utils::gradient_centralization;

let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0]);
gradient_centralization(&mut gradients);
assert_eq!(gradients, Array1::from_vec(vec![-1.0, 0.0, 1.0, 0.0]));