pub mod attention;
pub mod cache;
pub mod layer;
pub mod mlp;
pub mod model;
pub mod rope;
pub use attention::MiniCpmAttention;
pub use cache::StaticKvCache;
pub use layer::MiniCpmDecoderLayer;
pub use mlp::MiniCpmMlp;
pub use model::MiniCpmModel;
pub use rope::MiniCpmLongRope;
use burn::module::Param;
use burn::prelude::*;
#[derive(Module, Debug)]
pub struct MiniCpmRmsNorm<B: Backend> {
pub weight: Param<Tensor<B, 1>>,
pub eps: f64,
}
impl<B: Backend> MiniCpmRmsNorm<B> {
pub fn new(hidden_size: usize, eps: f64, device: &B::Device) -> Self {
Self {
weight: Param::from_tensor(Tensor::ones([hidden_size], device)),
eps,
}
}
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let x_sq = x.clone() * x.clone();
let variance = x_sq.mean_dim(D - 1);
let inv_rms = variance.add_scalar(self.eps).sqrt().recip();
let x = x * inv_rms;
x * self.weight.val().unsqueeze()
}
}