burn_core/nn/
linear.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Param;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::{Tensor, backend::Backend};
7
8use super::Initializer;
9
10/// Configuration to create a [Linear](Linear) layer using the [init function](LinearConfig::init).
11#[derive(Config, Debug)]
12pub struct LinearConfig {
13    /// The size of the input features.
14    pub d_input: usize,
15    /// The size of the output features.
16    pub d_output: usize,
17    /// If a bias should be applied during the linear transformation.
18    #[config(default = true)]
19    pub bias: bool,
20    /// The type of function used to initialize neural network parameters
21    #[config(
22        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
23    )]
24    pub initializer: Initializer,
25}
26
27/// Applies a linear transformation to the input tensor:
28///
29/// Should be created with [LinearConfig]
30///
31/// `O = IW + b`
32#[derive(Module, Debug)]
33#[module(custom_display)]
34pub struct Linear<B: Backend> {
35    /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
36    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
37    pub weight: Param<Tensor<B, 2>>,
38    /// Vector of size `d_output` initialized from a uniform distribution:
39    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
40    pub bias: Option<Param<Tensor<B, 1>>>,
41}
42
43impl LinearConfig {
44    /// Initialize a new [linear](Linear) module.
45    pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
46        let shape = [self.d_input, self.d_output];
47        let weight =
48            self.initializer
49                .init_with(shape, Some(self.d_input), Some(self.d_output), device);
50        let bias = if self.bias {
51            Some(self.initializer.init_with(
52                [self.d_output],
53                Some(self.d_input),
54                Some(self.d_output),
55                device,
56            ))
57        } else {
58            None
59        };
60
61        Linear { weight, bias }
62    }
63}
64
65impl<B: Backend> Linear<B> {
66    /// Applies the forward pass on the input tensor.
67    ///
68    /// # Shapes
69    ///
70    /// - input: `[..., d_input]`
71    /// - output: `[..., d_output]`
72    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
73        if D == 1 {
74            // Insert and remove an extra batch dimension for the batch matmul to work.
75            return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1);
76        }
77
78        let weight = self.weight.val().unsqueeze();
79        let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());
80        let output = input.matmul(weight);
81
82        match bias {
83            Some(bias) => output + bias,
84            None => output,
85        }
86    }
87}
88
89impl<B: Backend> ModuleDisplay for Linear<B> {
90    fn custom_settings(&self) -> Option<DisplaySettings> {
91        DisplaySettings::new()
92            .with_new_line_after_attribute(false)
93            .optional()
94    }
95
96    fn custom_content(&self, content: Content) -> Option<Content> {
97        let [d_input, d_output] = self.weight.shape().dims();
98        content
99            .add("d_input", &d_input)
100            .add("d_output", &d_output)
101            .add("bias", &self.bias.is_some())
102            .optional()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::TestBackend;
110    use crate::tensor::{Shape, TensorData};
111    use burn_tensor::{Tolerance, ops::FloatElem};
112    type FT = FloatElem<TestBackend>;
113
114    #[test]
115    fn initializer_default() {
116        TestBackend::seed(0);
117
118        let config = LinearConfig::new(5, 5);
119        let k = (1.0 / config.d_input as f64).sqrt() as f32;
120        let device = Default::default();
121        let linear = config.init::<TestBackend>(&device);
122
123        assert_eq!(
124            config.initializer,
125            Initializer::KaimingUniform {
126                gain: 1.0 / 3.0f64.sqrt(),
127                fan_out_only: false
128            }
129        );
130        linear.weight.to_data().assert_within_range(-k..k);
131    }
132
133    #[test]
134    fn initializer_zeros() {
135        TestBackend::seed(0);
136
137        let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
138        let device = Default::default();
139        let linear = config.init::<TestBackend>(&device);
140
141        assert_eq!(config.initializer, Initializer::Zeros);
142        linear.weight.to_data().assert_approx_eq::<FT>(
143            &TensorData::zeros::<f32, _>(linear.weight.shape()),
144            Tolerance::default(),
145        );
146    }
147
148    #[test]
149    fn test_linear_forward_no_bias() {
150        TestBackend::seed(0);
151
152        let value = 2.;
153        let config = LinearConfig::new(2, 3)
154            .with_initializer(Initializer::Constant { value })
155            .with_bias(false);
156        let device = Default::default();
157        let linear = config.init::<TestBackend>(&device);
158
159        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
160        let result = linear.forward(input);
161        let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
162
163        assert_eq!(result.into_data(), expected_result.into_data());
164    }
165
166    #[test]
167    fn test_linear_forward_with_bias() {
168        TestBackend::seed(0);
169
170        let device = Default::default();
171
172        let value = 2.;
173        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
174        let linear = config.init::<TestBackend>(&device);
175
176        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
177        let result = linear.forward(input);
178        let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
179
180        assert_eq!(result.into_data(), expected_result.into_data());
181    }
182
183    #[test]
184    fn test_linear_1d() {
185        TestBackend::seed(0);
186
187        let device = Default::default();
188
189        let value = 2.;
190        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
191        let linear = config.init::<TestBackend>(&device);
192
193        let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
194        let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
195
196        let result_1d = linear.forward(input_1d).unsqueeze::<2>();
197        let result_2d = linear.forward(input_2d);
198
199        assert_eq!(result_1d.into_data(), result_2d.into_data());
200    }
201
202    #[test]
203    fn display() {
204        let config = LinearConfig::new(3, 5);
205        let linear = config.init::<TestBackend>(&Default::default());
206
207        assert_eq!(
208            alloc::format!("{}", linear),
209            "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
210        );
211    }
212}