use crate::autograd::{self, Variable};
use crate::tensor::{Result, Tensor, TensorOptions, DType, Device};
use super::parameter::Parameter;
use super::Module;
pub struct Embedding {
pub weight: Parameter,
#[allow(dead_code)]
num_embeddings: i64,
embedding_dim: i64,
}
impl Embedding {
pub fn new(num_embeddings: i64, embedding_dim: i64) -> Result<Self> {
Self::on_device(num_embeddings, embedding_dim, Device::CPU)
}
pub fn on_device(num_embeddings: i64, embedding_dim: i64, device: Device) -> Result<Self> {
let weight = Variable::new(
Tensor::randn(
&[num_embeddings, embedding_dim],
TensorOptions { dtype: DType::Float32, device },
)?,
true,
);
Ok(Embedding {
weight: Parameter {
variable: weight,
name: "weight".into(),
},
num_embeddings,
embedding_dim,
})
}
}
impl Module for Embedding {
fn name(&self) -> &str { "embedding" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let input_shape = input.shape();
let numel = input.numel();
let index_tensor = if input.data().dtype() == DType::Int64 {
input.data().reshape(&[numel])?
} else {
let flat_data = input.data().to_f32_vec()?;
let indices: Vec<i64> = flat_data.iter().map(|&v| v as i64).collect();
Tensor::from_i64(&indices, &[numel], input.device())?
};
let selected = self.weight.variable.index_select(0, &index_tensor)?;
let mut output_shape = input_shape;
output_shape.push(self.embedding_dim);
selected.reshape(&output_shape)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
}
pub struct EmbeddingBag {
pub weight: Parameter,
#[allow(dead_code)]
num_embeddings: i64,
#[allow(dead_code)]
embedding_dim: i64,
mode: i64,
}
impl EmbeddingBag {
pub const SUM: i64 = 0;
pub const MEAN: i64 = 1;
pub const MAX: i64 = 2;
pub fn new(num_embeddings: i64, embedding_dim: i64, mode: i64) -> Result<Self> {
Self::on_device(num_embeddings, embedding_dim, mode, Device::CPU)
}
pub fn on_device(
num_embeddings: i64, embedding_dim: i64, mode: i64, device: Device,
) -> Result<Self> {
let weight = Variable::new(
Tensor::randn(
&[num_embeddings, embedding_dim],
TensorOptions { dtype: DType::Float32, device },
)?,
true,
);
Ok(EmbeddingBag {
weight: Parameter {
variable: weight,
name: "weight".into(),
},
num_embeddings,
embedding_dim,
mode,
})
}
pub fn forward_bag(&self, indices: &Tensor, offsets: &Tensor) -> Result<Variable> {
autograd::embedding_bag(&self.weight.variable, indices, offsets, self.mode)
}
}
impl Module for EmbeddingBag {
fn name(&self) -> &str { "embedding_bag" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let shape = input.shape();
if shape.len() != 2 {
return Err(crate::tensor::TensorError::new(&format!(
"EmbeddingBag::forward expects 2-D input [num_bags, bag_size], got {:?}",
shape,
)));
}
let num_bags = shape[0];
let bag_size = shape[1];
let device = input.device();
let flat_indices = if input.data().dtype() == DType::Int64 {
input.data().reshape(&[num_bags * bag_size])?
} else {
let flat_data = input.data().to_f32_vec()?;
let idx: Vec<i64> = flat_data.iter().map(|&v| v as i64).collect();
Tensor::from_i64(&idx, &[num_bags * bag_size], device)?
};
let offsets_vec: Vec<i64> = (0..num_bags).map(|i| i * bag_size).collect();
let offsets = Tensor::from_i64(&offsets_vec, &[num_bags], device)?;
autograd::embedding_bag(&self.weight.variable, &flat_indices, &offsets, self.mode)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::test_device;
#[test]
#[allow(clippy::identity_op, clippy::erasing_op)]
fn embedding_bag_sum_known_values() {
let dev = test_device();
let eb = EmbeddingBag::on_device(5, 3, EmbeddingBag::SUM, dev).unwrap();
let w = eb.weight.variable.data().to_f32_vec().unwrap();
let indices = Tensor::from_i64(&[0, 1, 2, 3, 4], &[5], dev).unwrap();
let offsets = Tensor::from_i64(&[0, 3], &[2], dev).unwrap();
let out = eb.forward_bag(&indices, &offsets).unwrap();
assert_eq!(out.shape(), vec![2, 3]);
let vals = out.data().to_f32_vec().unwrap();
for d in 0..3 {
let expected = w[0 * 3 + d] + w[1 * 3 + d] + w[2 * 3 + d];
assert!((vals[0 * 3 + d] - expected).abs() < 1e-5,
"bag0 dim {d}: got {}, expected {}", vals[0 * 3 + d], expected);
}
for d in 0..3 {
let expected = w[3 * 3 + d] + w[4 * 3 + d];
assert!((vals[1 * 3 + d] - expected).abs() < 1e-5,
"bag1 dim {d}: got {}, expected {}", vals[1 * 3 + d], expected);
}
}
#[test]
#[allow(clippy::identity_op, clippy::erasing_op)]
fn embedding_bag_mean() {
let dev = test_device();
let eb = EmbeddingBag::on_device(4, 2, EmbeddingBag::MEAN, dev).unwrap();
let w = eb.weight.variable.data().to_f32_vec().unwrap();
let indices = Tensor::from_i64(&[0, 1, 2, 3], &[4], dev).unwrap();
let offsets = Tensor::from_i64(&[0, 2], &[2], dev).unwrap();
let out = eb.forward_bag(&indices, &offsets).unwrap();
assert_eq!(out.shape(), vec![2, 2]);
let vals = out.data().to_f32_vec().unwrap();
for d in 0..2 {
let expected = (w[0 * 2 + d] + w[1 * 2 + d]) / 2.0;
assert!((vals[0 * 2 + d] - expected).abs() < 1e-5);
}
for d in 0..2 {
let expected = (w[2 * 2 + d] + w[3 * 2 + d]) / 2.0;
assert!((vals[1 * 2 + d] - expected).abs() < 1e-5);
}
}
#[test]
fn embedding_bag_2d_forward() {
let dev = test_device();
let eb = EmbeddingBag::on_device(10, 4, EmbeddingBag::SUM, dev).unwrap();
let input = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3, 4, 5], &[3, 2], dev).unwrap(),
false,
);
let out = eb.forward(&input).unwrap();
assert_eq!(out.shape(), vec![3, 4]);
let flat_idx = Tensor::from_i64(&[0, 1, 2, 3, 4, 5], &[6], dev).unwrap();
let offsets = Tensor::from_i64(&[0, 2, 4], &[3], dev).unwrap();
let out_bag = eb.forward_bag(&flat_idx, &offsets).unwrap();
let v1 = out.data().to_f32_vec().unwrap();
let v2 = out_bag.data().to_f32_vec().unwrap();
for (a, b) in v1.iter().zip(v2.iter()) {
assert!((a - b).abs() < 1e-6, "forward vs forward_bag mismatch: {a} != {b}");
}
}
#[test]
fn embedding_bag_gradient() {
let dev = test_device();
let eb = EmbeddingBag::on_device(5, 3, EmbeddingBag::SUM, dev).unwrap();
let indices = Tensor::from_i64(&[0, 1, 2, 3], &[4], dev).unwrap();
let offsets = Tensor::from_i64(&[0, 2], &[2], dev).unwrap();
let out = eb.forward_bag(&indices, &offsets).unwrap();
let loss = out.sum().unwrap();
loss.backward().unwrap();
let grad = eb.weight.variable.grad();
assert!(grad.is_some(), "weight should have gradient after backward");
let g = grad.unwrap();
assert_eq!(g.shape(), vec![5, 3]);
let gv = g.to_f32_vec().unwrap();
let row4_sum: f32 = gv[4 * 3..5 * 3].iter().sum();
assert_eq!(row4_sum, 0.0, "unused index should have zero gradient");
}
#[test]
fn embedding_bag_max() {
let dev = test_device();
let eb = EmbeddingBag::on_device(4, 2, EmbeddingBag::MAX, dev).unwrap();
let w = eb.weight.variable.data().to_f32_vec().unwrap();
let indices = Tensor::from_i64(&[0, 1, 2, 3], &[4], dev).unwrap();
let offsets = Tensor::from_i64(&[0], &[1], dev).unwrap();
let out = eb.forward_bag(&indices, &offsets).unwrap();
assert_eq!(out.shape(), vec![1, 2]);
let vals = out.data().to_f32_vec().unwrap();
for d in 0..2 {
let expected = (0..4)
.map(|i| w[i * 2 + d])
.fold(f32::NEG_INFINITY, f32::max);
assert!((vals[d] - expected).abs() < 1e-5,
"max dim {d}: got {}, expected {}", vals[d], expected);
}
}
#[test]
fn embedding_bag_parameters() {
let dev = test_device();
let eb = EmbeddingBag::on_device(10, 8, EmbeddingBag::MEAN, dev).unwrap();
let params = eb.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].name, "weight");
assert_eq!(params[0].variable.shape(), vec![10, 8]);
}
}