use crate::error::{LmError, LmResult};
use crate::weights::WeightTensor;
#[derive(Debug, Clone)]
pub struct TokenEmbedding {
pub vocab_size: usize,
pub embed_dim: usize,
pub weight: WeightTensor,
}
impl TokenEmbedding {
pub fn new(vocab_size: usize, embed_dim: usize) -> LmResult<Self> {
if vocab_size == 0 || embed_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "TokenEmbedding: vocab_size and embed_dim must be > 0".into(),
});
}
let weight = WeightTensor::zeros(&[vocab_size, embed_dim]);
Ok(Self {
vocab_size,
embed_dim,
weight,
})
}
pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
if weight.shape.len() != 2 {
return Err(LmError::DimensionMismatch {
expected: 2,
got: weight.shape.len(),
});
}
let vocab_size = weight.shape[0];
let embed_dim = weight.shape[1];
if vocab_size == 0 || embed_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "TokenEmbedding weight must be non-empty".into(),
});
}
Ok(Self {
vocab_size,
embed_dim,
weight,
})
}
pub fn forward(&self, token_ids: &[u32]) -> LmResult<Vec<f32>> {
if token_ids.is_empty() {
return Err(LmError::EmptyInput {
context: "token_ids",
});
}
let mut out = vec![0.0_f32; token_ids.len() * self.embed_dim];
for (pos, &tid) in token_ids.iter().enumerate() {
if tid as usize >= self.vocab_size {
return Err(LmError::OutOfVocab { token: tid });
}
let src_start = tid as usize * self.embed_dim;
let dst_start = pos * self.embed_dim;
out[dst_start..dst_start + self.embed_dim]
.copy_from_slice(&self.weight.data[src_start..src_start + self.embed_dim]);
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct LearnedPositionalEmbedding {
pub max_positions: usize,
pub embed_dim: usize,
pub weight: WeightTensor,
}
impl LearnedPositionalEmbedding {
pub fn new(max_positions: usize, embed_dim: usize) -> LmResult<Self> {
if max_positions == 0 || embed_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "LearnedPositionalEmbedding: max_positions and embed_dim must be > 0".into(),
});
}
let weight = WeightTensor::zeros(&[max_positions, embed_dim]);
Ok(Self {
max_positions,
embed_dim,
weight,
})
}
pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
if weight.shape.len() != 2 {
return Err(LmError::DimensionMismatch {
expected: 2,
got: weight.shape.len(),
});
}
let max_positions = weight.shape[0];
let embed_dim = weight.shape[1];
Ok(Self {
max_positions,
embed_dim,
weight,
})
}
pub fn forward(&self, seq_len: usize, offset: usize) -> LmResult<Vec<f32>> {
if offset + seq_len > self.max_positions {
return Err(LmError::SequenceTooLong {
total_len: offset + seq_len,
max_pos: self.max_positions,
});
}
let mut out = vec![0.0_f32; seq_len * self.embed_dim];
for i in 0..seq_len {
let pos = offset + i;
let src = pos * self.embed_dim;
let dst = i * self.embed_dim;
out[dst..dst + self.embed_dim]
.copy_from_slice(&self.weight.data[src..src + self.embed_dim]);
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
pub head_dim: usize,
pub max_positions: usize,
pub theta: f32,
cos_table: Vec<f32>,
sin_table: Vec<f32>,
}
impl RotaryEmbedding {
pub fn new(head_dim: usize, max_positions: usize, theta: f32) -> LmResult<Self> {
if head_dim == 0 || head_dim % 2 != 0 {
return Err(LmError::InvalidConfig {
msg: format!("RotaryEmbedding: head_dim={head_dim} must be even and > 0"),
});
}
if max_positions == 0 {
return Err(LmError::InvalidConfig {
msg: "RotaryEmbedding: max_positions must be > 0".into(),
});
}
if theta <= 0.0 {
return Err(LmError::InvalidConfig {
msg: "RotaryEmbedding: theta must be > 0".into(),
});
}
let half_dim = head_dim / 2;
let n = max_positions * half_dim;
let mut cos_table = Vec::with_capacity(n);
let mut sin_table = Vec::with_capacity(n);
for pos in 0..max_positions {
for i in 0..half_dim {
let freq = theta.powf(-((2 * i) as f32) / head_dim as f32);
let angle = pos as f32 * freq;
cos_table.push(angle.cos());
sin_table.push(angle.sin());
}
}
Ok(Self {
head_dim,
max_positions,
theta,
cos_table,
sin_table,
})
}
pub fn apply(
&self,
x: &mut [f32],
n_heads: usize,
n_tokens: usize,
offset: usize,
) -> LmResult<()> {
if offset + n_tokens > self.max_positions {
return Err(LmError::SequenceTooLong {
total_len: offset + n_tokens,
max_pos: self.max_positions,
});
}
let expected = n_tokens * n_heads * self.head_dim;
if x.len() != expected {
return Err(LmError::DimensionMismatch {
expected,
got: x.len(),
});
}
let half_dim = self.head_dim / 2;
for t in 0..n_tokens {
let abs_pos = offset + t;
let cos_row_start = abs_pos * half_dim;
for h in 0..n_heads {
let base = (t * n_heads + h) * self.head_dim;
for i in 0..half_dim {
let cos = self.cos_table[cos_row_start + i];
let sin = self.sin_table[cos_row_start + i];
let x0 = x[base + 2 * i];
let x1 = x[base + 2 * i + 1];
x[base + 2 * i] = x0 * cos - x1 * sin;
x[base + 2 * i + 1] = x0 * sin + x1 * cos;
}
}
}
Ok(())
}
pub fn cos_at(&self, pos: usize, i: usize) -> f32 {
self.cos_table[pos * (self.head_dim / 2) + i]
}
pub fn sin_at(&self, pos: usize, i: usize) -> f32 {
self.sin_table[pos * (self.head_dim / 2) + i]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_embedding_lookup() {
let mut emb = TokenEmbedding::new(4, 3).unwrap();
emb.weight.data[6] = 1.0;
emb.weight.data[7] = 2.0;
emb.weight.data[8] = 3.0;
let out = emb.forward(&[2]).unwrap();
assert_eq!(out, vec![1.0_f32, 2.0, 3.0]);
}
#[test]
fn token_embedding_multi_token() {
let mut emb = TokenEmbedding::new(3, 2).unwrap();
emb.weight.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let out = emb.forward(&[0, 2]).unwrap();
assert_eq!(out, vec![1.0_f32, 2.0, 5.0, 6.0]);
}
#[test]
fn token_embedding_out_of_vocab_error() {
let emb = TokenEmbedding::new(4, 3).unwrap();
assert!(matches!(
emb.forward(&[5]),
Err(LmError::OutOfVocab { token: 5 })
));
}
#[test]
fn token_embedding_empty_error() {
let emb = TokenEmbedding::new(4, 3).unwrap();
assert!(matches!(emb.forward(&[]), Err(LmError::EmptyInput { .. })));
}
#[test]
fn token_embedding_from_weight() {
let w = WeightTensor::zeros(&[10, 4]);
let emb = TokenEmbedding::from_weight(w).unwrap();
assert_eq!(emb.vocab_size, 10);
assert_eq!(emb.embed_dim, 4);
}
#[test]
fn pos_embedding_lookup() {
let mut pe = LearnedPositionalEmbedding::new(4, 2).unwrap();
pe.weight.data[2] = 3.0;
pe.weight.data[3] = 4.0;
let out = pe.forward(2, 0).unwrap();
assert_eq!(out, vec![0.0_f32, 0.0, 3.0, 4.0]);
}
#[test]
fn pos_embedding_with_offset() {
let mut pe = LearnedPositionalEmbedding::new(8, 2).unwrap();
for i in 8..12 {
pe.weight.data[i] = 10.0;
}
let out = pe.forward(2, 4).unwrap(); assert!(out.iter().all(|&v| v == 10.0));
}
#[test]
fn pos_embedding_too_long_error() {
let pe = LearnedPositionalEmbedding::new(4, 2).unwrap();
assert!(matches!(
pe.forward(5, 0),
Err(LmError::SequenceTooLong { .. })
));
}
#[test]
fn rope_pos0_is_identity() {
let rope = RotaryEmbedding::new(4, 16, 10_000.0).unwrap();
let mut x = vec![1.0_f32, 2.0, 3.0, 4.0]; rope.apply(&mut x, 1, 1, 0).unwrap();
assert!((x[0] - 1.0).abs() < 1e-5, "x[0]={}", x[0]);
assert!((x[1] - 2.0).abs() < 1e-5, "x[1]={}", x[1]);
assert!((x[2] - 3.0).abs() < 1e-5, "x[2]={}", x[2]);
assert!((x[3] - 4.0).abs() < 1e-5, "x[3]={}", x[3]);
}
#[test]
fn rope_rotation_preserves_norm() {
let rope = RotaryEmbedding::new(4, 32, 10_000.0).unwrap();
let original = vec![1.0_f32, 2.0, 3.0, 4.0];
let mut x = original.clone();
rope.apply(&mut x, 1, 1, 5).unwrap(); let norm_before: f32 = original.iter().map(|&v| v * v).sum::<f32>().sqrt();
let norm_after: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm_before - norm_after).abs() < 1e-4,
"norm {norm_before} ≠ {norm_after}"
);
}
#[test]
fn rope_multi_head_multi_token() {
let rope = RotaryEmbedding::new(4, 32, 10_000.0).unwrap();
let mut x = vec![1.0_f32; 2 * 3 * 4]; rope.apply(&mut x, 3, 2, 0).unwrap();
assert_eq!(x.len(), 24);
}
#[test]
fn rope_odd_head_dim_error() {
assert!(RotaryEmbedding::new(3, 16, 10_000.0).is_err());
}
#[test]
fn rope_sequence_too_long_error() {
let rope = RotaryEmbedding::new(4, 4, 10_000.0).unwrap();
let mut x = vec![0.0_f32; 4];
assert!(matches!(
rope.apply(&mut x, 1, 2, 3),
Err(LmError::SequenceTooLong { .. })
));
}
#[test]
fn rope_cos_sin_tables_at_zero() {
let rope = RotaryEmbedding::new(4, 8, 10_000.0).unwrap();
assert!((rope.cos_at(0, 0) - 1.0).abs() < 1e-6);
assert!(rope.sin_at(0, 0).abs() < 1e-6);
}
#[test]
fn rope_tables_have_correct_dimensions() {
let head_dim = 8;
let max_pos = 16;
let rope = RotaryEmbedding::new(head_dim, max_pos, 10_000.0).unwrap();
assert_eq!(rope.cos_table.len(), max_pos * (head_dim / 2));
assert_eq!(rope.sin_table.len(), max_pos * (head_dim / 2));
}
}