use super::module::Module;
use crate::nn::init::Initializer;
use crate::tensor::{GraphContext, Tensor};
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct Embedding {
pub num_embeddings: usize,
pub embedding_dim: usize,
pub weight: Tensor,
}
impl Embedding {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
num_embeddings: usize,
embedding_dim: usize,
name: &str,
) -> Self {
let weight = Tensor::new_parameter_with_shape(
context,
&format!("{}_weight", name),
vec![num_embeddings, embedding_dim],
Initializer::Normal {
mean: 0.0,
std: 0.02,
},
);
Self {
num_embeddings,
embedding_dim,
weight,
}
}
pub fn from_weight(weight: Tensor, num_embeddings: usize, embedding_dim: usize) -> Self {
Self {
num_embeddings,
embedding_dim,
weight,
}
}
}
impl Module for Embedding {
fn forward(&self, input: &Tensor) -> Tensor {
input.embedding(&self.weight)
}
fn parameters(&self) -> Vec<Tensor> {
vec![self.weight.clone()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::asg::Value;
use crate::runtime::{backend::Backend, cpu_backend::CpuBackend};
use ndarray::{arr2, ArrayD, IxDyn};
use std::collections::HashMap;
#[test]
fn test_embedding_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let embedding = Embedding::new(&context, 1000, 64, "emb");
assert_eq!(embedding.num_embeddings, 1000);
assert_eq!(embedding.embedding_dim, 64);
assert_eq!(embedding.parameters().len(), 1);
}
#[test]
fn test_embedding_forward() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let embedding = Embedding::new(&context, 5, 3, "emb");
let indices = Tensor::new_input(&context, "indices");
let output = embedding.forward(&indices);
context
.borrow_mut()
.main_graph_mut()
.set_output(output.node_id);
let weight_data = arr2(&[
[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0], ])
.into_dyn();
let indices_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.0, 2.0]).unwrap();
let mut inputs = HashMap::new();
inputs.insert("indices".to_string(), Value::Tensor(indices_data));
inputs.insert("emb_weight".to_string(), Value::Tensor(weight_data));
let backend = CpuBackend::new();
let device_data = backend.load_data(&inputs).unwrap();
let mut memo = HashMap::new();
for (name, value) in device_data {
let node_id = context
.borrow()
.main_graph()
.nodes
.iter()
.find(|(_, node)| {
matches!(&node.node_type,
crate::asg::NodeType::Input { name: n } |
crate::asg::NodeType::Parameter { name: n } if n == &name
)
})
.map(|(id, _)| *id);
if let Some(id) = node_id {
memo.insert((0, id), value);
}
}
let graph = context.borrow().main_graph().clone();
let (results, _) = backend.run(&graph, memo).unwrap();
let result = &results[0];
if let Value::Tensor(arr) = result {
assert_eq!(arr.shape(), &[2, 3]);
assert!((arr[[0, 0]] - 1.0).abs() < 1e-5);
assert!((arr[[0, 1]] - 2.0).abs() < 1e-5);
assert!((arr[[0, 2]] - 3.0).abs() < 1e-5);
assert!((arr[[1, 0]] - 7.0).abs() < 1e-5);
assert!((arr[[1, 1]] - 8.0).abs() < 1e-5);
assert!((arr[[1, 2]] - 9.0).abs() < 1e-5);
} else {
panic!("Expected tensor output");
}
}
#[test]
fn test_embedding_2d_indices() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let embedding = Embedding::new(&context, 5, 4, "emb");
let indices = Tensor::new_input(&context, "indices");
let output = embedding.forward(&indices);
context
.borrow_mut()
.main_graph_mut()
.set_output(output.node_id);
let weight_data =
ArrayD::from_shape_vec(IxDyn(&[5, 4]), (0..20).map(|x| x as f32).collect()).unwrap();
let indices_data =
ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![0.0, 1.0, 2.0, 3.0, 4.0, 0.0]).unwrap();
let mut inputs = HashMap::new();
inputs.insert("indices".to_string(), Value::Tensor(indices_data));
inputs.insert("emb_weight".to_string(), Value::Tensor(weight_data));
let backend = CpuBackend::new();
let device_data = backend.load_data(&inputs).unwrap();
let mut memo = HashMap::new();
for (name, value) in device_data {
let node_id = context
.borrow()
.main_graph()
.nodes
.iter()
.find(|(_, node)| {
matches!(&node.node_type,
crate::asg::NodeType::Input { name: n } |
crate::asg::NodeType::Parameter { name: n } if n == &name
)
})
.map(|(id, _)| *id);
if let Some(id) = node_id {
memo.insert((0, id), value);
}
}
let graph = context.borrow().main_graph().clone();
let (results, _) = backend.run(&graph, memo).unwrap();
if let Value::Tensor(arr) = &results[0] {
assert_eq!(arr.shape(), &[2, 3, 4]);
} else {
panic!("Expected tensor output");
}
}
}