1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Param;
5use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
6use burn::tensor::module::linear;
7use burn::tensor::{Tensor, backend::Backend};
8
9#[derive(Config, Debug)]
11pub struct LinearConfig {
12 pub d_input: usize,
14 pub d_output: usize,
16 #[config(default = true)]
18 pub bias: bool,
19 #[config(
21 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
22 )]
23 pub initializer: Initializer,
24 #[config(default = "LinearLayout::Row")]
26 pub layout: LinearLayout,
27}
28
29#[derive(Config, Debug, Copy)]
30pub enum LinearLayout {
34 Row,
36 Col,
38}
39
40#[derive(Module, Debug)]
46#[module(custom_display)]
47pub struct Linear<B: Backend> {
48 pub weight: Param<Tensor<B, 2>>,
51 pub bias: Option<Param<Tensor<B, 1>>>,
54}
55
56impl LinearConfig {
57 pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
59 let weight = match self.layout {
60 LinearLayout::Row => {
61 let shape = [self.d_input, self.d_output];
62 self.initializer
63 .init_with(shape, Some(self.d_input), Some(self.d_output), device)
64 }
65 LinearLayout::Col => {
66 let shape = [self.d_output, self.d_input];
67
68 self.initializer
69 .init_with(shape, Some(self.d_output), Some(self.d_input), device)
70 .save_mapper(move |tensor| {
73 B::sync(&tensor.device()).unwrap();
74 let tensor = tensor.transpose();
75 B::sync(&tensor.device()).unwrap();
76 tensor
77 })
78 .load_mapper(move |tensor| {
80 B::sync(&tensor.device()).unwrap();
81 let tensor = tensor.transpose();
82 B::sync(&tensor.device()).unwrap();
83
84 tensor
85 })
86 .init_mapper(|tensor| {
88 B::sync(&tensor.device()).unwrap();
89 let tensor = tensor.transpose();
90 B::sync(&tensor.device()).unwrap();
91 tensor
92 })
93 }
94 };
95 let bias = if self.bias {
96 Some(self.initializer.init_with(
97 [self.d_output],
98 Some(self.d_input),
99 Some(self.d_output),
100 device,
101 ))
102 } else {
103 None
104 };
105
106 Linear { weight, bias }
107 }
108}
109
110impl<B: Backend> Linear<B> {
111 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
126 linear(
127 input,
128 self.weight.val(),
129 self.bias.as_ref().map(|b| b.val()),
130 )
131 }
132}
133
134impl<B: Backend> ModuleDisplay for Linear<B> {
135 fn custom_settings(&self) -> Option<DisplaySettings> {
136 DisplaySettings::new()
137 .with_new_line_after_attribute(false)
138 .optional()
139 }
140
141 fn custom_content(&self, content: Content) -> Option<Content> {
142 let [d_input, d_output] = self.weight.shape().dims();
143 content
144 .add("d_input", &d_input)
145 .add("d_output", &d_output)
146 .add("bias", &self.bias.is_some())
147 .optional()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::TestBackend;
155 use burn::module::ParamId;
156 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
157 use burn::tensor::ElementConversion;
158 use burn::tensor::{Shape, TensorData};
159 use burn::tensor::{Tolerance, ops::FloatElem};
160 type FT = FloatElem<TestBackend>;
161
162 #[test]
163 fn initializer_default() {
164 let device = Default::default();
165 TestBackend::seed(&device, 0);
166
167 let config = LinearConfig::new(5, 5);
168 let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
169 let linear = config.init::<TestBackend>(&device);
170
171 assert_eq!(
172 config.initializer,
173 Initializer::KaimingUniform {
174 gain: 1.0 / 3.0f64.sqrt(),
175 fan_out_only: false
176 }
177 );
178 linear.weight.to_data().assert_within_range(-k..k);
179 }
180
181 #[test]
182 fn initializer_zeros() {
183 let device = Default::default();
184 TestBackend::seed(&device, 0);
185
186 let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
187 let linear = config.init::<TestBackend>(&device);
188
189 assert_eq!(config.initializer, Initializer::Zeros);
190 linear.weight.to_data().assert_approx_eq::<FT>(
191 &TensorData::zeros::<f32, _>(linear.weight.shape()),
192 Tolerance::default(),
193 );
194 }
195
196 #[test]
197 fn test_linear_forward_no_bias() {
198 let device = Default::default();
199 TestBackend::seed(&device, 0);
200
201 let value = 2.;
202 let config = LinearConfig::new(2, 3)
203 .with_initializer(Initializer::Constant { value })
204 .with_bias(false);
205 let linear = config.init::<TestBackend>(&device);
206
207 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
208 let result = linear.forward(input);
209 let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
210
211 assert_eq!(result.into_data(), expected_result.into_data());
212 }
213
214 #[test]
215 fn test_linear_forward_with_bias() {
216 let device = Default::default();
217 TestBackend::seed(&device, 0);
218
219 let device = Default::default();
220
221 let value = 2.;
222 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
223 let linear = config.init::<TestBackend>(&device);
224
225 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
226 let result = linear.forward(input);
227 let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
228
229 assert_eq!(result.into_data(), expected_result.into_data());
230 }
231
232 #[test]
233 fn test_linear_1d() {
234 let device = Default::default();
235 TestBackend::seed(&device, 0);
236
237 let device = Default::default();
238
239 let value = 2.;
240 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
241 let linear = config.init::<TestBackend>(&device);
242
243 let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
244 let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
245
246 let result_1d = linear.forward(input_1d).unsqueeze::<2>();
247 let result_2d = linear.forward(input_2d);
248
249 assert_eq!(result_1d.into_data(), result_2d.into_data());
250 }
251
252 #[test]
253 fn display() {
254 let config = LinearConfig::new(3, 5);
255 let linear = config.init::<TestBackend>(&Default::default());
256
257 assert_eq!(
258 alloc::format!("{linear}"),
259 "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
260 );
261 }
262
263 #[test]
264 fn layout() {
265 let device = Default::default();
266 let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
267 let linear = config.init::<TestBackend>(&device);
268
269 assert_eq!(linear.weight.dims(), [6, 12], "Shape is as configured");
270
271 let recorder = BinBytesRecorder::<FullPrecisionSettings>::new();
272
273 let record = linear.into_record();
275 let data = recorder.record(record, ()).unwrap();
276 let record = recorder.load(data.clone(), &device).unwrap();
277
278 let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row);
279 let linear_row = config.init::<TestBackend>(&device).load_record(record);
280
281 assert_eq!(
282 linear_row.weight.dims(),
283 [12, 6],
284 "Shape should be transposed"
285 );
286
287 let record = recorder.load(data.clone(), &device).unwrap();
288 let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
289 let linear_col = config.init::<TestBackend>(&device).load_record(record);
290
291 assert_eq!(
292 linear_col.weight.dims(),
293 [6, 12],
294 "Shape should be as configured"
295 );
296
297 let record = linear_col.into_record();
302 let data = recorder.record(record, ()).unwrap();
303
304 let record = recorder.load(data, &device).unwrap();
305 let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
306 let linear_col = config.init::<TestBackend>(&device).load_record(record);
307
308 assert_eq!(
309 linear_col.weight.dims(),
310 [6, 12],
311 "Shape should be as configured"
312 );
313 }
314
315 #[test]
316 fn col_row_same_result() {
317 let device = Default::default();
318 let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);
319 let linear_col = config_col.init::<TestBackend>(&device);
320 let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device);
321 let value = linear_col.forward(signal.clone());
322
323 let data_1 = value.into_data();
324
325 let weights = linear_col.weight.val().into_data();
326 let weights = Tensor::from_data(weights, &device);
327
328 let linear = Linear {
329 weight: Param::initialized(ParamId::new(), weights),
330 bias: linear_col
331 .bias
332 .map(|b| Param::initialized(ParamId::new(), b.val())),
333 };
334
335 let value = linear.forward(signal);
336 let data_2 = value.into_data();
337
338 data_1.assert_approx_eq::<f32>(&data_2, Default::default());
339 }
340}