use std::io::Read;
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result, TrustformersError},
layers::{Embedding, Linear},
tensor::Tensor,
traits::{Config, Layer, Model},
};
use super::config::Gemma2Config;
pub fn soft_cap(x: f32, cap: f64) -> f32 {
let c = cap as f32;
(x / c).tanh() * c
}
pub fn apply_soft_cap_inplace(data: &mut [f32], cap: f64) {
for v in data.iter_mut() {
*v = soft_cap(*v, cap);
}
}
pub fn gelu(x: f32) -> f32 {
use std::f32::consts::PI;
let c = (2.0f32 / PI).sqrt();
0.5 * x * (1.0 + (c * (x + 0.044715 * x * x * x)).tanh())
}
pub fn geglu(gate: &[f32], up: &[f32]) -> Vec<f32> {
gate.iter().zip(up.iter()).map(|(&g, &u)| gelu(g) * u).collect()
}
pub struct Gemma2RmsNorm {
weight: Tensor,
eps: f32,
device: Device,
}
impl Gemma2RmsNorm {
pub fn new(size: usize, eps: f64, device: Device) -> Result<Self> {
let weight = Tensor::ones(&[size])?;
Ok(Self {
weight,
eps: eps as f32,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for Gemma2RmsNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
match &input {
Tensor::F32(arr) => {
let n = arr.len() as f32;
let mean_sq = arr.iter().map(|x| x * x).sum::<f32>() / n;
let rms = (mean_sq + self.eps).sqrt();
let normed = arr.mapv(|x| x / rms);
match &self.weight {
Tensor::F32(w) => Ok(Tensor::F32(&normed * w)),
_ => Err(tensor_op_error(
"gemma2_rmsnorm",
"weight tensor must be F32",
)),
}
},
_ => Err(tensor_op_error(
"gemma2_rmsnorm",
"input tensor must be F32",
)),
}
}
}
pub struct Gemma2RotaryEmbedding {
head_dim: usize,
rope_theta: f64,
#[allow(dead_code)]
device: Device,
}
impl Gemma2RotaryEmbedding {
pub fn new(config: &Gemma2Config, device: Device) -> Self {
Self {
head_dim: config.head_dim,
rope_theta: config.rope_theta,
device,
}
}
pub fn apply(&self, q: &mut [f32], k: &mut [f32], seq_len: usize) {
let half = self.head_dim / 2;
for pos in 0..seq_len {
for i in 0..half {
let freq = 1.0 / self.rope_theta.powf(2.0 * i as f64 / self.head_dim as f64);
let angle = (pos as f64 * freq) as f32;
let cos_v = angle.cos();
let sin_v = angle.sin();
let base = pos * self.head_dim;
let q0 = q[base + i];
let q1 = q[base + i + half];
q[base + i] = q0 * cos_v - q1 * sin_v;
q[base + i + half] = q0 * sin_v + q1 * cos_v;
let k0 = k[base + i];
let k1 = k[base + i + half];
k[base + i] = k0 * cos_v - k1 * sin_v;
k[base + i + half] = k0 * sin_v + k1 * cos_v;
}
}
}
}
pub struct Gemma2GegluMlp {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
device: Device,
}
impl Gemma2GegluMlp {
pub fn new(config: &Gemma2Config, device: Device) -> Result<Self> {
let gate_proj =
Linear::new_with_device(config.hidden_size, config.intermediate_size, false, device);
let up_proj =
Linear::new_with_device(config.hidden_size, config.intermediate_size, false, device);
let down_proj =
Linear::new_with_device(config.intermediate_size, config.hidden_size, false, device);
Ok(Self {
gate_proj,
up_proj,
down_proj,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for Gemma2GegluMlp {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let gate_out = self.gate_proj.forward(input.clone())?;
let up_out = self.up_proj.forward(input)?;
let activated = match (&gate_out, &up_out) {
(Tensor::F32(g), Tensor::F32(u)) => {
let g_slice = g
.as_slice()
.ok_or_else(|| tensor_op_error("gemma2_geglu", "gate tensor not contiguous"))?;
let u_slice = u
.as_slice()
.ok_or_else(|| tensor_op_error("gemma2_geglu", "up tensor not contiguous"))?;
let result = geglu(g_slice, u_slice);
let shape = g.shape().to_vec();
Tensor::from_vec(result, &shape)?
},
_ => {
return Err(tensor_op_error(
"gemma2_geglu",
"gate and up tensors must be F32",
))
},
};
self.down_proj.forward(activated)
}
}
#[allow(dead_code)]
pub struct Gemma2Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: Gemma2RotaryEmbedding,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
is_local: bool,
sliding_window: usize,
attention_logit_softcapping: f64,
device: Device,
}
impl Gemma2Attention {
pub fn new(config: &Gemma2Config, layer_idx: usize, device: Device) -> Result<Self> {
let q_proj = Linear::new_with_device(
config.hidden_size,
config.num_attention_heads * config.head_dim,
false,
device,
);
let k_proj = Linear::new_with_device(
config.hidden_size,
config.num_key_value_heads * config.head_dim,
false,
device,
);
let v_proj = Linear::new_with_device(
config.hidden_size,
config.num_key_value_heads * config.head_dim,
false,
device,
);
let o_proj = Linear::new_with_device(
config.num_attention_heads * config.head_dim,
config.hidden_size,
false,
device,
);
let rotary_emb = Gemma2RotaryEmbedding::new(config, device);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
rotary_emb,
num_heads: config.num_attention_heads,
num_kv_heads: config.num_key_value_heads,
head_dim: config.head_dim,
is_local: Gemma2Config::is_local_layer(layer_idx),
sliding_window: config.sliding_window,
attention_logit_softcapping: config.attention_logit_softcapping,
device,
})
}
pub fn is_local(&self) -> bool {
self.is_local
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for Gemma2Attention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let q = self.q_proj.forward(input.clone())?;
let k = self.k_proj.forward(input.clone())?;
let _v = self.v_proj.forward(input)?;
let (q_data, _k_data) = match (&q, &k) {
(Tensor::F32(qd), Tensor::F32(kd)) => {
let mut qv = qd
.as_slice()
.ok_or_else(|| tensor_op_error("gemma2_attn", "q not contiguous"))?
.to_vec();
let mut kv = kd
.as_slice()
.ok_or_else(|| tensor_op_error("gemma2_attn", "k not contiguous"))?
.to_vec();
let total_q = qv.len();
let seq_len = total_q / (self.num_heads * self.head_dim).max(1);
if seq_len > 0 {
self.rotary_emb.apply(&mut qv, &mut kv, seq_len);
}
(qv, kv)
},
_ => return Err(tensor_op_error("gemma2_attn", "q and k must be F32")),
};
let q_shape = match &q {
Tensor::F32(arr) => arr.shape().to_vec(),
_ => vec![q_data.len()],
};
let attended = Tensor::from_vec(q_data, &q_shape)?;
self.o_proj.forward(attended)
}
}
pub struct Gemma2DecoderLayer {
self_attn: Gemma2Attention,
mlp: Gemma2GegluMlp,
input_layernorm: Gemma2RmsNorm,
post_attention_layernorm: Gemma2RmsNorm,
pre_feedforward_layernorm: Gemma2RmsNorm,
post_feedforward_layernorm: Gemma2RmsNorm,
device: Device,
}
impl Gemma2DecoderLayer {
pub fn new(config: &Gemma2Config, layer_idx: usize, device: Device) -> Result<Self> {
let self_attn = Gemma2Attention::new(config, layer_idx, device)?;
let mlp = Gemma2GegluMlp::new(config, device)?;
let input_layernorm = Gemma2RmsNorm::new(config.hidden_size, config.rms_norm_eps, device)?;
let post_attention_layernorm =
Gemma2RmsNorm::new(config.hidden_size, config.rms_norm_eps, device)?;
let pre_feedforward_layernorm =
Gemma2RmsNorm::new(config.hidden_size, config.rms_norm_eps, device)?;
let post_feedforward_layernorm =
Gemma2RmsNorm::new(config.hidden_size, config.rms_norm_eps, device)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
pre_feedforward_layernorm,
post_feedforward_layernorm,
device,
})
}
pub fn is_local(&self) -> bool {
self.self_attn.is_local()
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for Gemma2DecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let pre_normed = self.input_layernorm.forward(input.clone())?;
let attn_out = self.self_attn.forward(pre_normed)?;
let post_attn = self.post_attention_layernorm.forward(attn_out)?;
let hidden = input.add(&post_attn)?;
let pre_ff_normed = self.pre_feedforward_layernorm.forward(hidden.clone())?;
let mlp_out = self.mlp.forward(pre_ff_normed)?;
let post_ff = self.post_feedforward_layernorm.forward(mlp_out)?;
hidden.add(&post_ff)
}
}
pub struct Gemma2Model {
config: Gemma2Config,
embed_tokens: Embedding,
layers: Vec<Gemma2DecoderLayer>,
norm: Gemma2RmsNorm,
device: Device,
}
impl Gemma2Model {
pub fn new(config: Gemma2Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: Gemma2Config, device: Device) -> Result<Self> {
config.validate()?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for layer_idx in 0..config.num_hidden_layers {
layers.push(Gemma2DecoderLayer::new(&config, layer_idx, device)?);
}
let norm = Gemma2RmsNorm::new(config.hidden_size, config.rms_norm_eps, device)?;
Ok(Self {
config,
embed_tokens,
layers,
norm,
device,
})
}
pub fn config(&self) -> &Gemma2Config {
&self.config
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for Gemma2Model {
type Config = Gemma2Config;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input_ids: Self::Input) -> Result<Self::Output> {
let token_ids: Vec<u32> = match &input_ids {
Tensor::I64(arr) => arr.as_slice().unwrap_or(&[]).iter().map(|&x| x as u32).collect(),
Tensor::F32(arr) => {
arr.as_slice().unwrap_or(&[]).iter().map(|&x| x.round() as u32).collect()
},
_ => {
return Err(tensor_op_error(
"gemma2_forward",
"input_ids must be I64 or F32",
))
},
};
let mut hidden_states = self.embed_tokens.forward(token_ids)?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
self.norm.forward(hidden_states)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).map_err(|e| {
TrustformersError::io_error(format!("Gemma2: failed to read weights: {}", e))
})?;
if buffer.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Gemma2: pretrained weight data is empty".to_string(),
));
}
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let hs = self.config.hidden_size;
let is = self.config.intermediate_size;
let nl = self.config.num_hidden_layers;
let nh = self.config.num_attention_heads;
let nkv = self.config.num_key_value_heads;
let hd = self.config.head_dim;
let vs = self.config.vocab_size;
let embed = vs * hs;
let attn = hs * nh * hd + hs * nkv * hd + hs * nkv * hd + nh * hd * hs;
let mlp = hs * is + hs * is + is * hs;
let norms = 4 * hs;
let final_norm = hs;
embed + nl * (attn + mlp + norms) + final_norm
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::device::Device;
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..n).map(|_| lcg_next(&mut state)).collect()
}
fn tiny_config() -> Gemma2Config {
Gemma2Config {
vocab_size: 64,
hidden_size: 8,
num_hidden_layers: 2,
num_attention_heads: 2,
num_key_value_heads: 1,
intermediate_size: 16,
head_dim: 4,
max_position_embeddings: 32,
rope_theta: 10000.0,
rms_norm_eps: 1e-6,
sliding_window: 16,
attention_logit_softcapping: 50.0,
final_logit_softcapping: 30.0,
query_pre_attn_scalar: 0.5,
model_type: "gemma2-test".to_string(),
}
}
#[test]
fn test_soft_cap_at_zero() {
assert_eq!(soft_cap(0.0, 50.0), 0.0, "soft_cap(0) must be 0");
}
#[test]
fn test_soft_cap_positive_saturates_below_cap() {
let cap = 50.0_f64;
let large = 10_000.0_f32;
let result = soft_cap(large, cap);
assert!(
result <= cap as f32,
"soft_cap must be at most cap for large input"
);
assert!(result > 0.0, "soft_cap of positive must be positive");
}
#[test]
fn test_soft_cap_negative_saturates_above_neg_cap() {
let cap = 50.0_f64;
let result = soft_cap(-10_000.0, cap);
assert!(
result >= -(cap as f32),
"soft_cap must be at least -cap for large negative input"
);
assert!(result < 0.0, "soft_cap of negative must be negative");
}
#[test]
fn test_soft_cap_identity_region() {
let x = 0.01_f32;
let cap = 50.0_f64;
let result = soft_cap(x, cap);
assert!(
(result - x).abs() < 1e-4,
"soft_cap should be near identity for small inputs"
);
}
#[test]
fn test_apply_soft_cap_inplace_length_preserved() {
let mut data = lcg_vec(16, 42);
let original_len = data.len();
apply_soft_cap_inplace(&mut data, 30.0);
assert_eq!(
data.len(),
original_len,
"inplace soft-cap must not change length"
);
}
#[test]
fn test_apply_soft_cap_inplace_bounds() {
let mut data: Vec<f32> = vec![100.0, -100.0, 0.0, 50.0, -50.0];
apply_soft_cap_inplace(&mut data, 30.0);
for &v in &data {
assert!(
v < 30.0 && v > -30.0,
"all values must be within cap bounds after inplace"
);
}
}
#[test]
fn test_gelu_at_zero() {
assert!((gelu(0.0)).abs() < 1e-6, "gelu(0) must be ~0");
}
#[test]
fn test_gelu_positive_for_positive_input() {
assert!(gelu(1.0) > 0.0, "gelu(1) must be positive");
}
#[test]
fn test_gelu_approaches_identity_for_large_input() {
let x = 10.0_f32;
let result = gelu(x);
assert!(
(result - x).abs() < 0.01,
"gelu approaches identity for large positive"
);
}
#[test]
fn test_geglu_zero_gate_gives_zero() {
let gate = vec![0.0_f32; 8];
let up = lcg_vec(8, 1);
let result = geglu(&gate, &up);
for &v in &result {
assert!(v.abs() < 1e-6, "geglu with zero gate must output zeros");
}
}
#[test]
fn test_geglu_length_matches_input() {
let gate = lcg_vec(8, 2);
let up = lcg_vec(8, 3);
let result = geglu(&gate, &up);
assert_eq!(
result.len(),
8,
"geglu output length must match input length"
);
}
#[test]
fn test_rmsnorm_ones_weight_normalises_to_unit_rms() {
let norm =
Gemma2RmsNorm::new(4, 1e-6, Device::CPU).expect("RmsNorm construction must succeed");
let input = Tensor::from_vec(vec![3.0_f32, 4.0, 0.0, 0.0], &[4])
.expect("tensor creation must succeed");
let output = norm.forward(input).expect("rmsnorm forward must succeed");
let vals: Vec<f32> = match &output {
Tensor::F32(arr) => arr.as_slice().expect("output must be contiguous").to_vec(),
_ => panic!("expected F32 output"),
};
assert_eq!(vals.len(), 4, "output length must match input");
let rms = (vals.iter().map(|x| x * x).sum::<f32>() / 4.0).sqrt();
assert!(
(rms - 1.0).abs() < 1e-4,
"RMSNorm output rms must be ~1.0, got {rms}"
);
}
#[test]
fn test_rmsnorm_device_accessor() {
let norm = Gemma2RmsNorm::new(8, 1e-6, Device::CPU).expect("RmsNorm must build");
assert_eq!(
norm.device(),
Device::CPU,
"device accessor must return CPU"
);
}
#[test]
fn test_is_local_layer_even_layers() {
assert!(
Gemma2Config::is_local_layer(0),
"layer 0 (even) must be local"
);
assert!(
Gemma2Config::is_local_layer(2),
"layer 2 (even) must be local"
);
assert!(
Gemma2Config::is_local_layer(4),
"layer 4 (even) must be local"
);
}
#[test]
fn test_is_local_layer_odd_layers() {
assert!(
!Gemma2Config::is_local_layer(1),
"layer 1 (odd) must be global"
);
assert!(
!Gemma2Config::is_local_layer(3),
"layer 3 (odd) must be global"
);
assert!(
!Gemma2Config::is_local_layer(5),
"layer 5 (odd) must be global"
);
}
#[test]
fn test_kv_group_size_gemma2_2b() {
let cfg = Gemma2Config::gemma2_2b();
assert_eq!(
cfg.kv_group_size(),
cfg.num_attention_heads / cfg.num_key_value_heads,
"kv_group_size must equal num_attention_heads / num_key_value_heads",
);
}
#[test]
fn test_gemma2_2b_config_values() {
let cfg = Gemma2Config::gemma2_2b();
assert_eq!(cfg.hidden_size, 2304, "2B hidden_size must be 2304");
assert_eq!(cfg.num_hidden_layers, 26, "2B must have 26 layers");
assert_eq!(cfg.num_attention_heads, 8, "2B must have 8 heads");
assert_eq!(
cfg.head_dim, 256,
"head_dim must be 256 for all Gemma-2 variants"
);
}
#[test]
fn test_rope_apply_preserves_norm() {
let cfg = tiny_config();
let rope = Gemma2RotaryEmbedding::new(&cfg, Device::CPU);
let n = cfg.head_dim;
let mut q = lcg_vec(n, 7);
let mut k = lcg_vec(n, 8);
let q_norm_before: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let k_norm_before: f32 = k.iter().map(|x| x * x).sum::<f32>().sqrt();
rope.apply(&mut q, &mut k, 1);
let q_norm_after: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let k_norm_after: f32 = k.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(q_norm_before - q_norm_after).abs() < 1e-4,
"RoPE must be norm-preserving for Q, before={q_norm_before} after={q_norm_after}",
);
assert!(
(k_norm_before - k_norm_after).abs() < 1e-4,
"RoPE must be norm-preserving for K",
);
}
#[test]
fn test_rope_zero_position_is_identity() {
let cfg = tiny_config();
let rope = Gemma2RotaryEmbedding::new(&cfg, Device::CPU);
let n = cfg.head_dim;
let mut q = lcg_vec(n, 9);
let mut k = lcg_vec(n, 10);
let q_orig = q.clone();
let _k_orig = k.clone();
rope.apply(&mut q, &mut k, 1); for i in 0..n / 2 {
assert!(
(q[i] - q_orig[i]).abs() < 1e-5,
"RoPE at pos=0 must be identity for q[{i}]",
);
}
}
#[test]
fn test_gemma2_model_construction() {
let cfg = tiny_config();
let model = Gemma2Model::new(cfg).expect("Gemma2Model construction must succeed");
assert_eq!(model.layers.len(), 2, "model must have 2 decoder layers");
}
#[test]
fn test_gemma2_model_forward_shape() {
let cfg = tiny_config();
let model = Gemma2Model::new(cfg.clone()).expect("model must build");
let input = Tensor::from_vec(vec![0_f32], &[1]).expect("token tensor must build");
let output = model.forward(input).expect("forward must succeed");
let out_len = output.shape().iter().product::<usize>();
assert!(out_len > 0, "output must be non-empty");
assert_eq!(
out_len, cfg.hidden_size,
"output length must equal hidden_size for single token"
);
}
#[test]
fn test_gemma2_model_num_parameters_positive() {
let cfg = tiny_config();
let model = Gemma2Model::new(cfg).expect("model must build");
let params = model.num_parameters();
assert!(params > 0, "model must have a positive parameter count");
}
#[test]
fn test_gemma2_decoder_layer_is_local_reflects_layer_idx() {
let cfg = tiny_config();
let layer0 = Gemma2DecoderLayer::new(&cfg, 0, Device::CPU)
.expect("layer 0 construction must succeed");
let layer1 = Gemma2DecoderLayer::new(&cfg, 1, Device::CPU)
.expect("layer 1 construction must succeed");
assert!(layer0.is_local(), "layer 0 must be local");
assert!(!layer1.is_local(), "layer 1 must be global");
}
#[test]
fn test_gemma2_attention_layer_0_is_local() {
let cfg = tiny_config();
let attn = Gemma2Attention::new(&cfg, 0, Device::CPU)
.expect("attention construction must succeed");
assert!(attn.is_local(), "attention at layer 0 must be local");
}
#[test]
fn test_gemma2_attention_layer_1_is_global() {
let cfg = tiny_config();
let attn = Gemma2Attention::new(&cfg, 1, Device::CPU)
.expect("attention construction must succeed");
assert!(!attn.is_local(), "attention at layer 1 must be global");
}
#[test]
fn test_gemma2_config_validate_ok() {
let cfg = tiny_config();
assert!(cfg.validate().is_ok(), "valid config must pass validation");
}
#[test]
fn test_gemma2_config_validate_zero_heads_fails() {
let mut cfg = tiny_config();
cfg.num_attention_heads = 0;
assert!(
cfg.validate().is_err(),
"zero attention heads must fail validation"
);
}
#[test]
fn test_gemma2_model_forward_f32_input() {
let cfg = tiny_config();
let model = Gemma2Model::new(cfg.clone()).expect("model must build");
let input = Tensor::from_vec(vec![1.0_f32], &[1]).expect("f32 token tensor must build");
let output = model.forward(input).expect("forward with f32 input must succeed");
let out_len = output.shape().iter().product::<usize>();
assert_eq!(
out_len, cfg.hidden_size,
"f32 input must produce correct output shape"
);
}
}