use scivex_core::random::Rng;
use scivex_core::{Float, Tensor};
use crate::error::{NnError, Result};
use crate::init;
use crate::variable::Variable;
use super::Layer;
pub struct Embedding<T: Float> {
weight: Variable<T>, num_embeddings: usize,
dim: usize,
}
impl<T: Float> Embedding<T> {
pub fn new(num_embeddings: usize, dim: usize, rng: &mut Rng) -> Self {
let w_data = init::xavier_uniform::<T>(&[num_embeddings, dim], rng);
let weight = Variable::new(w_data, true);
Self {
weight,
num_embeddings,
dim,
}
}
pub fn weight(&self) -> &Variable<T> {
&self.weight
}
pub fn num_embeddings(&self) -> usize {
self.num_embeddings
}
pub fn dim(&self) -> usize {
self.dim
}
}
impl<T: Float> Layer<T> for Embedding<T> {
fn forward(&self, x: &Variable<T>) -> Result<Variable<T>> {
let shape = x.shape();
if shape.len() != 2 {
return Err(NnError::ShapeMismatch {
expected: vec![0, 0],
got: shape,
});
}
let batch = shape[0];
let seq_len = shape[1];
let dim = self.dim;
let num_emb = self.num_embeddings;
let xd = x.data();
let xs = xd.as_slice();
let wd = self.weight.data();
let ws = wd.as_slice();
let mut indices = Vec::with_capacity(batch * seq_len);
for &v in xs {
let idx = v.round();
let idx_f64 = format!("{idx:.0}");
let idx_usize: usize = idx_f64.parse().map_err(|_| NnError::IndexOutOfBounds {
index: 0,
len: num_emb,
})?;
if idx_usize >= num_emb {
return Err(NnError::IndexOutOfBounds {
index: idx_usize,
len: num_emb,
});
}
indices.push(idx_usize);
}
let mut out = vec![T::zero(); batch * seq_len * dim];
for (i, &idx) in indices.iter().enumerate() {
let src = &ws[idx * dim..(idx + 1) * dim];
out[i * dim..(i + 1) * dim].copy_from_slice(src);
}
let out_tensor = Tensor::from_vec(out, vec![batch, seq_len * dim]).expect("valid shape");
let grad_fn = Box::new(move |g: &Tensor<T>| {
let gd = g.as_slice();
let mut gw = vec![T::zero(); num_emb * dim];
for (i, &idx) in indices.iter().enumerate() {
for j in 0..dim {
gw[idx * dim + j] += gd[i * dim + j];
}
}
vec![
Tensor::zeros(vec![batch, seq_len]),
Tensor::from_vec(gw, vec![num_emb, dim]).expect("valid shape"),
]
});
Ok(Variable::from_op(
out_tensor,
vec![x.clone(), self.weight.clone()],
grad_fn,
))
}
fn parameters(&self) -> Vec<Variable<T>> {
vec![self.weight.clone()]
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_output_shape() {
let mut rng = Rng::new(42);
let emb = Embedding::<f64>::new(10, 8, &mut rng);
let x = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap(),
false,
);
let y = emb.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 24]); }
#[test]
fn test_embedding_values() {
let mut rng = Rng::new(42);
let emb = Embedding::<f64>::new(5, 3, &mut rng);
let x = Variable::new(Tensor::from_vec(vec![0.0, 2.0], vec![1, 2]).unwrap(), false);
let y = emb.forward(&x).unwrap();
let yd = y.data();
let ys = yd.as_slice();
let wd = emb.weight().data();
let ws = wd.as_slice();
assert_eq!(&ys[0..3], &ws[0..3]);
assert_eq!(&ys[3..6], &ws[6..9]);
}
#[test]
fn test_embedding_out_of_bounds() {
let mut rng = Rng::new(42);
let emb = Embedding::<f64>::new(5, 3, &mut rng);
let x = Variable::new(Tensor::from_vec(vec![10.0], vec![1, 1]).unwrap(), false);
assert!(emb.forward(&x).is_err());
}
#[test]
fn test_embedding_backward() {
let mut rng = Rng::new(42);
let emb = Embedding::<f64>::new(5, 4, &mut rng);
let x = Variable::new(Tensor::from_vec(vec![1.0, 3.0], vec![1, 2]).unwrap(), false);
let y = emb.forward(&x).unwrap();
let loss = crate::ops::sum(&y);
loss.backward();
let gw = emb.weight().grad().unwrap();
assert_eq!(gw.shape(), &[5, 4]);
let gs = gw.as_slice();
for j in 0..4 {
assert!((gs[4 + j] - 1.0).abs() < f64::EPSILON);
assert!((gs[3 * 4 + j] - 1.0).abs() < f64::EPSILON);
assert!(gs[j].abs() < f64::EPSILON);
}
}
#[test]
fn test_embedding_parameters() {
let mut rng = Rng::new(42);
let emb = Embedding::<f64>::new(10, 8, &mut rng);
assert_eq!(emb.parameters().len(), 1);
}
}