1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Param;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::{backend::Backend, Tensor};
7
8use super::Initializer;
9
10#[derive(Config, Debug)]
12pub struct LinearConfig {
13 pub d_input: usize,
15 pub d_output: usize,
17 #[config(default = true)]
19 pub bias: bool,
20 #[config(
22 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
23 )]
24 pub initializer: Initializer,
25}
26
27#[derive(Module, Debug)]
33#[module(custom_display)]
34pub struct Linear<B: Backend> {
35 pub weight: Param<Tensor<B, 2>>,
38 pub bias: Option<Param<Tensor<B, 1>>>,
41}
42
43impl LinearConfig {
44 pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
46 let shape = [self.d_input, self.d_output];
47 let weight =
48 self.initializer
49 .init_with(shape, Some(self.d_input), Some(self.d_output), device);
50 let bias = if self.bias {
51 Some(self.initializer.init_with(
52 [self.d_output],
53 Some(self.d_input),
54 Some(self.d_output),
55 device,
56 ))
57 } else {
58 None
59 };
60
61 Linear { weight, bias }
62 }
63}
64
65impl<B: Backend> Linear<B> {
66 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
73 if D == 1 {
74 return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1);
76 }
77
78 let weight = self.weight.val().unsqueeze();
79 let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());
80
81 let output = input.matmul(weight);
82
83 match bias {
84 Some(bias) => output + bias,
85 None => output,
86 }
87 }
88}
89
90impl<B: Backend> ModuleDisplay for Linear<B> {
91 fn custom_settings(&self) -> Option<DisplaySettings> {
92 DisplaySettings::new()
93 .with_new_line_after_attribute(false)
94 .optional()
95 }
96
97 fn custom_content(&self, content: Content) -> Option<Content> {
98 let [d_input, d_output] = self.weight.shape().dims();
99 content
100 .add("d_input", &d_input)
101 .add("d_output", &d_output)
102 .add("bias", &self.bias.is_some())
103 .optional()
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::tensor::{Shape, TensorData};
111 use crate::TestBackend;
112
113 #[test]
114 fn initializer_default() {
115 TestBackend::seed(0);
116
117 let config = LinearConfig::new(5, 5);
118 let k = (1.0 / config.d_input as f64).sqrt() as f32;
119 let device = Default::default();
120 let linear = config.init::<TestBackend>(&device);
121
122 assert_eq!(
123 config.initializer,
124 Initializer::KaimingUniform {
125 gain: 1.0 / 3.0f64.sqrt(),
126 fan_out_only: false
127 }
128 );
129 linear.weight.to_data().assert_within_range(-k..k);
130 }
131
132 #[test]
133 fn initializer_zeros() {
134 TestBackend::seed(0);
135
136 let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
137 let device = Default::default();
138 let linear = config.init::<TestBackend>(&device);
139
140 assert_eq!(config.initializer, Initializer::Zeros);
141 linear
142 .weight
143 .to_data()
144 .assert_approx_eq(&TensorData::zeros::<f32, _>(linear.weight.shape()), 3);
145 }
146
147 #[test]
148 fn test_linear_forward_no_bias() {
149 TestBackend::seed(0);
150
151 let value = 2.;
152 let config = LinearConfig::new(2, 3)
153 .with_initializer(Initializer::Constant { value })
154 .with_bias(false);
155 let device = Default::default();
156 let linear = config.init::<TestBackend>(&device);
157
158 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
159 let result = linear.forward(input);
160 let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
161
162 assert_eq!(result.into_data(), expected_result.into_data());
163 }
164
165 #[test]
166 fn test_linear_forward_with_bias() {
167 TestBackend::seed(0);
168
169 let device = Default::default();
170
171 let value = 2.;
172 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
173 let linear = config.init::<TestBackend>(&device);
174
175 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
176 let result = linear.forward(input);
177 let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
178
179 assert_eq!(result.into_data(), expected_result.into_data());
180 }
181
182 #[test]
183 fn test_linear_1d() {
184 TestBackend::seed(0);
185
186 let device = Default::default();
187
188 let value = 2.;
189 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
190 let linear = config.init::<TestBackend>(&device);
191
192 let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
193 let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
194
195 let result_1d = linear.forward(input_1d).unsqueeze::<2>();
196 let result_2d = linear.forward(input_2d);
197
198 assert_eq!(result_1d.into_data(), result_2d.into_data());
199 }
200
201 #[test]
202 fn display() {
203 let config = LinearConfig::new(3, 5);
204 let linear = config.init::<TestBackend>(&Default::default());
205
206 assert_eq!(
207 alloc::format!("{}", linear),
208 "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
209 );
210 }
211}