Skip to main content

burn_nn/modules/
linear.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Param;
5use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
6use burn::tensor::module::linear;
7use burn::tensor::{Tensor, backend::Backend};
8
9/// Configuration to create a [`Linear`] layer using the [init function](LinearConfig::init).
10#[derive(Config, Debug)]
11pub struct LinearConfig {
12    /// The size of the input features.
13    pub d_input: usize,
14    /// The size of the output features.
15    pub d_output: usize,
16    /// If a bias should be applied during the linear transformation.
17    #[config(default = true)]
18    pub bias: bool,
19    /// The type of function used to initialize neural network parameters
20    #[config(
21        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
22    )]
23    pub initializer: Initializer,
24    /// The layout in which the linear parameters are stored.
25    #[config(default = "LinearLayout::Row")]
26    pub layout: LinearLayout,
27}
28
29#[derive(Config, Debug, Copy)]
30/// The layout in which the linear parameters are stored.
31///
32/// This can have performance impacts.
33pub enum LinearLayout {
34    /// Parameters are stored in Row major.
35    Row,
36    /// Parameters are stored in Col major.
37    Col,
38}
39
40/// Applies a linear transformation to the input tensor.
41///
42/// Should be created with [LinearConfig]
43///
44/// `O = IW + b`
45#[derive(Module, Debug)]
46#[module(custom_display)]
47pub struct Linear<B: Backend> {
48    /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
49    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
50    pub weight: Param<Tensor<B, 2>>,
51    /// Vector of size `d_output` initialized from a uniform distribution:
52    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
53    pub bias: Option<Param<Tensor<B, 1>>>,
54}
55
56impl LinearConfig {
57    /// Initialize a new [`Linear`] module.
58    pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
59        let weight = match self.layout {
60            LinearLayout::Row => {
61                let shape = [self.d_input, self.d_output];
62                self.initializer
63                    .init_with(shape, Some(self.d_input), Some(self.d_output), device)
64            }
65            LinearLayout::Col => {
66                let shape = [self.d_output, self.d_input];
67
68                self.initializer
69                    .init_with(shape, Some(self.d_output), Some(self.d_input), device)
70                    // The param is already transposed when init. We re-transpose to have
71                    // [d_output, d_input] while saving.
72                    .save_mapper(move |tensor| {
73                        B::sync(&tensor.device()).unwrap();
74                        let tensor = tensor.transpose();
75                        B::sync(&tensor.device()).unwrap();
76                        tensor
77                    })
78                    // When loading from record we have to transpose.
79                    .load_mapper(move |tensor| {
80                        B::sync(&tensor.device()).unwrap();
81                        let tensor = tensor.transpose();
82                        B::sync(&tensor.device()).unwrap();
83
84                        tensor
85                    })
86                    // When loading from initialization, we have to transpose.
87                    .init_mapper(|tensor| {
88                        B::sync(&tensor.device()).unwrap();
89                        let tensor = tensor.transpose();
90                        B::sync(&tensor.device()).unwrap();
91                        tensor
92                    })
93            }
94        };
95        let bias = if self.bias {
96            Some(self.initializer.init_with(
97                [self.d_output],
98                Some(self.d_input),
99                Some(self.d_output),
100                device,
101            ))
102        } else {
103            None
104        };
105
106        Linear { weight, bias }
107    }
108}
109
110impl<B: Backend> Linear<B> {
111    /// Applies the forward pass on the input tensor.
112    ///
113    /// # Arguments
114    ///
115    /// - `input` - The input tensor of shape `[..., d_input]`.
116    ///
117    /// # Shapes
118    ///
119    /// - input: `[..., d_input]`
120    /// - output: `[..., d_output]`
121    ///
122    /// # Returns
123    ///
124    /// The transformed tensor of shape `[..., d_output]`.
125    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
126        linear(
127            input,
128            self.weight.val(),
129            self.bias.as_ref().map(|b| b.val()),
130        )
131    }
132}
133
134impl<B: Backend> ModuleDisplay for Linear<B> {
135    fn custom_settings(&self) -> Option<DisplaySettings> {
136        DisplaySettings::new()
137            .with_new_line_after_attribute(false)
138            .optional()
139    }
140
141    fn custom_content(&self, content: Content) -> Option<Content> {
142        let [d_input, d_output] = self.weight.shape().dims();
143        content
144            .add("d_input", &d_input)
145            .add("d_output", &d_output)
146            .add("bias", &self.bias.is_some())
147            .optional()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::TestBackend;
155    use burn::module::ParamId;
156    use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
157    use burn::tensor::ElementConversion;
158    use burn::tensor::{Shape, TensorData};
159    use burn::tensor::{Tolerance, ops::FloatElem};
160    type FT = FloatElem<TestBackend>;
161
162    #[test]
163    fn initializer_default() {
164        let device = Default::default();
165        TestBackend::seed(&device, 0);
166
167        let config = LinearConfig::new(5, 5);
168        let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
169        let linear = config.init::<TestBackend>(&device);
170
171        assert_eq!(
172            config.initializer,
173            Initializer::KaimingUniform {
174                gain: 1.0 / 3.0f64.sqrt(),
175                fan_out_only: false
176            }
177        );
178        linear.weight.to_data().assert_within_range(-k..k);
179    }
180
181    #[test]
182    fn initializer_zeros() {
183        let device = Default::default();
184        TestBackend::seed(&device, 0);
185
186        let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
187        let linear = config.init::<TestBackend>(&device);
188
189        assert_eq!(config.initializer, Initializer::Zeros);
190        linear.weight.to_data().assert_approx_eq::<FT>(
191            &TensorData::zeros::<f32, _>(linear.weight.shape()),
192            Tolerance::default(),
193        );
194    }
195
196    #[test]
197    fn test_linear_forward_no_bias() {
198        let device = Default::default();
199        TestBackend::seed(&device, 0);
200
201        let value = 2.;
202        let config = LinearConfig::new(2, 3)
203            .with_initializer(Initializer::Constant { value })
204            .with_bias(false);
205        let linear = config.init::<TestBackend>(&device);
206
207        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
208        let result = linear.forward(input);
209        let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
210
211        assert_eq!(result.into_data(), expected_result.into_data());
212    }
213
214    #[test]
215    fn test_linear_forward_with_bias() {
216        let device = Default::default();
217        TestBackend::seed(&device, 0);
218
219        let device = Default::default();
220
221        let value = 2.;
222        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
223        let linear = config.init::<TestBackend>(&device);
224
225        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
226        let result = linear.forward(input);
227        let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
228
229        assert_eq!(result.into_data(), expected_result.into_data());
230    }
231
232    #[test]
233    fn test_linear_1d() {
234        let device = Default::default();
235        TestBackend::seed(&device, 0);
236
237        let device = Default::default();
238
239        let value = 2.;
240        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
241        let linear = config.init::<TestBackend>(&device);
242
243        let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
244        let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
245
246        let result_1d = linear.forward(input_1d).unsqueeze::<2>();
247        let result_2d = linear.forward(input_2d);
248
249        assert_eq!(result_1d.into_data(), result_2d.into_data());
250    }
251
252    #[test]
253    fn display() {
254        let config = LinearConfig::new(3, 5);
255        let linear = config.init::<TestBackend>(&Default::default());
256
257        assert_eq!(
258            alloc::format!("{linear}"),
259            "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
260        );
261    }
262
263    #[test]
264    fn layout() {
265        let device = Default::default();
266        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
267        let linear = config.init::<TestBackend>(&device);
268
269        assert_eq!(linear.weight.dims(), [6, 12], "Shape is as configured");
270
271        let recorder = BinBytesRecorder::<FullPrecisionSettings>::new();
272
273        // We go through serialization to trigger the mappers..
274        let record = linear.into_record();
275        let data = recorder.record(record, ()).unwrap();
276        let record = recorder.load(data.clone(), &device).unwrap();
277
278        let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row);
279        let linear_row = config.init::<TestBackend>(&device).load_record(record);
280
281        assert_eq!(
282            linear_row.weight.dims(),
283            [12, 6],
284            "Shape should be transposed"
285        );
286
287        let record = recorder.load(data.clone(), &device).unwrap();
288        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
289        let linear_col = config.init::<TestBackend>(&device).load_record(record);
290
291        assert_eq!(
292            linear_col.weight.dims(),
293            [6, 12],
294            "Shape should be as configured"
295        );
296
297        // We go through serialization to trigger the mappers.
298        //
299        // The test will fail if the mapper is not correctly given to the module after loading a
300        // record.
301        let record = linear_col.into_record();
302        let data = recorder.record(record, ()).unwrap();
303
304        let record = recorder.load(data, &device).unwrap();
305        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
306        let linear_col = config.init::<TestBackend>(&device).load_record(record);
307
308        assert_eq!(
309            linear_col.weight.dims(),
310            [6, 12],
311            "Shape should be as configured"
312        );
313    }
314
315    #[test]
316    fn col_row_same_result() {
317        let device = Default::default();
318        let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
319        let linear_col = config_col.init::<TestBackend>(&device);
320        let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device);
321        let value = linear_col.forward(signal.clone());
322
323        let data_1 = value.into_data();
324
325        let weights = linear_col.weight.val().into_data();
326        let weights = Tensor::from_data(weights, &device);
327
328        let linear = Linear {
329            weight: Param::initialized(ParamId::new(), weights),
330            bias: linear_col
331                .bias
332                .map(|b| Param::initialized(ParamId::new(), b.val())),
333        };
334
335        let value = linear.forward(signal);
336        let data_2 = value.into_data();
337
338        data_1.assert_approx_eq::<f32>(&data_2, Default::default());
339    }
340}