1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Param;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::{Tensor, backend::Backend};
7
8use super::Initializer;
9
10#[derive(Config, Debug)]
12pub struct LinearConfig {
13 pub d_input: usize,
15 pub d_output: usize,
17 #[config(default = true)]
19 pub bias: bool,
20 #[config(
22 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
23 )]
24 pub initializer: Initializer,
25}
26
27#[derive(Module, Debug)]
33#[module(custom_display)]
34pub struct Linear<B: Backend> {
35 pub weight: Param<Tensor<B, 2>>,
38 pub bias: Option<Param<Tensor<B, 1>>>,
41}
42
43impl LinearConfig {
44 pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
46 let shape = [self.d_input, self.d_output];
47 let weight =
48 self.initializer
49 .init_with(shape, Some(self.d_input), Some(self.d_output), device);
50 let bias = if self.bias {
51 Some(self.initializer.init_with(
52 [self.d_output],
53 Some(self.d_input),
54 Some(self.d_output),
55 device,
56 ))
57 } else {
58 None
59 };
60
61 Linear { weight, bias }
62 }
63}
64
65impl<B: Backend> Linear<B> {
66 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
73 if D == 1 {
74 return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1);
76 }
77
78 let weight = self.weight.val().unsqueeze();
79 let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());
80 let output = input.matmul(weight);
81
82 match bias {
83 Some(bias) => output + bias,
84 None => output,
85 }
86 }
87}
88
89impl<B: Backend> ModuleDisplay for Linear<B> {
90 fn custom_settings(&self) -> Option<DisplaySettings> {
91 DisplaySettings::new()
92 .with_new_line_after_attribute(false)
93 .optional()
94 }
95
96 fn custom_content(&self, content: Content) -> Option<Content> {
97 let [d_input, d_output] = self.weight.shape().dims();
98 content
99 .add("d_input", &d_input)
100 .add("d_output", &d_output)
101 .add("bias", &self.bias.is_some())
102 .optional()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::TestBackend;
110 use crate::tensor::{Shape, TensorData};
111 use burn_tensor::{Tolerance, ops::FloatElem};
112 type FT = FloatElem<TestBackend>;
113
114 #[test]
115 fn initializer_default() {
116 TestBackend::seed(0);
117
118 let config = LinearConfig::new(5, 5);
119 let k = (1.0 / config.d_input as f64).sqrt() as f32;
120 let device = Default::default();
121 let linear = config.init::<TestBackend>(&device);
122
123 assert_eq!(
124 config.initializer,
125 Initializer::KaimingUniform {
126 gain: 1.0 / 3.0f64.sqrt(),
127 fan_out_only: false
128 }
129 );
130 linear.weight.to_data().assert_within_range(-k..k);
131 }
132
133 #[test]
134 fn initializer_zeros() {
135 TestBackend::seed(0);
136
137 let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
138 let device = Default::default();
139 let linear = config.init::<TestBackend>(&device);
140
141 assert_eq!(config.initializer, Initializer::Zeros);
142 linear.weight.to_data().assert_approx_eq::<FT>(
143 &TensorData::zeros::<f32, _>(linear.weight.shape()),
144 Tolerance::default(),
145 );
146 }
147
148 #[test]
149 fn test_linear_forward_no_bias() {
150 TestBackend::seed(0);
151
152 let value = 2.;
153 let config = LinearConfig::new(2, 3)
154 .with_initializer(Initializer::Constant { value })
155 .with_bias(false);
156 let device = Default::default();
157 let linear = config.init::<TestBackend>(&device);
158
159 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
160 let result = linear.forward(input);
161 let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
162
163 assert_eq!(result.into_data(), expected_result.into_data());
164 }
165
166 #[test]
167 fn test_linear_forward_with_bias() {
168 TestBackend::seed(0);
169
170 let device = Default::default();
171
172 let value = 2.;
173 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
174 let linear = config.init::<TestBackend>(&device);
175
176 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
177 let result = linear.forward(input);
178 let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
179
180 assert_eq!(result.into_data(), expected_result.into_data());
181 }
182
183 #[test]
184 fn test_linear_1d() {
185 TestBackend::seed(0);
186
187 let device = Default::default();
188
189 let value = 2.;
190 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
191 let linear = config.init::<TestBackend>(&device);
192
193 let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
194 let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
195
196 let result_1d = linear.forward(input_1d).unsqueeze::<2>();
197 let result_2d = linear.forward(input_2d);
198
199 assert_eq!(result_1d.into_data(), result_2d.into_data());
200 }
201
202 #[test]
203 fn display() {
204 let config = LinearConfig::new(3, 5);
205 let linear = config.init::<TestBackend>(&Default::default());
206
207 assert_eq!(
208 alloc::format!("{}", linear),
209 "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
210 );
211 }
212}