burn_core/nn/
linear.rs

1use burn_tensor::module::linear;
2
3use crate as burn;
4
5use crate::config::Config;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
8use crate::tensor::{Tensor, backend::Backend};
9
10use super::Initializer;
11
12/// Configuration to create a [Linear](Linear) layer using the [init function](LinearConfig::init).
13#[derive(Config, Debug)]
14pub struct LinearConfig {
15    /// The size of the input features.
16    pub d_input: usize,
17    /// The size of the output features.
18    pub d_output: usize,
19    /// If a bias should be applied during the linear transformation.
20    #[config(default = true)]
21    pub bias: bool,
22    /// The type of function used to initialize neural network parameters
23    #[config(
24        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
25    )]
26    pub initializer: Initializer,
27}
28
29/// Applies a linear transformation to the input tensor.
30///
31/// Should be created with [LinearConfig]
32///
33/// `O = IW + b`
34#[derive(Module, Debug)]
35#[module(custom_display)]
36pub struct Linear<B: Backend> {
37    /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
38    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
39    pub weight: Param<Tensor<B, 2>>,
40    /// Vector of size `d_output` initialized from a uniform distribution:
41    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`
42    pub bias: Option<Param<Tensor<B, 1>>>,
43}
44
45impl LinearConfig {
46    /// Initialize a new [linear](Linear) module.
47    pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
48        let shape = [self.d_input, self.d_output];
49        let weight =
50            self.initializer
51                .init_with(shape, Some(self.d_input), Some(self.d_output), device);
52        let bias = if self.bias {
53            Some(self.initializer.init_with(
54                [self.d_output],
55                Some(self.d_input),
56                Some(self.d_output),
57                device,
58            ))
59        } else {
60            None
61        };
62
63        Linear { weight, bias }
64    }
65}
66
67impl<B: Backend> Linear<B> {
68    /// Applies the forward pass on the input tensor.
69    ///
70    /// # Arguments
71    ///
72    /// - `input` - The input tensor of shape `[..., d_input]`.
73    ///
74    /// # Shapes
75    ///
76    /// - input: `[..., d_input]`
77    /// - output: `[..., d_output]`
78    ///
79    /// # Returns
80    ///
81    /// The transformed tensor of shape `[..., d_output]`.
82    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
83        linear(
84            input,
85            self.weight.val(),
86            self.bias.as_ref().map(|b| b.val()),
87        )
88    }
89}
90
91impl<B: Backend> ModuleDisplay for Linear<B> {
92    fn custom_settings(&self) -> Option<DisplaySettings> {
93        DisplaySettings::new()
94            .with_new_line_after_attribute(false)
95            .optional()
96    }
97
98    fn custom_content(&self, content: Content) -> Option<Content> {
99        let [d_input, d_output] = self.weight.shape().dims();
100        content
101            .add("d_input", &d_input)
102            .add("d_output", &d_output)
103            .add("bias", &self.bias.is_some())
104            .optional()
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::TestBackend;
112    use crate::tensor::{Shape, TensorData};
113    use burn_tensor::ElementConversion;
114    use burn_tensor::{Tolerance, ops::FloatElem};
115    type FT = FloatElem<TestBackend>;
116
117    #[test]
118    fn initializer_default() {
119        TestBackend::seed(0);
120
121        let config = LinearConfig::new(5, 5);
122        let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
123        let device = Default::default();
124        let linear = config.init::<TestBackend>(&device);
125
126        assert_eq!(
127            config.initializer,
128            Initializer::KaimingUniform {
129                gain: 1.0 / 3.0f64.sqrt(),
130                fan_out_only: false
131            }
132        );
133        linear.weight.to_data().assert_within_range(-k..k);
134    }
135
136    #[test]
137    fn initializer_zeros() {
138        TestBackend::seed(0);
139
140        let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
141        let device = Default::default();
142        let linear = config.init::<TestBackend>(&device);
143
144        assert_eq!(config.initializer, Initializer::Zeros);
145        linear.weight.to_data().assert_approx_eq::<FT>(
146            &TensorData::zeros::<f32, _>(linear.weight.shape()),
147            Tolerance::default(),
148        );
149    }
150
151    #[test]
152    fn test_linear_forward_no_bias() {
153        TestBackend::seed(0);
154
155        let value = 2.;
156        let config = LinearConfig::new(2, 3)
157            .with_initializer(Initializer::Constant { value })
158            .with_bias(false);
159        let device = Default::default();
160        let linear = config.init::<TestBackend>(&device);
161
162        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
163        let result = linear.forward(input);
164        let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
165
166        assert_eq!(result.into_data(), expected_result.into_data());
167    }
168
169    #[test]
170    fn test_linear_forward_with_bias() {
171        TestBackend::seed(0);
172
173        let device = Default::default();
174
175        let value = 2.;
176        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
177        let linear = config.init::<TestBackend>(&device);
178
179        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
180        let result = linear.forward(input);
181        let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
182
183        assert_eq!(result.into_data(), expected_result.into_data());
184    }
185
186    #[test]
187    fn test_linear_1d() {
188        TestBackend::seed(0);
189
190        let device = Default::default();
191
192        let value = 2.;
193        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
194        let linear = config.init::<TestBackend>(&device);
195
196        let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
197        let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
198
199        let result_1d = linear.forward(input_1d).unsqueeze::<2>();
200        let result_2d = linear.forward(input_2d);
201
202        assert_eq!(result_1d.into_data(), result_2d.into_data());
203    }
204
205    #[test]
206    fn display() {
207        let config = LinearConfig::new(3, 5);
208        let linear = config.init::<TestBackend>(&Default::default());
209
210        assert_eq!(
211            alloc::format!("{linear}"),
212            "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
213        );
214    }
215}