zyx-nn 0.15.3

Zyx nn modules
Documentation
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only

use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;

/// RMS norm layer
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct RMSNorm {
    /// weight, scale
    pub scale: Tensor,
    /// small value to avoid division by zero
    pub eps: f64,
}

impl RMSNorm {
    /// Initialize RMSNorm layer
    pub fn new(dim: u64, dtype: DType) -> RMSNorm {
        RMSNorm {
            scale: Tensor::ones(dim, dtype),
            eps: 1e-6,
        }
    }

    /// RMSNorm forward function
    pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
        let x = x.into();
        let dtype = x.dtype();
        let x_normed =
            &x * (x.pow(2)?.mean_keepdim([-1])? + Tensor::from(self.eps).cast(dtype)).rsqrt();
        Ok(x_normed * &self.scale)
    }
}