use std::any::Any;
use std::collections::HashMap;
use axonml_autograd::{GradFn, GradientFunction, Variable};
use axonml_tensor::Tensor;
use crate::init::normal;
use crate::module::Module;
use crate::parameter::Parameter;
#[derive(Debug)]
struct EmbeddingBackward {
next_fns: Vec<Option<GradFn>>,
indices: Vec<usize>,
num_embeddings: usize,
embedding_dim: usize,
}
impl GradientFunction for EmbeddingBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
#[cfg(feature = "cuda")]
if grad_output.device().is_gpu() {
let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
let grad_tensor = grad_output.embedding_scatter_add_cuda(
&indices_u32,
self.num_embeddings,
self.embedding_dim,
);
return vec![Some(grad_tensor)];
}
let grad_data = grad_output.to_vec();
let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
for (i, &idx) in self.indices.iter().enumerate() {
if idx < self.num_embeddings {
let src_offset = i * self.embedding_dim;
let dst_offset = idx * self.embedding_dim;
for d in 0..self.embedding_dim {
weight_grad[dst_offset + d] += grad_data[src_offset + d];
}
}
}
let grad_tensor = Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim])
.expect("tensor creation failed");
vec![Some(grad_tensor)]
}
fn name(&self) -> &'static str {
"EmbeddingBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
pub struct Embedding {
pub weight: Parameter,
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
}
impl Embedding {
pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
Self::with_options(num_embeddings, embedding_dim, None)
}
pub fn with_options(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
) -> Self {
let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
if let Some(pad_idx) = padding_idx {
let mut data = weight_data.to_vec();
for i in 0..embedding_dim {
data[pad_idx * embedding_dim + i] = 0.0;
}
weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim])
.expect("tensor creation failed");
}
Self {
weight: Parameter::named("weight", weight_data, true),
num_embeddings,
embedding_dim,
padding_idx,
}
}
pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
let shape = weights.shape();
let num_embeddings = shape[0];
let embedding_dim = shape[1];
Self {
weight: Parameter::named("weight", weights, !freeze),
num_embeddings,
embedding_dim,
padding_idx: None,
}
}
pub fn num_embeddings(&self) -> usize {
self.num_embeddings
}
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
pub fn lookup(&self, indices: &Variable) -> Variable {
let indices_data = indices.data();
let indices_vec = indices_data.to_vec();
let indices_shape = indices_data.shape().to_vec();
let mut output_shape = indices_shape.clone();
output_shape.push(self.embedding_dim);
let output_size: usize = output_shape.iter().product();
let mut safe_indices = Vec::with_capacity(indices_vec.len());
let mut gather_idx = Vec::with_capacity(output_size);
for &idx_f in &indices_vec {
let idx = idx_f as usize;
let safe_idx = if idx >= self.num_embeddings {
#[cfg(debug_assertions)]
eprintln!(
"Warning: embedding index {} out of range (max {}), using padding index 0",
idx,
self.num_embeddings - 1
);
0
} else {
idx
};
safe_indices.push(safe_idx);
let base = safe_idx * self.embedding_dim;
for d in 0..self.embedding_dim {
gather_idx.push((base + d) as u32);
}
}
let weight_data = self.weight.data();
#[cfg(feature = "cuda")]
let weight_device = weight_data.device();
#[cfg(feature = "cuda")]
let output_tensor = if weight_device.is_gpu() {
weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
} else {
let weight_vec = weight_data.to_vec();
let output_data: Vec<f32> =
gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
};
#[cfg(not(feature = "cuda"))]
let output_tensor = {
let weight_vec = weight_data.to_vec();
let output_data: Vec<f32> =
gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
};
if self.weight.requires_grad() {
let grad_fn = GradFn::new(EmbeddingBackward {
next_fns: vec![self.weight.variable().grad_fn().cloned()],
indices: safe_indices,
num_embeddings: self.num_embeddings,
embedding_dim: self.embedding_dim,
});
Variable::from_operation(output_tensor, grad_fn, true)
} else {
Variable::new(output_tensor, false)
}
}
}
impl Module for Embedding {
fn forward(&self, input: &Variable) -> Variable {
self.lookup(input)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.insert("weight".to_string(), self.weight.clone());
params
}
fn name(&self) -> &'static str {
"Embedding"
}
}
impl std::fmt::Debug for Embedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedding")
.field("num_embeddings", &self.num_embeddings)
.field("embedding_dim", &self.embedding_dim)
.field("padding_idx", &self.padding_idx)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_creation() {
let emb = Embedding::new(1000, 128);
assert_eq!(emb.num_embeddings(), 1000);
assert_eq!(emb.embedding_dim(), 128);
}
#[test]
fn test_embedding_lookup() {
let emb = Embedding::new(10, 4);
let indices = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
false,
);
let output = emb.forward(&indices);
assert_eq!(output.shape(), vec![3, 4]);
}
#[test]
fn test_embedding_batch() {
let emb = Embedding::new(10, 4);
let indices = Variable::new(
Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3])
.expect("tensor creation failed"),
false,
);
let output = emb.forward(&indices);
assert_eq!(output.shape(), vec![2, 3, 4]);
}
#[test]
fn test_embedding_parameters() {
let emb = Embedding::new(100, 64);
assert_eq!(emb.parameters().len(), 1);
assert_eq!(emb.num_parameters(), 100 * 64);
}
#[test]
fn test_embedding_with_padding() {
let emb = Embedding::with_options(10, 4, Some(0));
let indices = Variable::new(
Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
false,
);
let output = emb.forward(&indices);
let output_vec = output.data().to_vec();
assert!(output_vec.iter().all(|&x| x == 0.0));
}
}