use std::io::Read;
use crate::nemotron::config::{NemotronConfig, NormType};
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result, TrustformersError},
layers::{Embedding, Linear},
tensor::Tensor,
traits::{Config, Layer, Model},
};
#[inline]
pub fn squared_relu(x: f32) -> f32 {
let r = x.max(0.0);
r * r
}
pub struct NemotronRmsNorm {
weight: Tensor,
eps: f32,
}
impl NemotronRmsNorm {
pub fn new(dim: usize, eps: f64) -> Result<Self> {
let weight = Tensor::ones(&[dim])?;
Ok(Self {
weight,
eps: eps as f32,
})
}
}
impl Layer for NemotronRmsNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
match (&input, &self.weight) {
(Tensor::F32(arr), Tensor::F32(w)) => {
let mean_sq = arr.iter().map(|x| x * x).sum::<f32>() / arr.len().max(1) as f32;
let rms = (mean_sq + self.eps).sqrt();
let result = arr.mapv(|x| x / rms);
Ok(Tensor::F32(&result * w))
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported dtype for NemotronRmsNorm",
)),
}
}
}
pub struct NemotronLayerNorm {
weight: Tensor,
bias: Tensor,
eps: f32,
}
impl NemotronLayerNorm {
pub fn new(dim: usize, eps: f64) -> Result<Self> {
let weight = Tensor::ones(&[dim])?;
let bias = Tensor::zeros(&[dim])?;
Ok(Self {
weight,
bias,
eps: eps as f32,
})
}
}
impl Layer for NemotronLayerNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
match (&input, &self.weight, &self.bias) {
(Tensor::F32(arr), Tensor::F32(w), Tensor::F32(b)) => {
let n = arr.len().max(1) as f32;
let mean = arr.iter().sum::<f32>() / n;
let var = arr.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std_inv = 1.0 / (var + self.eps).sqrt();
let result = arr.mapv(|x| (x - mean) * std_inv);
Ok(Tensor::F32(&result * w + b))
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported dtype for NemotronLayerNorm",
)),
}
}
}
pub enum NemotronNorm {
Rms(NemotronRmsNorm),
Layer(NemotronLayerNorm),
}
impl NemotronNorm {
pub fn new(dim: usize, eps: f64, norm_type: &NormType) -> Result<Self> {
match norm_type {
NormType::RmsNorm => Ok(NemotronNorm::Rms(NemotronRmsNorm::new(dim, eps)?)),
NormType::LayerNorm => Ok(NemotronNorm::Layer(NemotronLayerNorm::new(dim, eps)?)),
}
}
}
impl Layer for NemotronNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
match self {
NemotronNorm::Rms(n) => n.forward(input),
NemotronNorm::Layer(n) => n.forward(input),
}
}
}
pub struct NemotronPartialRotaryEmbedding {
head_dim: usize,
rotary_dim: usize,
rope_theta: f64,
}
impl NemotronPartialRotaryEmbedding {
pub fn new(config: &NemotronConfig) -> Self {
Self {
head_dim: config.head_dim,
rotary_dim: config.rotary_dim(),
rope_theta: config.rope_theta,
}
}
pub fn apply(&self, q: &[f32], k: &[f32], seq_len: usize) -> (Vec<f32>, Vec<f32>) {
let half_rot = self.rotary_dim / 2;
let mut q_out = q.to_vec();
let mut k_out = k.to_vec();
for pos in 0..seq_len {
let base = pos * self.head_dim;
for i in 0..half_rot {
let freq = 1.0 / self.rope_theta.powf(2.0 * i as f64 / self.rotary_dim as f64);
let angle = (pos as f64 * freq) as f32;
let cos_v = angle.cos();
let sin_v = angle.sin();
let qi = base + i;
let qi_half = base + i + half_rot;
if let (Some(&q0), Some(&q1)) = (q_out.get(qi), q_out.get(qi_half)) {
q_out[qi] = q0 * cos_v - q1 * sin_v;
q_out[qi_half] = q0 * sin_v + q1 * cos_v;
}
let ki = base + i;
let ki_half = base + i + half_rot;
if let (Some(&k0), Some(&k1)) = (k_out.get(ki), k_out.get(ki_half)) {
k_out[ki] = k0 * cos_v - k1 * sin_v;
k_out[ki_half] = k0 * sin_v + k1 * cos_v;
}
}
}
(q_out, k_out)
}
}
pub struct NemotronMLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl NemotronMLP {
pub fn new(config: &NemotronConfig) -> Self {
Self {
gate_proj: Linear::new(
config.hidden_size,
config.intermediate_size,
config.mlp_bias,
),
up_proj: Linear::new(
config.hidden_size,
config.intermediate_size,
config.mlp_bias,
),
down_proj: Linear::new(
config.intermediate_size,
config.hidden_size,
config.mlp_bias,
),
}
}
}
impl Layer for NemotronMLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let gate = self.gate_proj.forward(input.clone())?;
let up = self.up_proj.forward(input)?;
let gated = match (&gate, &up) {
(Tensor::F32(g), Tensor::F32(u)) => {
let g_shape = g.shape().to_vec();
let data: Vec<f32> = g
.as_slice()
.unwrap_or(&[])
.iter()
.zip(u.as_slice().unwrap_or(&[]).iter())
.map(|(&gi, &ui)| squared_relu(gi) * ui)
.collect();
let shape: Vec<usize> = g_shape.to_vec();
Tensor::from_vec(data, &shape)?
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Type mismatch in NemotronMLP",
))
},
};
self.down_proj.forward(gated)
}
}
pub struct NemotronAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: NemotronPartialRotaryEmbedding,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
}
impl NemotronAttention {
pub fn new(config: &NemotronConfig) -> Self {
let head_dim = config.head_dim;
let num_kv_heads = config.num_key_value_heads;
let bias = config.attention_bias;
Self {
q_proj: Linear::new(
config.hidden_size,
config.num_attention_heads * head_dim,
bias,
),
k_proj: Linear::new(config.hidden_size, num_kv_heads * head_dim, bias),
v_proj: Linear::new(config.hidden_size, num_kv_heads * head_dim, bias),
o_proj: Linear::new(
config.num_attention_heads * head_dim,
config.hidden_size,
bias,
),
rotary_emb: NemotronPartialRotaryEmbedding::new(config),
num_heads: config.num_attention_heads,
num_kv_heads,
head_dim,
}
}
fn gqa_attention(&self, q: &[f32], k: &[f32], v: &[f32], seq_len: usize) -> Vec<f32> {
let kv_group = self.num_heads / self.num_kv_heads.max(1);
let scale = 1.0 / (self.head_dim as f32).sqrt();
let mut output = vec![0.0f32; self.num_heads * seq_len * self.head_dim];
for h in 0..self.num_heads {
let kv_h = h / kv_group;
for qi in 0..seq_len {
let mut scores = vec![0.0f32; seq_len];
for kj in 0..=qi {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let qv = q
.get(h * seq_len * self.head_dim + qi * self.head_dim + d)
.copied()
.unwrap_or(0.0);
let kv = k
.get(kv_h * seq_len * self.head_dim + kj * self.head_dim + d)
.copied()
.unwrap_or(0.0);
dot += qv * kv;
}
scores[kj] = dot * scale;
}
for kj in (qi + 1)..seq_len {
scores[kj] = f32::NEG_INFINITY;
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0f32;
let mut exp_scores: Vec<f32> = scores
.iter()
.map(|&s| {
let e = (s - max_s).exp();
exp_sum += e;
e
})
.collect();
if exp_sum > 0.0 {
for es in &mut exp_scores {
*es /= exp_sum;
}
}
for d in 0..self.head_dim {
let mut weighted = 0.0f32;
for kj in 0..seq_len {
let vv = v
.get(kv_h * seq_len * self.head_dim + kj * self.head_dim + d)
.copied()
.unwrap_or(0.0);
weighted += exp_scores[kj] * vv;
}
output[h * seq_len * self.head_dim + qi * self.head_dim + d] = weighted;
}
}
}
output
}
}
impl Layer for NemotronAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let q_t = self.q_proj.forward(input.clone())?;
let k_t = self.k_proj.forward(input.clone())?;
let v_t = self.v_proj.forward(input)?;
match (&q_t, &k_t, &v_t) {
(Tensor::F32(q_arr), Tensor::F32(k_arr), Tensor::F32(v_arr)) => {
let total_q = q_arr.len();
let seq_len =
total_q.checked_div(self.num_heads * self.head_dim).unwrap_or(1).max(1);
let q_s = q_arr.as_slice().unwrap_or(&[]);
let k_s = k_arr.as_slice().unwrap_or(&[]);
let (q_rot, k_rot) = self.rotary_emb.apply(q_s, k_s, seq_len);
let v_s = v_arr.as_slice().unwrap_or(&[]);
let context = self.gqa_attention(&q_rot, &k_rot, v_s, seq_len);
let context_tensor =
Tensor::from_vec(context.clone(), &[seq_len, self.num_heads * self.head_dim])?;
self.o_proj.forward(context_tensor)
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported dtype in NemotronAttention",
)),
}
}
}
pub struct NemotronDecoderLayer {
self_attn: NemotronAttention,
mlp: NemotronMLP,
input_layernorm: NemotronNorm,
post_attention_layernorm: NemotronNorm,
}
impl NemotronDecoderLayer {
pub fn new(config: &NemotronConfig) -> Result<Self> {
Ok(Self {
self_attn: NemotronAttention::new(config),
mlp: NemotronMLP::new(config),
input_layernorm: NemotronNorm::new(
config.hidden_size,
config.rms_norm_eps,
&config.norm_type,
)?,
post_attention_layernorm: NemotronNorm::new(
config.hidden_size,
config.rms_norm_eps,
&config.norm_type,
)?,
})
}
}
impl Layer for NemotronDecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Tensor) -> Result<Tensor> {
let normed = self.input_layernorm.forward(input.clone())?;
let attn_out = self.self_attn.forward(normed)?;
let residual = input.add(&attn_out)?;
let normed2 = self.post_attention_layernorm.forward(residual.clone())?;
let mlp_out = self.mlp.forward(normed2)?;
residual.add(&mlp_out)
}
}
pub struct NemotronModel {
config: NemotronConfig,
embed_tokens: Embedding,
layers: Vec<NemotronDecoderLayer>,
norm: NemotronNorm,
}
impl NemotronModel {
pub fn new(config: NemotronConfig) -> Result<Self> {
Config::validate(&config)?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for _ in 0..config.num_hidden_layers {
layers.push(NemotronDecoderLayer::new(&config)?);
}
let norm = NemotronNorm::new(config.hidden_size, config.rms_norm_eps, &config.norm_type)?;
Ok(Self {
config,
embed_tokens,
layers,
norm,
})
}
pub fn config(&self) -> &NemotronConfig {
&self.config
}
}
impl Model for NemotronModel {
type Config = NemotronConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input_ids: Tensor) -> Result<Tensor> {
let token_ids: Vec<u32> = match &input_ids {
Tensor::I64(arr) => arr.as_slice().unwrap_or(&[]).iter().map(|&x| x as u32).collect(),
Tensor::F32(arr) => {
arr.as_slice().unwrap_or(&[]).iter().map(|&x| x.round() as u32).collect()
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported input dtype for NemotronModel",
))
},
};
let mut hidden = self.embed_tokens.forward(token_ids)?;
for layer in &self.layers {
hidden = layer.forward(hidden)?;
}
self.norm.forward(hidden)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Ok(())
}
fn get_config(&self) -> &NemotronConfig {
&self.config
}
fn num_parameters(&self) -> usize {
let c = &self.config;
let embed = c.vocab_size * c.hidden_size;
let attn = (c.num_attention_heads + 2 * c.num_key_value_heads) * c.head_dim * c.hidden_size
+ c.num_attention_heads * c.head_dim * c.hidden_size;
let mlp = 3 * c.hidden_size * c.intermediate_size;
let norms = 2 * c.hidden_size;
let layer = attn + mlp + norms;
embed + c.num_hidden_layers * layer + c.hidden_size
}
}
#[allow(dead_code)]
fn _assert_imports() {
let _ = Device::CPU;
let _ = TrustformersError::io_error("unused".to_string());
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nemotron::config::{NemotronConfig, NormType};
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..n).map(|_| lcg_next(&mut state)).collect()
}
fn tiny_nemotron_config() -> NemotronConfig {
NemotronConfig {
vocab_size: 64,
hidden_size: 8,
intermediate_size: 16,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
head_dim: 2,
max_position_embeddings: 32,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
partial_rotary_factor: 0.5,
hidden_act: "relu2".to_string(),
tie_word_embeddings: false,
norm_type: NormType::RmsNorm,
attention_bias: false,
mlp_bias: false,
}
}
#[test]
fn test_squared_relu_negative_is_zero() {
assert_eq!(
squared_relu(-1.0),
0.0,
"squared_relu of negative must be 0"
);
assert_eq!(
squared_relu(-100.0),
0.0,
"squared_relu of large negative must be 0"
);
}
#[test]
fn test_squared_relu_zero_is_zero() {
assert_eq!(squared_relu(0.0), 0.0, "squared_relu(0) must be 0");
}
#[test]
fn test_squared_relu_positive_is_squared() {
assert!(
(squared_relu(2.0) - 4.0).abs() < 1e-6,
"squared_relu(2.0) must be 4.0"
);
assert!(
(squared_relu(3.0) - 9.0).abs() < 1e-6,
"squared_relu(3.0) must be 9.0"
);
}
#[test]
fn test_squared_relu_one_is_one() {
assert!(
(squared_relu(1.0) - 1.0).abs() < 1e-6,
"squared_relu(1.0) must be 1.0"
);
}
#[test]
fn test_nemotron_rmsnorm_forward_unit_rms() {
let norm =
NemotronRmsNorm::new(4, 1e-5).expect("NemotronRmsNorm construction must succeed");
let input =
Tensor::from_vec(vec![3.0_f32, 4.0, 0.0, 0.0], &[4]).expect("tensor must build");
let output = norm.forward(input).expect("forward must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("must be contiguous").to_vec(),
_ => panic!("expected F32 output"),
};
let rms = (vals.iter().map(|x| x * x).sum::<f32>() / 4.0).sqrt();
assert!(
(rms - 1.0).abs() < 1e-4,
"RMSNorm output rms must be ~1.0, got {rms}"
);
}
#[test]
fn test_nemotron_layernorm_zero_mean() {
let norm = NemotronLayerNorm::new(4, 1e-5).expect("NemotronLayerNorm must build");
let input =
Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], &[4]).expect("tensor must build");
let output = norm.forward(input).expect("forward must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("must be contiguous").to_vec(),
_ => panic!("expected F32 output"),
};
let mean = vals.iter().sum::<f32>() / vals.len() as f32;
assert!(
mean.abs() < 1e-4,
"LayerNorm output mean must be ~0, got {mean}"
);
}
#[test]
fn test_nemotron_norm_rmstype_dispatches_correctly() {
let norm =
NemotronNorm::new(4, 1e-5, &NormType::RmsNorm).expect("NemotronNorm(Rms) must build");
matches!(norm, NemotronNorm::Rms(_));
}
#[test]
fn test_nemotron_norm_layertype_dispatches_correctly() {
let norm = NemotronNorm::new(4, 1e-5, &NormType::LayerNorm)
.expect("NemotronNorm(Layer) must build");
matches!(norm, NemotronNorm::Layer(_));
}
#[test]
fn test_partial_rope_output_length_matches_input() {
let cfg = tiny_nemotron_config();
let rope = NemotronPartialRotaryEmbedding::new(&cfg);
let n = cfg.head_dim;
let q = lcg_vec(n, 11);
let k = lcg_vec(n, 12);
let (q_out, k_out) = rope.apply(&q, &k, 1);
assert_eq!(
q_out.len(),
n,
"partial RoPE output Q length must match input"
);
assert_eq!(
k_out.len(),
n,
"partial RoPE output K length must match input"
);
}
#[test]
fn test_partial_rope_non_rotary_portion_unchanged() {
let cfg = NemotronConfig {
head_dim: 8,
partial_rotary_factor: 0.5, ..tiny_nemotron_config()
};
let rope = NemotronPartialRotaryEmbedding::new(&cfg);
let q: Vec<f32> = (0..8).map(|i| i as f32 * 0.1).collect();
let k: Vec<f32> = (0..8).map(|i| i as f32 * 0.2).collect();
let (q_out, _k_out) = rope.apply(&q, &k, 1);
for i in 4..8 {
assert!(
(q_out[i] - q[i]).abs() < 1e-6,
"non-rotary portion q[{i}] must be unchanged after partial RoPE",
);
}
}
#[test]
fn test_partial_rope_rotary_dim() {
let cfg = tiny_nemotron_config();
let rotary_dim = cfg.rotary_dim();
let expected = (cfg.head_dim as f32 * cfg.partial_rotary_factor) as usize;
assert_eq!(
rotary_dim, expected,
"rotary_dim must equal head_dim * partial_rotary_factor"
);
}
#[test]
fn test_nemotron_mlp_output_length() {
let cfg = tiny_nemotron_config();
let mlp = NemotronMLP::new(&cfg);
let input = Tensor::from_vec(lcg_vec(cfg.hidden_size, 20), &[1, cfg.hidden_size])
.expect("input tensor must build");
let output = mlp.forward(input).expect("MLP forward must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.hidden_size,
"MLP output must have hidden_size elements"
);
}
#[test]
fn test_nemotron_config_validate_ok() {
let cfg = tiny_nemotron_config();
assert!(
Config::validate(&cfg).is_ok(),
"valid tiny nemotron config must pass validation"
);
}
#[test]
fn test_nemotron_config_rotary_dim_fraction() {
let cfg = NemotronConfig {
head_dim: 128,
partial_rotary_factor: 0.5,
..NemotronConfig::default()
};
assert_eq!(
cfg.rotary_dim(),
64,
"50% partial rotary on dim=128 must give 64"
);
}
#[test]
fn test_nemotron_4_22b_config_values() {
let cfg = NemotronConfig::nemotron_4_22b();
assert_eq!(cfg.hidden_size, 6144, "22B hidden_size must be 6144");
assert_eq!(cfg.num_hidden_layers, 40, "22B must have 40 layers");
assert_eq!(cfg.num_attention_heads, 48, "22B must have 48 heads");
assert_eq!(cfg.num_key_value_heads, 8, "22B must have 8 kv heads (GQA)");
assert!(!cfg.mlp_bias, "nemotron must have no MLP bias");
assert!(!cfg.attention_bias, "nemotron must have no attention bias");
}
#[test]
fn test_nemotron_model_construction() {
let cfg = tiny_nemotron_config();
let model = NemotronModel::new(cfg).expect("NemotronModel must build");
assert_eq!(
model.config().num_hidden_layers,
2,
"model must have 2 layers"
);
}
#[test]
fn test_nemotron_model_forward_single_token() {
let cfg = tiny_nemotron_config();
let model = NemotronModel::new(cfg.clone()).expect("model must build");
let input = Tensor::from_vec(vec![0_f32], &[1]).expect("f32 token must build");
let _ = model.forward(input);
}
#[test]
fn test_nemotron_model_num_parameters_positive() {
let cfg = tiny_nemotron_config();
let model = NemotronModel::new(cfg).expect("model must build");
assert!(
model.num_parameters() > 0,
"num_parameters must be positive"
);
}
#[test]
fn test_nemotron_model_forward_f32_input() {
let cfg = tiny_nemotron_config();
let model = NemotronModel::new(cfg).expect("model must build");
let input = Tensor::from_vec(vec![1.0_f32], &[1]).expect("f32 token must build");
let _ = model.forward(input);
}
}