burn_core/nn/
embedding.rs1use crate as burn;
2
3use super::Initializer;
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::backend::Backend;
9use crate::tensor::Int;
10use crate::tensor::Tensor;
11
12use crate::tensor::module::embedding;
13
14#[derive(Config)]
16pub struct EmbeddingConfig {
17 pub n_embedding: usize,
19 pub d_model: usize,
21 #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
23 pub initializer: Initializer,
24}
25
26#[derive(Module, Debug)]
30#[module(custom_display)]
31pub struct Embedding<B: Backend> {
32 pub weight: Param<Tensor<B, 2>>,
35}
36
37impl<B: Backend> ModuleDisplay for Embedding<B> {
38 fn custom_settings(&self) -> Option<DisplaySettings> {
39 DisplaySettings::new()
40 .with_new_line_after_attribute(false)
41 .optional()
42 }
43
44 fn custom_content(&self, content: Content) -> Option<Content> {
45 let [n_embedding, d_model] = self.weight.shape().dims();
46 content
47 .add("n_embedding", &n_embedding)
48 .add("d_model", &d_model)
49 .optional()
50 }
51}
52
53impl EmbeddingConfig {
54 pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
56 let weight = self
57 .initializer
58 .init([self.n_embedding, self.d_model], device);
59
60 Embedding { weight }
61 }
62}
63
64impl<B: Backend> Embedding<B> {
65 pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
74 embedding(self.weight.val(), input)
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::tensor::TensorData;
82 use crate::TestBackend;
83
84 #[test]
85 fn initializer_default() {
86 TestBackend::seed(0);
87
88 let config = EmbeddingConfig::new(100, 10);
89 let embed = config.init::<TestBackend>(&Default::default());
90 let weights = embed.weight.val().reshape([1000]);
91 let (var_act, mean_act) = weights.var_mean(0);
92
93 assert_eq!(
94 config.initializer,
95 Initializer::Normal {
96 mean: 0.0,
97 std: 1.0
98 }
99 );
100 var_act
101 .to_data()
102 .assert_approx_eq(&TensorData::from([1.0f32]), 0);
103 mean_act
104 .to_data()
105 .assert_approx_eq(&TensorData::from([0.0f32]), 0);
106 }
107
108 #[test]
109 fn initializer_zeros() {
110 TestBackend::seed(0);
111
112 let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);
113 let embed = config.init::<TestBackend>(&Default::default());
114
115 assert_eq!(config.initializer, Initializer::Zeros);
116 embed
117 .weight
118 .to_data()
119 .assert_approx_eq(&TensorData::zeros::<f32, _>(embed.weight.shape()), 3);
120 }
121
122 #[test]
123 fn display() {
124 let config = EmbeddingConfig::new(100, 10);
125 let embed = config.init::<TestBackend>(&Default::default());
126
127 assert_eq!(
128 alloc::format!("{}", embed),
129 "Embedding {n_embedding: 100, d_model: 10, params: 1000}"
130 );
131 }
132}