Skip to main content

NormOps

Trait NormOps 

Source
pub trait NormOps: Send + Sync {
    // Required method
    fn rms_norm(
        &self,
        input: &TensorRef,
        weight: &TensorRef,
        eps: f32,
    ) -> Result<TensorRef>;

    // Provided method
    fn rms_norm_residual(
        &self,
        input: &TensorRef,
        residual: &TensorRef,
        weight: &TensorRef,
        eps: f32,
    ) -> Result<(TensorRef, TensorRef)> { ... }
}
Expand description

Normalization operations.

Required Methods§

Source

fn rms_norm( &self, input: &TensorRef, weight: &TensorRef, eps: f32, ) -> Result<TensorRef>

RMS normalization: x / rms(x) * weight.

Provided Methods§

Source

fn rms_norm_residual( &self, input: &TensorRef, residual: &TensorRef, weight: &TensorRef, eps: f32, ) -> Result<(TensorRef, TensorRef)>

Fused RMS normalization with residual add: output = rms_norm(input + residual, weight, eps). Returns (normed_output, updated_residual).

Implementors§