use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use std::collections::HashMap;
use crate::kernels::fused_ops::rms_norm_cuda;
pub struct RMSNorm {
pub weight: Tensor,
pub eps: f64,
}
impl RMSNorm {
pub fn load_direct(
tensors: &HashMap<String, Tensor>,
key: &str,
dim: usize,
eps: f64,
device: &candle_core::Device,
) -> Result<Self> {
let weight = tensors
.get(key)
.ok_or_else(|| candle_core::Error::Msg(format!("RMSNorm weight not found: {}", key)))?;
let weight_f32 = weight.to_dtype(candle_core::DType::F32)?;
let weight = if device.is_cpu() {
let data = weight_f32.to_vec1::<f32>()?;
Tensor::from_vec(data, (dim,), device)?
} else {
weight_f32.to_device(device)?
};
Ok(Self { weight, eps })
}
pub fn load(
dim: usize,
eps: f64,
vb: VarBuilder,
device: &candle_core::Device,
) -> Result<Self> {
let weight =
vb.get_with_hints((dim,), "weight", candle_nn::init::DEFAULT_KAIMING_NORMAL)?;
let weight = if device.is_cpu() {
let data = weight.to_vec1::<f32>()?;
Tensor::from_vec(data, weight.shape(), device)?
} else {
weight.to_device(device)?
};
Ok(Self { weight, eps })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let weight = if self.weight.device().same_device(x.device()) {
self.weight.clone()
} else {
self.weight.to_device(x.device())?
};
rms_norm_cuda(x, &weight, self.eps)
}
}