use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{Float, FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use crate::init;
use crate::module::Module;
use crate::parameter::Parameter;
#[derive(Debug)]
pub struct EmbeddingBackward<T: Float> {
weight: Tensor<T>,
indices: Vec<usize>,
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
}
impl<T: Float> GradFn<T> for EmbeddingBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let go_data = cpu_go.data()?;
let dim = self.embedding_dim;
let mut grad_weight =
vec![<T as num_traits::Zero>::zero(); self.num_embeddings * dim];
for (i, &idx) in self.indices.iter().enumerate() {
let go_row = &go_data[i * dim..(i + 1) * dim];
let gw_row = &mut grad_weight[idx * dim..(idx + 1) * dim];
for (gw, &go) in gw_row.iter_mut().zip(go_row.iter()) {
*gw += go;
}
}
if let Some(pad_idx) = self.padding_idx {
let start = pad_idx * dim;
for v in &mut grad_weight[start..start + dim] {
*v = <T as num_traits::Zero>::zero();
}
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_weight),
vec![self.num_embeddings, dim],
false,
)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.weight]
}
fn name(&self) -> &'static str {
"EmbeddingBackward"
}
}
#[derive(Debug)]
pub struct Embedding<T: Float> {
pub weight: Parameter<T>,
pub num_embeddings: usize,
pub embedding_dim: usize,
pub padding_idx: Option<usize>,
training: bool,
}
impl<T: Float> Embedding<T> {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
) -> FerrotorchResult<Self> {
if let Some(idx) = padding_idx {
if idx >= num_embeddings {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"padding_idx {idx} is out of range for num_embeddings {num_embeddings}"
),
});
}
}
let mut weight = Parameter::zeros(&[num_embeddings, embedding_dim])?;
init::normal(&mut weight, 0.0, 1.0)?;
if let Some(idx) = padding_idx {
let data = weight.data()?.to_vec();
let mut new_data = data;
let start = idx * embedding_dim;
for v in &mut new_data[start..start + embedding_dim] {
*v = <T as num_traits::Zero>::zero();
}
weight = Parameter::new(Tensor::from_storage(
TensorStorage::cpu(new_data),
vec![num_embeddings, embedding_dim],
true,
)?);
}
Ok(Self {
weight,
num_embeddings,
embedding_dim,
padding_idx,
training: true,
})
}
pub fn from_pretrained(
weight: Tensor<T>,
padding_idx: Option<usize>,
) -> FerrotorchResult<Self> {
if weight.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Embedding weight must be 2-D, got shape {:?}",
weight.shape()
),
});
}
let num_embeddings = weight.shape()[0];
let embedding_dim = weight.shape()[1];
if let Some(idx) = padding_idx {
if idx >= num_embeddings {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"padding_idx {idx} is out of range for num_embeddings {num_embeddings}"
),
});
}
}
Ok(Self {
weight: Parameter::new(weight),
num_embeddings,
embedding_dim,
padding_idx,
training: true,
})
}
}
impl<T: Float> Module<T> for Embedding<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Embedding input must be 1-D, got shape {:?}",
input.shape()
),
});
}
let input_data = input.data_vec()?;
let cpu_weight = if self.weight.tensor().is_cuda() { self.weight.tensor().cpu()? } else { self.weight.tensor().clone() };
let weight_data = cpu_weight.data()?;
let dim = self.embedding_dim;
let n = input_data.len();
let mut indices = Vec::with_capacity(n);
for (i, &val) in input_data.iter().enumerate() {
let idx = num_traits::ToPrimitive::to_usize(&val).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!(
"Embedding index at position {i} cannot be converted to usize: {val:?}"
),
}
})?;
if idx >= self.num_embeddings {
return Err(FerrotorchError::IndexOutOfBounds {
index: idx,
axis: 0,
size: self.num_embeddings,
});
}
indices.push(idx);
}
let mut output_data = Vec::with_capacity(n * dim);
for &idx in &indices {
let row_start = idx * dim;
output_data.extend_from_slice(&weight_data[row_start..row_start + dim]);
}
if let Some(pad_idx) = self.padding_idx {
for (i, &idx) in indices.iter().enumerate() {
if idx == pad_idx {
let start = i * dim;
for v in &mut output_data[start..start + dim] {
*v = <T as num_traits::Zero>::zero();
}
}
}
}
let output_shape = vec![n, dim];
if self.weight.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(EmbeddingBackward {
weight: self.weight.tensor().clone(),
indices,
num_embeddings: self.num_embeddings,
embedding_dim: dim,
padding_idx: self.padding_idx,
});
Tensor::from_operation(TensorStorage::cpu(output_data), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), output_shape, false)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.weight]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.weight]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("weight".to_string(), &self.weight)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::autograd::graph::backward;
use ferrotorch_core::storage::TensorStorage;
fn index_tensor(indices: &[f32]) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(indices.to_vec()),
vec![indices.len()],
false,
)
.unwrap()
}
#[test]
fn test_forward_shape() {
let emb = Embedding::<f32>::new(10, 4, None).unwrap();
let indices = index_tensor(&[0.0, 3.0, 7.0]);
let output = emb.forward(&indices).unwrap();
assert_eq!(output.shape(), &[3, 4]);
}
#[test]
fn test_forward_correct_values() {
let weight_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![4, 3],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, None).unwrap();
let indices = index_tensor(&[2.0, 0.0]);
let output = emb.forward(&indices).unwrap();
let data = output.data().unwrap();
assert_eq!(data.len(), 6);
assert!((data[0] - 6.0).abs() < 1e-6);
assert!((data[1] - 7.0).abs() < 1e-6);
assert!((data[2] - 8.0).abs() < 1e-6);
assert!((data[3] - 0.0).abs() < 1e-6);
assert!((data[4] - 1.0).abs() < 1e-6);
assert!((data[5] - 2.0).abs() < 1e-6);
}
#[test]
fn test_forward_single_index() {
let emb = Embedding::<f32>::new(5, 8, None).unwrap();
let indices = index_tensor(&[3.0]);
let output = emb.forward(&indices).unwrap();
assert_eq!(output.shape(), &[1, 8]);
}
#[test]
fn test_padding_idx_zeros() {
let emb = Embedding::<f32>::new(5, 3, Some(2)).unwrap();
let w_data = emb.weight.data().unwrap();
let pad_start = 2 * 3;
for i in 0..3 {
assert!(
(w_data[pad_start + i] - 0.0).abs() < 1e-6,
"padding row weight[2][{i}] should be 0, got {}",
w_data[pad_start + i]
);
}
let indices = index_tensor(&[2.0]);
let output = emb.forward(&indices).unwrap();
let data = output.data().unwrap();
for i in 0..3 {
assert!(
(data[i] - 0.0).abs() < 1e-6,
"padding output[0][{i}] should be 0, got {}",
data[i]
);
}
}
#[test]
fn test_padding_idx_mixed() {
let weight_data: Vec<f32> = vec![
1.0, 2.0, 0.0, 0.0, 5.0, 6.0, ];
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![3, 2],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, Some(1)).unwrap();
let indices = index_tensor(&[0.0, 1.0, 2.0]);
let output = emb.forward(&indices).unwrap();
let data = output.data().unwrap();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 0.0).abs() < 1e-6);
assert!((data[3] - 0.0).abs() < 1e-6);
assert!((data[4] - 5.0).abs() < 1e-6);
assert!((data[5] - 6.0).abs() < 1e-6);
}
#[test]
fn test_padding_idx_out_of_range() {
let result = Embedding::<f32>::new(5, 3, Some(10));
assert!(result.is_err());
}
#[test]
fn test_out_of_bounds_error() {
let emb = Embedding::<f32>::new(5, 3, None).unwrap();
let indices = index_tensor(&[0.0, 5.0]); let result = emb.forward(&indices);
assert!(result.is_err());
}
#[test]
fn test_negative_index_error() {
let emb = Embedding::<f32>::new(5, 3, None).unwrap();
let indices = index_tensor(&[-1.0]); let result = emb.forward(&indices);
assert!(result.is_err());
}
#[test]
fn test_non_1d_input_error() {
let emb = Embedding::<f32>::new(5, 3, None).unwrap();
let input = Tensor::from_storage(
TensorStorage::cpu(vec![0.0f32, 1.0, 2.0, 3.0]),
vec![2, 2],
false,
)
.unwrap();
let result = emb.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_backward_simple() {
let weight_data: Vec<f32> = vec![
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, ];
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![3, 2],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, None).unwrap();
let indices = index_tensor(&[0.0, 2.0]);
let output = emb.forward(&indices).unwrap();
assert!(output.requires_grad());
assert_eq!(output.grad_fn().unwrap().name(), "EmbeddingBackward");
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32; 4]),
vec![2, 2],
false,
)
.unwrap();
let grad_fn = output.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_weight = grads[0].as_ref().unwrap();
assert_eq!(grad_weight.shape(), &[3, 2]);
let gd = grad_weight.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-6);
assert!((gd[1] - 1.0).abs() < 1e-6);
assert!((gd[2] - 0.0).abs() < 1e-6);
assert!((gd[3] - 0.0).abs() < 1e-6);
assert!((gd[4] - 1.0).abs() < 1e-6);
assert!((gd[5] - 1.0).abs() < 1e-6);
}
#[test]
fn test_backward_duplicate_indices() {
let weight_data: Vec<f32> = vec![
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, ];
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![3, 2],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, None).unwrap();
let indices = index_tensor(&[1.0, 1.0, 0.0, 1.0]);
let output = emb.forward(&indices).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
vec![4, 2],
false,
)
.unwrap();
let grad_fn = output.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_weight = grads[0].as_ref().unwrap();
let gd = grad_weight.data().unwrap();
assert!((gd[0] - 5.0).abs() < 1e-6, "gd[0] = {}, expected 5", gd[0]);
assert!((gd[1] - 6.0).abs() < 1e-6, "gd[1] = {}, expected 6", gd[1]);
assert!(
(gd[2] - 11.0).abs() < 1e-6,
"gd[2] = {}, expected 11",
gd[2]
);
assert!(
(gd[3] - 14.0).abs() < 1e-6,
"gd[3] = {}, expected 14",
gd[3]
);
assert!((gd[4] - 0.0).abs() < 1e-6, "gd[4] = {}, expected 0", gd[4]);
assert!((gd[5] - 0.0).abs() < 1e-6, "gd[5] = {}, expected 0", gd[5]);
}
#[test]
fn test_backward_padding_idx_zeroed() {
let weight_data: Vec<f32> = vec![
1.0, 2.0, 0.0, 0.0, 5.0, 6.0, ];
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![3, 2],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, Some(1)).unwrap();
let indices = index_tensor(&[0.0, 1.0, 2.0]);
let output = emb.forward(&indices).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32; 6]),
vec![3, 2],
false,
)
.unwrap();
let grad_fn = output.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_weight = grads[0].as_ref().unwrap();
let gd = grad_weight.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-6);
assert!((gd[1] - 1.0).abs() < 1e-6);
assert!((gd[2] - 0.0).abs() < 1e-6, "padding grad[1][0] should be 0");
assert!((gd[3] - 0.0).abs() < 1e-6, "padding grad[1][1] should be 0");
assert!((gd[4] - 1.0).abs() < 1e-6);
assert!((gd[5] - 1.0).abs() < 1e-6);
}
#[test]
fn test_backward_end_to_end() {
let weight_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let weight = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![3, 2],
true,
)
.unwrap();
let emb = Embedding::from_pretrained(weight, None).unwrap();
let indices = index_tensor(&[1.0, 0.0]);
let output = emb.forward(&indices).unwrap();
let out_data = output.data().unwrap();
let total: f32 = out_data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_val = grad_output.data()?[0];
let grad = vec![go_val; self.input.numel()];
let t = Tensor::from_storage(
TensorStorage::cpu(grad),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(SumBackward {
input: output.clone(),
}),
)
.unwrap();
backward(&loss).unwrap();
let grad = emb.weight.tensor().grad().unwrap().unwrap();
let gd = grad.data().unwrap();
assert_eq!(gd.len(), 6);
assert!((gd[0] - 1.0).abs() < 1e-6, "grad[0][0] = {}", gd[0]);
assert!((gd[1] - 1.0).abs() < 1e-6, "grad[0][1] = {}", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-6, "grad[1][0] = {}", gd[2]);
assert!((gd[3] - 1.0).abs() < 1e-6, "grad[1][1] = {}", gd[3]);
assert!((gd[4] - 0.0).abs() < 1e-6, "grad[2][0] = {}", gd[4]);
assert!((gd[5] - 0.0).abs() < 1e-6, "grad[2][1] = {}", gd[5]);
}
#[test]
fn test_module_parameters() {
let emb = Embedding::<f32>::new(10, 4, None).unwrap();
assert_eq!(emb.parameters().len(), 1);
assert_eq!(emb.parameters()[0].shape(), &[10, 4]);
}
#[test]
fn test_module_named_parameters() {
let emb = Embedding::<f32>::new(5, 3, None).unwrap();
let named = emb.named_parameters();
assert_eq!(named.len(), 1);
assert_eq!(named[0].0, "weight");
}
#[test]
fn test_module_train_eval() {
let mut emb = Embedding::<f32>::new(5, 3, None).unwrap();
assert!(emb.is_training());
emb.eval();
assert!(!emb.is_training());
emb.train();
assert!(emb.is_training());
}
#[test]
fn test_embedding_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Embedding<f32>>();
assert_send_sync::<Embedding<f64>>();
}
#[test]
fn test_f64_embedding() {
let emb = Embedding::<f64>::new(5, 3, None).unwrap();
let indices = Tensor::from_storage(
TensorStorage::cpu(vec![0.0f64, 2.0, 4.0]),
vec![3],
false,
)
.unwrap();
let output = emb.forward(&indices).unwrap();
assert_eq!(output.shape(), &[3, 3]);
}
}