candle_nn/
embedding.rs

1//! Embedding Layer.
2use candle::{Result, Tensor};
3
4#[derive(Clone, Debug)]
5pub struct Embedding {
6    embeddings: Tensor,
7    hidden_size: usize,
8}
9
10impl Embedding {
11    pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
12        Self {
13            embeddings,
14            hidden_size,
15        }
16    }
17
18    pub fn embeddings(&self) -> &Tensor {
19        &self.embeddings
20    }
21
22    /// Get the hidden size of the embedding matrix
23    pub fn hidden_size(&self) -> usize {
24        self.hidden_size
25    }
26}
27
28impl crate::Module for Embedding {
29    fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
30        let mut final_dims = indexes.dims().to_vec();
31        final_dims.push(self.hidden_size);
32        let indexes = indexes.flatten_all()?;
33        let values = self.embeddings.index_select(&indexes, 0)?;
34        let values = values.reshape(final_dims)?;
35        Ok(values)
36    }
37}
38
39pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
40    let embeddings = vb.get_with_hints(
41        (in_size, out_size),
42        "weight",
43        crate::Init::Randn {
44            mean: 0.,
45            stdev: 1.,
46        },
47    )?;
48    Ok(Embedding::new(embeddings, out_size))
49}