use crate::error::{CognitionError, Result};
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Embedding {
pub weights: Tensor,
pub vocab_size: usize,
pub embed_dim: usize,
}
impl Embedding {
pub fn new(vocab_size: usize, embed_dim: usize, rng: &mut impl rand::Rng) -> Self {
let scale = 1.0 / (embed_dim as f64).sqrt();
let weights = Tensor::randn(&[vocab_size, embed_dim], rng).scale(scale);
Self {
weights,
vocab_size,
embed_dim,
}
}
pub fn forward(&self, token_id: usize) -> Result<Tensor> {
if token_id >= self.vocab_size {
return Err(CognitionError::TokenOutOfRange {
id: token_id,
vocab_size: self.vocab_size,
});
}
self.weights.row(token_id)
}
pub fn forward_batch(&self, token_ids: &[usize]) -> Result<Tensor> {
let rows: Vec<Tensor> = token_ids
.iter()
.map(|&id| self.forward(id))
.collect::<Result<Vec<_>>>()?;
Tensor::stack_rows(&rows)
}
}
#[derive(Debug, Clone)]
pub struct PositionalEncoding {
pub table: Tensor,
pub max_seq_len: usize,
pub embed_dim: usize,
}
impl PositionalEncoding {
pub fn new(max_seq_len: usize, embed_dim: usize) -> Self {
let mut data = vec![0.0; max_seq_len * embed_dim];
for pos in 0..max_seq_len {
for i in 0..(embed_dim / 2) {
let angle = pos as f64 / 10000_f64.powf(2.0 * i as f64 / embed_dim as f64);
data[pos * embed_dim + 2 * i] = angle.sin();
data[pos * embed_dim + 2 * i + 1] = angle.cos();
}
if embed_dim % 2 == 1 {
let angle =
pos as f64 / 10000_f64.powf(2.0 * (embed_dim / 2) as f64 / embed_dim as f64);
data[pos * embed_dim + embed_dim - 1] = angle.sin();
}
}
let table = match Tensor::new(data, vec![max_seq_len, embed_dim]) {
Ok(t) => t,
Err(_) => Tensor::zeros(&[max_seq_len, embed_dim]),
};
Self {
table,
max_seq_len,
embed_dim,
}
}
pub fn forward(&self, embeddings: &Tensor) -> Result<Tensor> {
let shape = embeddings.shape();
if shape.len() != 2 || shape[1] != self.embed_dim {
return Err(CognitionError::ShapeMismatch {
expected: vec![shape[0], self.embed_dim],
got: shape.to_vec(),
operation: "positional_encoding",
});
}
let seq_len = shape[0];
if seq_len > self.max_seq_len {
return Err(CognitionError::InvalidConfig(format!(
"sequence length {seq_len} exceeds max positional encoding length {}",
self.max_seq_len
)));
}
let pe_data = &self.table.data()[..seq_len * self.embed_dim];
let pe = Tensor::new(pe_data.to_vec(), vec![seq_len, self.embed_dim])?;
embeddings.add(&pe)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_lookup() {
let mut rng = rand::rng();
let emb = Embedding::new(100, 16, &mut rng);
let vec = emb.forward(0).unwrap();
assert_eq!(vec.shape(), &[16]);
}
#[test]
fn test_embedding_batch() {
let mut rng = rand::rng();
let emb = Embedding::new(100, 16, &mut rng);
let batch = emb.forward_batch(&[0, 1, 2]).unwrap();
assert_eq!(batch.shape(), &[3, 16]);
}
#[test]
fn test_positional_encoding_shape() {
let pe = PositionalEncoding::new(512, 16);
let mut rng = rand::rng();
let emb = Embedding::new(100, 16, &mut rng);
let embedded = emb.forward_batch(&[0, 1, 2]).unwrap();
let result = pe.forward(&embedded).unwrap();
assert_eq!(result.shape(), &[3, 16]);
}
#[test]
fn test_out_of_vocab() {
let mut rng = rand::rng();
let emb = Embedding::new(10, 4, &mut rng);
assert!(emb.forward(10).is_err());
}
}