use crate::autograd::Variable;
use crate::tensor::{Result, Tensor, TensorOptions, DType, Device};
use super::parameter::Parameter;
use super::Module;
pub struct Embedding {
pub weight: Parameter,
#[allow(dead_code)]
num_embeddings: i64,
embedding_dim: i64,
}
impl Embedding {
pub fn new(num_embeddings: i64, embedding_dim: i64) -> Result<Self> {
Self::on_device(num_embeddings, embedding_dim, Device::CPU)
}
pub fn on_device(num_embeddings: i64, embedding_dim: i64, device: Device) -> Result<Self> {
let weight = Variable::new(
Tensor::randn(
&[num_embeddings, embedding_dim],
TensorOptions { dtype: DType::Float32, device },
)?,
true,
);
Ok(Embedding {
weight: Parameter {
variable: weight,
name: "weight".into(),
},
num_embeddings,
embedding_dim,
})
}
}
impl Module for Embedding {
fn name(&self) -> &str { "embedding" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let input_shape = input.shape();
let numel = input.numel();
let index_tensor = if input.data().dtype() == DType::Int64 {
input.data().reshape(&[numel])?
} else {
let flat_data = input.data().to_f32_vec()?;
let indices: Vec<i64> = flat_data.iter().map(|&v| v as i64).collect();
Tensor::from_i64(&indices, &[numel], input.device())?
};
let selected = self.weight.variable.index_select(0, &index_tensor)?;
let mut output_shape = input_shape;
output_shape.push(self.embedding_dim);
selected.reshape(&output_shape)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
}