1use candle::{Result, Tensor};
3
4#[derive(Clone, Debug)]
5pub struct Embedding {
6 embeddings: Tensor,
7 hidden_size: usize,
8}
9
10impl Embedding {
11 pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
12 Self {
13 embeddings,
14 hidden_size,
15 }
16 }
17
18 pub fn embeddings(&self) -> &Tensor {
19 &self.embeddings
20 }
21
22 pub fn hidden_size(&self) -> usize {
24 self.hidden_size
25 }
26}
27
28impl crate::Module for Embedding {
29 fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
30 let mut final_dims = indexes.dims().to_vec();
31 final_dims.push(self.hidden_size);
32 let indexes = indexes.flatten_all()?;
33 let values = self.embeddings.index_select(&indexes, 0)?;
34 let values = values.reshape(final_dims)?;
35 Ok(values)
36 }
37}
38
39pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
40 let embeddings = vb.get_with_hints(
41 (in_size, out_size),
42 "weight",
43 crate::Init::Randn {
44 mean: 0.,
45 stdev: 1.,
46 },
47 )?;
48 Ok(Embedding::new(embeddings, out_size))
49}