1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use crate as burn;

use super::Initializer;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Int;
use crate::tensor::Tensor;

use crate::tensor::module::embedding;

/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).
#[derive(Config)]
pub struct EmbeddingConfig {
    /// The number of embedding vectors.
    pub n_embedding: usize,
    /// The size of each vector.
    pub d_model: usize,
    /// The type of function used to initialize neural network parameters
    #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
    pub initializer: Initializer,
}

/// Lookup table to store a fix number of vectors.
///
/// Should be created with [EmbeddingConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Embedding<B: Backend> {
    /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
    /// from a normal distribution `N(0, 1)`.
    pub weight: Param<Tensor<B, 2>>,
}

impl<B: Backend> ModuleDisplay for Embedding<B> {
    fn custom_settings(&self) -> Option<DisplaySettings> {
        DisplaySettings::new()
            .with_new_line_after_attribute(false)
            .optional()
    }

    fn custom_content(&self, content: Content) -> Option<Content> {
        let [n_embedding, d_model] = self.weight.shape().dims;
        content
            .add("n_embedding", &n_embedding)
            .add("d_model", &d_model)
            .optional()
    }
}

impl EmbeddingConfig {
    /// Initialize a new [embedding](Embedding) module.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
        let weight = self
            .initializer
            .init([self.n_embedding, self.d_model], device);

        Embedding { weight }
    }
}

impl<B: Backend> Embedding<B> {
    /// Applies the forward pass on the input tensor.
    ///
    /// See also [embedding](crate::tensor::module::embedding).
    ///
    /// # Shapes
    ///
    /// - input: `[batch_size, seq_length]`
    /// - output: `[batch_size, d_model]`
    pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
        embedding(self.weight.val(), input)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::TensorData;
    use crate::TestBackend;

    #[test]
    fn initializer_default() {
        TestBackend::seed(0);

        let config = EmbeddingConfig::new(100, 10);
        let embed = config.init::<TestBackend>(&Default::default());
        let weights = embed.weight.val().reshape([1000]);
        let (var_act, mean_act) = weights.var_mean(0);

        assert_eq!(
            config.initializer,
            Initializer::Normal {
                mean: 0.0,
                std: 1.0
            }
        );
        var_act
            .to_data()
            .assert_approx_eq(&TensorData::from([1.0f32]), 0);
        mean_act
            .to_data()
            .assert_approx_eq(&TensorData::from([0.0f32]), 0);
    }

    #[test]
    fn initializer_zeros() {
        TestBackend::seed(0);

        let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);
        let embed = config.init::<TestBackend>(&Default::default());

        assert_eq!(config.initializer, Initializer::Zeros);
        embed
            .weight
            .to_data()
            .assert_approx_eq(&TensorData::zeros::<f32, _>(embed.weight.shape()), 3);
    }

    #[test]
    fn display() {
        let config = EmbeddingConfig::new(100, 10);
        let embed = config.init::<TestBackend>(&Default::default());

        assert_eq!(
            alloc::format!("{}", embed),
            "Embedding {n_embedding: 100, d_model: 10, params: 1000}"
        );
    }
}