1use burn_tensor::module::linear;
2
3use crate as burn;
4
5use crate::config::Config;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
8use crate::tensor::{Tensor, backend::Backend};
9
10use super::Initializer;
11
12#[derive(Config, Debug)]
14pub struct LinearConfig {
15 pub d_input: usize,
17 pub d_output: usize,
19 #[config(default = true)]
21 pub bias: bool,
22 #[config(
24 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
25 )]
26 pub initializer: Initializer,
27}
28
29#[derive(Module, Debug)]
35#[module(custom_display)]
36pub struct Linear<B: Backend> {
37 pub weight: Param<Tensor<B, 2>>,
40 pub bias: Option<Param<Tensor<B, 1>>>,
43}
44
45impl LinearConfig {
46 pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {
48 let shape = [self.d_input, self.d_output];
49 let weight =
50 self.initializer
51 .init_with(shape, Some(self.d_input), Some(self.d_output), device);
52 let bias = if self.bias {
53 Some(self.initializer.init_with(
54 [self.d_output],
55 Some(self.d_input),
56 Some(self.d_output),
57 device,
58 ))
59 } else {
60 None
61 };
62
63 Linear { weight, bias }
64 }
65}
66
67impl<B: Backend> Linear<B> {
68 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
83 linear(
84 input,
85 self.weight.val(),
86 self.bias.as_ref().map(|b| b.val()),
87 )
88 }
89}
90
91impl<B: Backend> ModuleDisplay for Linear<B> {
92 fn custom_settings(&self) -> Option<DisplaySettings> {
93 DisplaySettings::new()
94 .with_new_line_after_attribute(false)
95 .optional()
96 }
97
98 fn custom_content(&self, content: Content) -> Option<Content> {
99 let [d_input, d_output] = self.weight.shape().dims();
100 content
101 .add("d_input", &d_input)
102 .add("d_output", &d_output)
103 .add("bias", &self.bias.is_some())
104 .optional()
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::TestBackend;
112 use crate::tensor::{Shape, TensorData};
113 use burn_tensor::ElementConversion;
114 use burn_tensor::{Tolerance, ops::FloatElem};
115 type FT = FloatElem<TestBackend>;
116
117 #[test]
118 fn initializer_default() {
119 TestBackend::seed(0);
120
121 let config = LinearConfig::new(5, 5);
122 let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();
123 let device = Default::default();
124 let linear = config.init::<TestBackend>(&device);
125
126 assert_eq!(
127 config.initializer,
128 Initializer::KaimingUniform {
129 gain: 1.0 / 3.0f64.sqrt(),
130 fan_out_only: false
131 }
132 );
133 linear.weight.to_data().assert_within_range(-k..k);
134 }
135
136 #[test]
137 fn initializer_zeros() {
138 TestBackend::seed(0);
139
140 let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);
141 let device = Default::default();
142 let linear = config.init::<TestBackend>(&device);
143
144 assert_eq!(config.initializer, Initializer::Zeros);
145 linear.weight.to_data().assert_approx_eq::<FT>(
146 &TensorData::zeros::<f32, _>(linear.weight.shape()),
147 Tolerance::default(),
148 );
149 }
150
151 #[test]
152 fn test_linear_forward_no_bias() {
153 TestBackend::seed(0);
154
155 let value = 2.;
156 let config = LinearConfig::new(2, 3)
157 .with_initializer(Initializer::Constant { value })
158 .with_bias(false);
159 let device = Default::default();
160 let linear = config.init::<TestBackend>(&device);
161
162 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
163 let result = linear.forward(input);
164 let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);
165
166 assert_eq!(result.into_data(), expected_result.into_data());
167 }
168
169 #[test]
170 fn test_linear_forward_with_bias() {
171 TestBackend::seed(0);
172
173 let device = Default::default();
174
175 let value = 2.;
176 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
177 let linear = config.init::<TestBackend>(&device);
178
179 let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
180 let result = linear.forward(input);
181 let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);
182
183 assert_eq!(result.into_data(), expected_result.into_data());
184 }
185
186 #[test]
187 fn test_linear_1d() {
188 TestBackend::seed(0);
189
190 let device = Default::default();
191
192 let value = 2.;
193 let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });
194 let linear = config.init::<TestBackend>(&device);
195
196 let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);
197 let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);
198
199 let result_1d = linear.forward(input_1d).unsqueeze::<2>();
200 let result_2d = linear.forward(input_2d);
201
202 assert_eq!(result_1d.into_data(), result_2d.into_data());
203 }
204
205 #[test]
206 fn display() {
207 let config = LinearConfig::new(3, 5);
208 let linear = config.init::<TestBackend>(&Default::default());
209
210 assert_eq!(
211 alloc::format!("{linear}"),
212 "Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
213 );
214 }
215}