1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//! RMSNorm (Root Mean Square Layer Normalization).
//!
//! Used as pre-norm before attention and FFN in each Transformer block.
//! `output[i] = weight[i] * (input[i] / rms(input))`
//! where `rms(x) = sqrt(mean(x^2) + eps)`.
use crate::error::ModelResult;
/// RMSNorm layer with learnable weight vector.
#[derive(Debug)]
pub struct RmsNorm {
weight: Vec<f32>,
eps: f32,
}
impl RmsNorm {
/// Create a new RMSNorm layer.
///
/// - `weight`: Per-element scale weights (length = hidden_size).
/// - `eps`: Small constant for numerical stability.
pub fn new(weight: Vec<f32>, eps: f32) -> Self {
Self { weight, eps }
}
/// Apply RMSNorm to an input vector in-place.
///
/// `output[i] = weight[i] * input[i] / rms(input)`
///
/// Delegates to the SIMD-accelerated implementation in `oxibonsai_kernels`.
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
let n = input.len();
debug_assert_eq!(n, self.weight.len());
debug_assert!(output.len() >= n);
oxibonsai_kernels::rms_norm_simd(input, &self.weight, output, self.eps);
Ok(())
}
/// Hidden size (dimension of weight vector).
pub fn hidden_size(&self) -> usize {
self.weight.len()
}
/// Access the raw weight vector (for batch GPU dispatch).
pub fn weight(&self) -> &[f32] {
&self.weight
}
/// Access the epsilon value (for batch GPU dispatch).
pub fn eps(&self) -> f32 {
self.eps
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rms_norm_unit_weights() {
let weight = vec![1.0; 4];
let norm = RmsNorm::new(weight, 1e-6);
let input = vec![1.0, 2.0, 3.0, 4.0];
let mut output = vec![0.0; 4];
norm.forward(&input, &mut output)
.expect("rms norm forward should succeed");
// RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
let rms = (30.0f32 / 4.0).sqrt();
for i in 0..4 {
let expected = input[i] / rms;
assert!(
(output[i] - expected).abs() < 1e-5,
"at {i}: expected {expected}, got {}",
output[i]
);
}
}
#[test]
fn rms_norm_with_scale() {
let weight = vec![2.0; 4];
let norm = RmsNorm::new(weight, 1e-6);
let input = vec![1.0, 1.0, 1.0, 1.0];
let mut output = vec![0.0; 4];
norm.forward(&input, &mut output)
.expect("rms norm forward should succeed");
// RMS = sqrt(1) = 1.0, so output = 2.0 * 1.0 / 1.0 = 2.0
for &v in &output {
assert!((v - 2.0).abs() < 1e-5);
}
}
}