pub trait SimdOptimizer<T: Float> {
// Required methods
fn simd_sgd_update(
params: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
learning_rate: T,
) -> Array1<T>;
fn simd_momentum_update(
params: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
velocity: &ArrayView1<'_, T>,
learning_rate: T,
momentum: T,
) -> (Array1<T>, Array1<T>);
fn simd_adam_first_moment(
m: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
beta1: T,
) -> Array1<T>;
fn simd_adam_second_moment(
v: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
beta2: T,
) -> Array1<T>;
fn simd_adam_update(
params: &ArrayView1<'_, T>,
m_hat: &ArrayView1<'_, T>,
v_hat: &ArrayView1<'_, T>,
learning_rate: T,
epsilon: T,
) -> Array1<T>;
fn simd_weight_decay(
gradients: &ArrayView1<'_, T>,
params: &ArrayView1<'_, T>,
weight_decay: T,
) -> Array1<T>;
fn simd_gradient_norm(gradients: &ArrayView1<'_, T>) -> T;
}Expand description
Trait for SIMD-accelerated optimizer operations
This trait provides high-performance implementations of common operations found in optimization algorithms.
Required Methods§
Sourcefn simd_sgd_update(
params: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
learning_rate: T,
) -> Array1<T>
fn simd_sgd_update( params: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, learning_rate: T, ) -> Array1<T>
Sourcefn simd_momentum_update(
params: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
velocity: &ArrayView1<'_, T>,
learning_rate: T,
momentum: T,
) -> (Array1<T>, Array1<T>)
fn simd_momentum_update( params: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, velocity: &ArrayView1<'_, T>, learning_rate: T, momentum: T, ) -> (Array1<T>, Array1<T>)
SIMD-accelerated momentum update
velocity = momentum * velocity + learning_rate * gradient params = params - velocity
§Arguments
params- Parameter arraygradients- Gradient arrayvelocity- Velocity array (momentum state)learning_rate- Learning rate scalarmomentum- Momentum coefficient
§Returns
Tuple of (updated_params, updated_velocity)
Sourcefn simd_adam_first_moment(
m: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
beta1: T,
) -> Array1<T>
fn simd_adam_first_moment( m: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, beta1: T, ) -> Array1<T>
Sourcefn simd_adam_second_moment(
v: &ArrayView1<'_, T>,
gradients: &ArrayView1<'_, T>,
beta2: T,
) -> Array1<T>
fn simd_adam_second_moment( v: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, beta2: T, ) -> Array1<T>
Sourcefn simd_adam_update(
params: &ArrayView1<'_, T>,
m_hat: &ArrayView1<'_, T>,
v_hat: &ArrayView1<'_, T>,
learning_rate: T,
epsilon: T,
) -> Array1<T>
fn simd_adam_update( params: &ArrayView1<'_, T>, m_hat: &ArrayView1<'_, T>, v_hat: &ArrayView1<'_, T>, learning_rate: T, epsilon: T, ) -> Array1<T>
SIMD-accelerated Adam parameter update
params = params - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
§Arguments
params- Parameter arraym_hat- Bias-corrected first momentv_hat- Bias-corrected second momentlearning_rate- Learning rate scalarepsilon- Small constant for numerical stability
§Returns
Updated parameters
Sourcefn simd_weight_decay(
gradients: &ArrayView1<'_, T>,
params: &ArrayView1<'_, T>,
weight_decay: T,
) -> Array1<T>
fn simd_weight_decay( gradients: &ArrayView1<'_, T>, params: &ArrayView1<'_, T>, weight_decay: T, ) -> Array1<T>
Sourcefn simd_gradient_norm(gradients: &ArrayView1<'_, T>) -> T
fn simd_gradient_norm(gradients: &ArrayView1<'_, T>) -> T
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.
Implementations on Foreign Types§
Source§impl SimdOptimizer<f32> for f32
Implementation of SIMD optimizer operations for f32
impl SimdOptimizer<f32> for f32
Implementation of SIMD optimizer operations for f32
fn simd_sgd_update( params: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, learning_rate: f32, ) -> Array1<f32>
fn simd_momentum_update( params: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, velocity: &ArrayView1<'_, f32>, learning_rate: f32, momentum: f32, ) -> (Array1<f32>, Array1<f32>)
fn simd_adam_first_moment( m: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, beta1: f32, ) -> Array1<f32>
fn simd_adam_second_moment( v: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, beta2: f32, ) -> Array1<f32>
fn simd_adam_update( params: &ArrayView1<'_, f32>, m_hat: &ArrayView1<'_, f32>, v_hat: &ArrayView1<'_, f32>, learning_rate: f32, epsilon: f32, ) -> Array1<f32>
fn simd_weight_decay( gradients: &ArrayView1<'_, f32>, params: &ArrayView1<'_, f32>, weight_decay: f32, ) -> Array1<f32>
fn simd_gradient_norm(gradients: &ArrayView1<'_, f32>) -> f32
Source§impl SimdOptimizer<f64> for f64
Implementation of SIMD optimizer operations for f64
impl SimdOptimizer<f64> for f64
Implementation of SIMD optimizer operations for f64