SimdOptimizer

Trait SimdOptimizer 

Source
pub trait SimdOptimizer<T: Float> {
    // Required methods
    fn simd_sgd_update(
        params: &ArrayView1<'_, T>,
        gradients: &ArrayView1<'_, T>,
        learning_rate: T,
    ) -> Array1<T>;
    fn simd_momentum_update(
        params: &ArrayView1<'_, T>,
        gradients: &ArrayView1<'_, T>,
        velocity: &ArrayView1<'_, T>,
        learning_rate: T,
        momentum: T,
    ) -> (Array1<T>, Array1<T>);
    fn simd_adam_first_moment(
        m: &ArrayView1<'_, T>,
        gradients: &ArrayView1<'_, T>,
        beta1: T,
    ) -> Array1<T>;
    fn simd_adam_second_moment(
        v: &ArrayView1<'_, T>,
        gradients: &ArrayView1<'_, T>,
        beta2: T,
    ) -> Array1<T>;
    fn simd_adam_update(
        params: &ArrayView1<'_, T>,
        m_hat: &ArrayView1<'_, T>,
        v_hat: &ArrayView1<'_, T>,
        learning_rate: T,
        epsilon: T,
    ) -> Array1<T>;
    fn simd_weight_decay(
        gradients: &ArrayView1<'_, T>,
        params: &ArrayView1<'_, T>,
        weight_decay: T,
    ) -> Array1<T>;
    fn simd_gradient_norm(gradients: &ArrayView1<'_, T>) -> T;
}
Expand description

Trait for SIMD-accelerated optimizer operations

This trait provides high-performance implementations of common operations found in optimization algorithms.

Required Methods§

Source

fn simd_sgd_update( params: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, learning_rate: T, ) -> Array1<T>

SIMD-accelerated parameter update: params - learning_rate * gradient

§Arguments
  • params - Parameter array
  • gradients - Gradient array
  • learning_rate - Learning rate scalar
§Returns

Updated parameters

Source

fn simd_momentum_update( params: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, velocity: &ArrayView1<'_, T>, learning_rate: T, momentum: T, ) -> (Array1<T>, Array1<T>)

SIMD-accelerated momentum update

velocity = momentum * velocity + learning_rate * gradient params = params - velocity

§Arguments
  • params - Parameter array
  • gradients - Gradient array
  • velocity - Velocity array (momentum state)
  • learning_rate - Learning rate scalar
  • momentum - Momentum coefficient
§Returns

Tuple of (updated_params, updated_velocity)

Source

fn simd_adam_first_moment( m: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, beta1: T, ) -> Array1<T>

SIMD-accelerated Adam first moment update

m = beta1 * m + (1 - beta1) * gradient

§Arguments
  • m - First moment array
  • gradients - Gradient array
  • beta1 - Exponential decay rate for first moment
§Returns

Updated first moment

Source

fn simd_adam_second_moment( v: &ArrayView1<'_, T>, gradients: &ArrayView1<'_, T>, beta2: T, ) -> Array1<T>

SIMD-accelerated Adam second moment update

v = beta2 * v + (1 - beta2) * gradient^2

§Arguments
  • v - Second moment array
  • gradients - Gradient array
  • beta2 - Exponential decay rate for second moment
§Returns

Updated second moment

Source

fn simd_adam_update( params: &ArrayView1<'_, T>, m_hat: &ArrayView1<'_, T>, v_hat: &ArrayView1<'_, T>, learning_rate: T, epsilon: T, ) -> Array1<T>

SIMD-accelerated Adam parameter update

params = params - learning_rate * m_hat / (sqrt(v_hat) + epsilon)

§Arguments
  • params - Parameter array
  • m_hat - Bias-corrected first moment
  • v_hat - Bias-corrected second moment
  • learning_rate - Learning rate scalar
  • epsilon - Small constant for numerical stability
§Returns

Updated parameters

Source

fn simd_weight_decay( gradients: &ArrayView1<'_, T>, params: &ArrayView1<'_, T>, weight_decay: T, ) -> Array1<T>

SIMD-accelerated weight decay application

gradients = gradients + weight_decay * params

§Arguments
  • gradients - Gradient array
  • params - Parameter array
  • weight_decay - Weight decay coefficient
§Returns

Gradients with weight decay applied

Source

fn simd_gradient_norm(gradients: &ArrayView1<'_, T>) -> T

SIMD-accelerated gradient norm computation

§Arguments
  • gradients - Gradient array
§Returns

L2 norm of gradients

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl SimdOptimizer<f32> for f32

Implementation of SIMD optimizer operations for f32

Source§

fn simd_sgd_update( params: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, learning_rate: f32, ) -> Array1<f32>

Source§

fn simd_momentum_update( params: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, velocity: &ArrayView1<'_, f32>, learning_rate: f32, momentum: f32, ) -> (Array1<f32>, Array1<f32>)

Source§

fn simd_adam_first_moment( m: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, beta1: f32, ) -> Array1<f32>

Source§

fn simd_adam_second_moment( v: &ArrayView1<'_, f32>, gradients: &ArrayView1<'_, f32>, beta2: f32, ) -> Array1<f32>

Source§

fn simd_adam_update( params: &ArrayView1<'_, f32>, m_hat: &ArrayView1<'_, f32>, v_hat: &ArrayView1<'_, f32>, learning_rate: f32, epsilon: f32, ) -> Array1<f32>

Source§

fn simd_weight_decay( gradients: &ArrayView1<'_, f32>, params: &ArrayView1<'_, f32>, weight_decay: f32, ) -> Array1<f32>

Source§

fn simd_gradient_norm(gradients: &ArrayView1<'_, f32>) -> f32

Source§

impl SimdOptimizer<f64> for f64

Implementation of SIMD optimizer operations for f64

Source§

fn simd_sgd_update( params: &ArrayView1<'_, f64>, gradients: &ArrayView1<'_, f64>, learning_rate: f64, ) -> Array1<f64>

Source§

fn simd_momentum_update( params: &ArrayView1<'_, f64>, gradients: &ArrayView1<'_, f64>, velocity: &ArrayView1<'_, f64>, learning_rate: f64, momentum: f64, ) -> (Array1<f64>, Array1<f64>)

Source§

fn simd_adam_first_moment( m: &ArrayView1<'_, f64>, gradients: &ArrayView1<'_, f64>, beta1: f64, ) -> Array1<f64>

Source§

fn simd_adam_second_moment( v: &ArrayView1<'_, f64>, gradients: &ArrayView1<'_, f64>, beta2: f64, ) -> Array1<f64>

Source§

fn simd_adam_update( params: &ArrayView1<'_, f64>, m_hat: &ArrayView1<'_, f64>, v_hat: &ArrayView1<'_, f64>, learning_rate: f64, epsilon: f64, ) -> Array1<f64>

Source§

fn simd_weight_decay( gradients: &ArrayView1<'_, f64>, params: &ArrayView1<'_, f64>, weight_decay: f64, ) -> Array1<f64>

Source§

fn simd_gradient_norm(gradients: &ArrayView1<'_, f64>) -> f64

Implementors§