use crate::internlm2::config::InternLm2Config;
use std::fmt;
#[derive(Debug)]
pub enum InternLm2Error {
InvalidInput(String),
ForwardError(String),
}
impl fmt::Display for InternLm2Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InternLm2Error::InvalidInput(msg) => write!(f, "InternLM-2 invalid input: {msg}"),
InternLm2Error::ForwardError(msg) => write!(f, "InternLM-2 forward error: {msg}"),
}
}
}
impl std::error::Error for InternLm2Error {}
pub struct InternLm2RotaryEmbedding {
theta: f64,
scaling: Option<f64>,
}
impl InternLm2RotaryEmbedding {
pub fn new(theta: f64, scaling: Option<f64>) -> Self {
Self { theta, scaling }
}
fn compute_freqs(&self, head_dim: usize) -> Vec<f64> {
let half = head_dim / 2;
(0..half)
.map(|i| {
let base_freq = 1.0 / self.theta.powf(2.0 * i as f64 / head_dim as f64);
match self.scaling {
Some(scale) => base_freq * scale.powf(2.0 * i as f64 / head_dim as f64),
None => base_freq,
}
})
.collect()
}
pub fn apply(
&self,
q: &[f32],
k: &[f32],
seq_len: usize,
head_dim: usize,
) -> (Vec<f32>, Vec<f32>) {
if q.is_empty() || k.is_empty() || head_dim < 2 {
return (q.to_vec(), k.to_vec());
}
let freqs = self.compute_freqs(head_dim);
let half = head_dim / 2;
let rotate_single = |src: &[f32]| -> Vec<f32> {
let mut out = src.to_vec();
let num_vectors = src.len() / head_dim;
for vec_idx in 0..num_vectors {
let pos = vec_idx % seq_len;
let base = vec_idx * head_dim;
for i in 0..half {
let angle = pos as f32 * freqs[i] as f32;
let cos_a = angle.cos();
let sin_a = angle.sin();
let x0 = src[base + i];
let x1 = src[base + i + half];
out[base + i] = x0 * cos_a - x1 * sin_a;
out[base + i + half] = x0 * sin_a + x1 * cos_a;
}
}
out
};
(rotate_single(q), rotate_single(k))
}
}
pub struct InternLm2RmsNorm;
impl InternLm2RmsNorm {
pub fn forward(x: &[f32], weight: &[f32], eps: f64) -> Vec<f32> {
let len = x.len();
if len == 0 {
return Vec::new();
}
let mean_sq: f32 = x.iter().map(|v| v * v).sum::<f32>() / len as f32;
let rms = (mean_sq + eps as f32).sqrt();
x.iter().zip(weight.iter()).map(|(xi, wi)| xi / rms * wi).collect()
}
}
pub struct InternLm2Attention {
pub config: InternLm2Config,
pub layer_idx: usize,
#[allow(dead_code)]
q_weight: Vec<f32>,
#[allow(dead_code)]
kv_weight: Vec<f32>,
#[allow(dead_code)]
o_weight: Vec<f32>,
rope: InternLm2RotaryEmbedding,
norm_weight: Vec<f32>,
}
impl InternLm2Attention {
pub fn new(config: InternLm2Config, layer_idx: usize) -> Self {
let h = config.hidden_size;
let norm_weight = vec![1.0_f32; h];
let rope = InternLm2RotaryEmbedding::new(config.rope_theta, config.rope_scaling);
Self {
q_weight: vec![0.0_f32; h * h],
kv_weight: vec![0.0_f32; h * h],
o_weight: vec![0.0_f32; h * h],
rope,
norm_weight,
config,
layer_idx,
}
}
pub fn kv_head_for_q(&self, q_head: usize) -> usize {
let ratio = self.config.gqa_ratio();
q_head / ratio
}
pub fn forward(&self, hidden_states: &[f32], seq_len: usize) -> Vec<f32> {
let h = self.config.hidden_size;
let normed: Vec<f32> = hidden_states
.chunks(h)
.flat_map(|chunk| {
InternLm2RmsNorm::forward(chunk, &self.norm_weight, self.config.rms_norm_eps)
})
.collect();
let head_dim = self.config.head_dim();
let num_q_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_key_value_heads;
let q_proj: Vec<f32> = normed.iter().map(|v| v * 0.1).collect();
let k_proj: Vec<f32> = normed.iter().map(|v| v * 0.1).collect();
let (q_rot, _k_rot) = self.rope.apply(&q_proj, &k_proj, seq_len, head_dim);
let scale = (head_dim as f32).sqrt().recip();
let mut attn_out = vec![0.0_f32; seq_len * h];
for pos in 0..seq_len {
for q_head in 0..num_q_heads {
let kv_head = self.kv_head_for_q(q_head);
let q_base = pos * h + q_head * head_dim;
let k_base = pos * h + kv_head * head_dim;
let score: f32 = (0..head_dim)
.map(|i| {
let qi = q_rot.get(q_base + i).copied().unwrap_or(0.0);
let ki = k_proj.get(k_base + i).copied().unwrap_or(0.0);
qi * ki
})
.sum::<f32>()
* scale;
let _softmax_weight = score.exp();
let out_base = pos * h + q_head * head_dim;
let v_base = pos * h + kv_head * head_dim;
for i in 0..head_dim {
let v_val = normed.get(v_base + i).copied().unwrap_or(0.0);
if let Some(slot) = attn_out.get_mut(out_base + i) {
*slot += v_val * scale;
}
}
let _ = num_kv_heads;
}
}
attn_out
}
}
pub struct InternLm2MLP {
hidden_size: usize,
intermediate_size: usize,
#[allow(dead_code)]
gate_weight: Vec<f32>,
#[allow(dead_code)]
up_weight: Vec<f32>,
#[allow(dead_code)]
down_weight: Vec<f32>,
norm_weight: Vec<f32>,
rms_norm_eps: f64,
}
impl InternLm2MLP {
pub fn new(config: &InternLm2Config) -> Self {
let h = config.hidden_size;
let i = config.intermediate_size;
Self {
hidden_size: h,
intermediate_size: i,
gate_weight: vec![0.0_f32; i * h],
up_weight: vec![0.0_f32; i * h],
down_weight: vec![0.0_f32; h * i],
norm_weight: vec![1.0_f32; h],
rms_norm_eps: config.rms_norm_eps,
}
}
#[inline]
fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
pub fn forward(&self, x: &[f32]) -> Vec<f32> {
let total = x.len();
if total == 0 {
return Vec::new();
}
let h = self.hidden_size;
let num_tokens = total / h;
let mut out = vec![0.0_f32; total];
for tok in 0..num_tokens {
let x_tok = &x[tok * h..(tok + 1) * h];
let normed = InternLm2RmsNorm::forward(x_tok, &self.norm_weight, self.rms_norm_eps);
let gate: Vec<f32> = vec![0.0_f32; self.intermediate_size];
let up: Vec<f32> = vec![0.0_f32; self.intermediate_size];
let swiglu: Vec<f32> =
gate.iter().zip(up.iter()).map(|(g, u)| g * Self::silu(*u)).collect();
let out_tok = &mut out[tok * h..(tok + 1) * h];
for (i, slot) in out_tok.iter_mut().enumerate() {
*slot = normed.get(i).copied().unwrap_or(0.0) * 0.0
+ swiglu.get(i % self.intermediate_size).copied().unwrap_or(0.0);
}
}
out
}
}
pub struct InternLm2DecoderLayer {
attention: InternLm2Attention,
mlp: InternLm2MLP,
}
impl InternLm2DecoderLayer {
pub fn new(config: InternLm2Config, layer_idx: usize) -> Self {
let mlp = InternLm2MLP::new(&config);
let attention = InternLm2Attention::new(config, layer_idx);
Self { attention, mlp }
}
pub fn forward(&self, hidden_states: &[f32], seq_len: usize) -> Vec<f32> {
let attn_out = self.attention.forward(hidden_states, seq_len);
let after_attn: Vec<f32> =
hidden_states.iter().zip(attn_out.iter()).map(|(h, a)| h + a).collect();
let mlp_out = self.mlp.forward(&after_attn);
after_attn.iter().zip(mlp_out.iter()).map(|(h, m)| h + m).collect()
}
}
pub struct InternLm2Model {
pub config: InternLm2Config,
pub layers: Vec<InternLm2DecoderLayer>,
final_norm_weight: Vec<f32>,
#[allow(dead_code)]
embed_weight: Vec<f32>,
}
impl InternLm2Model {
pub fn new(config: InternLm2Config) -> Self {
let num_layers = config.num_hidden_layers;
let h = config.hidden_size;
let v = config.vocab_size;
let layers = (0..num_layers)
.map(|idx| InternLm2DecoderLayer::new(config.clone(), idx))
.collect();
Self {
final_norm_weight: vec![1.0_f32; h],
embed_weight: vec![0.0_f32; v * h],
layers,
config,
}
}
pub fn forward(&self, input_ids: &[u32]) -> Result<Vec<f32>, InternLm2Error> {
let seq_len = input_ids.len();
if seq_len == 0 {
return Err(InternLm2Error::InvalidInput(
"input_ids must not be empty".to_string(),
));
}
let h = self.config.hidden_size;
let v = self.config.vocab_size;
let mut hidden: Vec<f32> = Vec::with_capacity(seq_len * h);
for &tok in input_ids {
let tok_id = tok as usize;
if tok_id >= v {
return Err(InternLm2Error::InvalidInput(format!(
"token id {tok_id} is out of vocabulary range {v}"
)));
}
let embedding: Vec<f32> =
(0..h).map(|dim| (tok_id as f32 * 0.001) * ((dim + 1) as f32 * 0.01)).collect();
hidden.extend_from_slice(&embedding);
}
for layer in &self.layers {
hidden = layer.forward(&hidden, seq_len);
}
hidden = hidden
.chunks(h)
.flat_map(|chunk| {
InternLm2RmsNorm::forward(chunk, &self.final_norm_weight, self.config.rms_norm_eps)
})
.collect();
Ok(hidden)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::internlm2::config::InternLm2Config;
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..n).map(|_| lcg_next(&mut state)).collect()
}
fn tiny_internlm2_config() -> InternLm2Config {
InternLm2Config {
vocab_size: 64,
hidden_size: 8,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
intermediate_size: 16,
max_position_embeddings: 64,
rope_theta: 1_000_000.0,
rope_scaling: None,
hidden_act: "silu".to_string(),
rms_norm_eps: 1e-5,
tie_word_embeddings: false,
use_cache: true,
}
}
#[test]
fn test_internlm2_rmsnorm_unit_rms() {
let weight = vec![1.0_f32; 4];
let x = vec![3.0_f32, 4.0, 0.0, 0.0];
let output = InternLm2RmsNorm::forward(&x, &weight, 1e-5);
let rms = (output.iter().map(|v| v * v).sum::<f32>() / 4.0).sqrt();
assert!(
(rms - 1.0).abs() < 1e-4,
"RMSNorm output rms must be ~1.0, got {rms}"
);
}
#[test]
fn test_internlm2_rmsnorm_empty_input_returns_empty() {
let output = InternLm2RmsNorm::forward(&[], &[], 1e-5);
assert!(output.is_empty(), "empty input must return empty output");
}
#[test]
fn test_internlm2_rmsnorm_preserves_length() {
let x = lcg_vec(8, 70);
let w = vec![1.0_f32; 8];
let output = InternLm2RmsNorm::forward(&x, &w, 1e-5);
assert_eq!(output.len(), 8, "RMSNorm must preserve input length");
}
#[test]
fn test_internlm2_rope_output_length_matches_input() {
let rope = InternLm2RotaryEmbedding::new(1_000_000.0, None);
let head_dim = 8;
let seq_len = 4;
let q = lcg_vec(seq_len * head_dim, 71);
let k = lcg_vec(seq_len * head_dim, 72);
let (q_out, k_out) = rope.apply(&q, &k, seq_len, head_dim);
assert_eq!(q_out.len(), q.len(), "Q output length must match input");
assert_eq!(k_out.len(), k.len(), "K output length must match input");
}
#[test]
fn test_internlm2_rope_empty_input_passthrough() {
let rope = InternLm2RotaryEmbedding::new(10000.0, None);
let (q_out, k_out) = rope.apply(&[], &[], 0, 8);
assert!(q_out.is_empty(), "empty Q must pass through");
assert!(k_out.is_empty(), "empty K must pass through");
}
#[test]
fn test_internlm2_rope_norm_preserving() {
let rope = InternLm2RotaryEmbedding::new(10000.0, None);
let head_dim = 8;
let q = lcg_vec(head_dim, 73);
let k = lcg_vec(head_dim, 74);
let q_norm_before: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let (q_out, _) = rope.apply(&q, &k, 1, head_dim);
let q_norm_after: f32 = q_out.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(q_norm_before - q_norm_after).abs() < 1e-4,
"RoPE must preserve norm, before={q_norm_before} after={q_norm_after}",
);
}
#[test]
fn test_internlm2_rope_with_ntk_scaling() {
let rope = InternLm2RotaryEmbedding::new(1_000_000.0, Some(2.0));
let head_dim = 8;
let q = lcg_vec(head_dim, 75);
let k = lcg_vec(head_dim, 76);
let (q_out, k_out) = rope.apply(&q, &k, 1, head_dim);
assert_eq!(q_out.len(), head_dim, "NTK-scaled RoPE Q length must match");
assert_eq!(k_out.len(), head_dim, "NTK-scaled RoPE K length must match");
}
#[test]
fn test_internlm2_config_gqa_ratio() {
let cfg = tiny_internlm2_config();
let ratio = cfg.gqa_ratio();
assert_eq!(
ratio,
cfg.num_attention_heads / cfg.num_key_value_heads,
"gqa_ratio must equal nh/nkv",
);
}
#[test]
fn test_internlm2_config_head_dim() {
let cfg = tiny_internlm2_config();
let hd = cfg.head_dim();
assert_eq!(
hd,
cfg.hidden_size / cfg.num_attention_heads,
"head_dim = hidden_size/num_heads"
);
}
#[test]
fn test_internlm2_attention_kv_head_mapping() {
let cfg = tiny_internlm2_config();
let attn = InternLm2Attention::new(cfg.clone(), 0);
assert_eq!(attn.kv_head_for_q(0), 0, "Q head 0 must map to KV head 0");
assert_eq!(attn.kv_head_for_q(1), 0, "Q head 1 must map to KV head 0");
assert_eq!(attn.kv_head_for_q(2), 1, "Q head 2 must map to KV head 1");
assert_eq!(attn.kv_head_for_q(3), 1, "Q head 3 must map to KV head 1");
}
#[test]
fn test_internlm2_attention_forward_shape() {
let cfg = tiny_internlm2_config();
let attn = InternLm2Attention::new(cfg.clone(), 0);
let hidden = lcg_vec(cfg.hidden_size, 80);
let out = attn.forward(&hidden, 1);
assert_eq!(
out.len(),
cfg.hidden_size,
"attention output must have hidden_size elements"
);
}
#[test]
fn test_internlm2_mlp_output_length() {
let cfg = tiny_internlm2_config();
let mlp = InternLm2MLP::new(&cfg);
let x = lcg_vec(cfg.hidden_size, 81);
let out = mlp.forward(&x);
assert_eq!(
out.len(),
cfg.hidden_size,
"MLP output must have hidden_size elements"
);
}
#[test]
fn test_internlm2_mlp_empty_input_returns_empty() {
let cfg = tiny_internlm2_config();
let mlp = InternLm2MLP::new(&cfg);
let out = mlp.forward(&[]);
assert!(
out.is_empty(),
"MLP with empty input must return empty output"
);
}
#[test]
fn test_internlm2_model_construction() {
let cfg = tiny_internlm2_config();
let model = InternLm2Model::new(cfg);
assert_eq!(model.layers.len(), 2, "model must have 2 layers");
}
#[test]
fn test_internlm2_model_forward_single_token() {
let cfg = tiny_internlm2_config();
let model = InternLm2Model::new(cfg.clone());
let output = model.forward(&[0u32]).expect("forward must succeed");
assert_eq!(
output.len(),
cfg.hidden_size,
"output length must equal hidden_size"
);
}
#[test]
fn test_internlm2_model_forward_multi_token() {
let cfg = tiny_internlm2_config();
let model = InternLm2Model::new(cfg.clone());
let output = model.forward(&[0u32, 1, 2]).expect("multi-token forward must succeed");
assert_eq!(
output.len(),
3 * cfg.hidden_size,
"output length must be seq_len * hidden_size"
);
}
#[test]
fn test_internlm2_model_empty_input_fails() {
let cfg = tiny_internlm2_config();
let model = InternLm2Model::new(cfg);
let result = model.forward(&[]);
assert!(result.is_err(), "empty input must return an error");
}
#[test]
fn test_internlm2_model_out_of_vocab_fails() {
let cfg = tiny_internlm2_config(); let model = InternLm2Model::new(cfg);
let result = model.forward(&[100u32]); assert!(result.is_err(), "out-of-vocab token must return an error");
}
}