burn_core/nn/
embedding.rs

1use crate as burn;
2
3use super::Initializer;
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::Int;
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11
12use crate::tensor::module::embedding;
13
14/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).
15#[derive(Config)]
16pub struct EmbeddingConfig {
17    /// The number of embedding vectors.
18    pub n_embedding: usize,
19    /// The size of each vector.
20    pub d_model: usize,
21    /// The type of function used to initialize neural network parameters
22    #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
23    pub initializer: Initializer,
24}
25
26/// Lookup table to store a fix number of vectors.
27///
28/// Should be created with [EmbeddingConfig].
29#[derive(Module, Debug)]
30#[module(custom_display)]
31pub struct Embedding<B: Backend> {
32    /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
33    /// from a normal distribution `N(0, 1)`.
34    pub weight: Param<Tensor<B, 2>>,
35}
36
37impl<B: Backend> ModuleDisplay for Embedding<B> {
38    fn custom_settings(&self) -> Option<DisplaySettings> {
39        DisplaySettings::new()
40            .with_new_line_after_attribute(false)
41            .optional()
42    }
43
44    fn custom_content(&self, content: Content) -> Option<Content> {
45        let [n_embedding, d_model] = self.weight.shape().dims();
46        content
47            .add("n_embedding", &n_embedding)
48            .add("d_model", &d_model)
49            .optional()
50    }
51}
52
53impl EmbeddingConfig {
54    /// Initialize a new [embedding](Embedding) module.
55    pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
56        let weight = self
57            .initializer
58            .init([self.n_embedding, self.d_model], device);
59
60        Embedding { weight }
61    }
62}
63
64impl<B: Backend> Embedding<B> {
65    /// Applies the forward pass on the input tensor.
66    ///
67    /// See also [embedding](crate::tensor::module::embedding).
68    ///
69    /// # Shapes
70    ///
71    /// - input: `[batch_size, seq_length]`
72    /// - output: `[batch_size, seq_length, d_model]`
73    pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
74        embedding(self.weight.val(), input)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::TestBackend;
82    use crate::tensor::TensorData;
83    use burn_tensor::{Tolerance, ops::FloatElem};
84    type FT = FloatElem<TestBackend>;
85
86    #[test]
87    fn initializer_default() {
88        TestBackend::seed(0);
89
90        let config = EmbeddingConfig::new(100, 10);
91        let embed = config.init::<TestBackend>(&Default::default());
92        let weights = embed.weight.val().reshape([1000]);
93        let (var_act, mean_act) = weights.var_mean(0);
94
95        assert_eq!(
96            config.initializer,
97            Initializer::Normal {
98                mean: 0.0,
99                std: 1.0
100            }
101        );
102        var_act
103            .to_data()
104            .assert_approx_eq::<FT>(&TensorData::from([1.0f32]), Tolerance::relative(5e-2));
105        mean_act
106            .to_data()
107            .assert_approx_eq::<FT>(&TensorData::from([0.0f32]), Tolerance::absolute(1e-1));
108    }
109
110    #[test]
111    fn initializer_zeros() {
112        TestBackend::seed(0);
113
114        let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);
115        let embed = config.init::<TestBackend>(&Default::default());
116
117        assert_eq!(config.initializer, Initializer::Zeros);
118        embed.weight.to_data().assert_approx_eq::<FT>(
119            &TensorData::zeros::<f32, _>(embed.weight.shape()),
120            Tolerance::default(),
121        );
122    }
123
124    #[test]
125    fn display() {
126        let config = EmbeddingConfig::new(100, 10);
127        let embed = config.init::<TestBackend>(&Default::default());
128
129        assert_eq!(
130            alloc::format!("{}", embed),
131            "Embedding {n_embedding: 100, d_model: 10, params: 1000}"
132        );
133    }
134}