axonml_nn/layers/
linear.rs1use std::collections::HashMap;
24
25use axonml_autograd::Variable;
26use axonml_tensor::Tensor;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32pub struct Linear {
56 pub weight: Parameter,
58 pub bias: Option<Parameter>,
60 in_features: usize,
62 out_features: usize,
64}
65
66impl Linear {
67 pub fn new(in_features: usize, out_features: usize) -> Self {
69 Self::with_bias(in_features, out_features, true)
70 }
71
72 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
74 let weight_data = kaiming_uniform(out_features, in_features);
76 let weight = Parameter::named("weight", weight_data, true);
77
78 let bias_param = if bias {
79 let bias_data = zeros(&[out_features]);
81 Some(Parameter::named("bias", bias_data, true))
82 } else {
83 None
84 };
85
86 Self {
87 weight,
88 bias: bias_param,
89 in_features,
90 out_features,
91 }
92 }
93
94 pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
96 let out_features = weight.shape()[0];
97 let in_features = weight.shape()[1];
98
99 Self {
100 weight: Parameter::named("weight", weight, true),
101 bias: bias.map(|b| Parameter::named("bias", b, true)),
102 in_features,
103 out_features,
104 }
105 }
106
107 pub fn in_features(&self) -> usize {
109 self.in_features
110 }
111
112 pub fn out_features(&self) -> usize {
114 self.out_features
115 }
116}
117
118impl Module for Linear {
119 fn forward(&self, input: &Variable) -> Variable {
120 let input_shape = input.shape();
122 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
123
124 let total_batch: usize = batch_dims.iter().product();
126 let input_2d = if input_shape.len() > 2 {
127 input.reshape(&[total_batch, self.in_features])
129 } else {
130 input.clone()
131 };
132
133 let weight_var = self.weight.variable();
137 let weight_t = weight_var.transpose(0, 1);
139 let mut output = input_2d.matmul(&weight_t);
140
141 if let Some(ref bias) = self.bias {
143 let bias_var = bias.variable();
144 output = output.add_var(&bias_var);
145 }
146
147 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
149 let mut output_shape: Vec<usize> = batch_dims.clone();
150 output_shape.push(self.out_features);
151 output.reshape(&output_shape)
152 } else {
153 output
154 }
155 }
156
157 fn parameters(&self) -> Vec<Parameter> {
158 let mut params = vec![self.weight.clone()];
159 if let Some(ref bias) = self.bias {
160 params.push(bias.clone());
161 }
162 params
163 }
164
165 fn named_parameters(&self) -> HashMap<String, Parameter> {
166 let mut params = HashMap::new();
167 params.insert("weight".to_string(), self.weight.clone());
168 if let Some(ref bias) = self.bias {
169 params.insert("bias".to_string(), bias.clone());
170 }
171 params
172 }
173
174 fn name(&self) -> &'static str {
175 "Linear"
176 }
177}
178
179impl std::fmt::Debug for Linear {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("Linear")
182 .field("in_features", &self.in_features)
183 .field("out_features", &self.out_features)
184 .field("bias", &self.bias.is_some())
185 .finish()
186 }
187}
188
189#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_linear_creation() {
199 let linear = Linear::new(10, 5);
200 assert_eq!(linear.in_features(), 10);
201 assert_eq!(linear.out_features(), 5);
202 assert!(linear.bias.is_some());
203 }
204
205 #[test]
206 fn test_linear_no_bias() {
207 let linear = Linear::with_bias(10, 5, false);
208 assert!(linear.bias.is_none());
209 }
210
211 #[test]
212 fn test_linear_forward() {
213 let linear = Linear::new(3, 2);
214
215 let input = Variable::new(
216 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
217 false,
218 );
219 let output = linear.forward(&input);
220
221 assert_eq!(output.shape(), vec![1, 2]);
222 }
223
224 #[test]
225 fn test_linear_batch_forward() {
226 let linear = Linear::new(4, 2);
227
228 let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
229 let output = linear.forward(&input);
230
231 assert_eq!(output.shape(), vec![3, 2]);
232 }
233
234 #[test]
235 fn test_linear_parameters() {
236 let linear = Linear::new(10, 5);
237 let params = linear.parameters();
238 assert_eq!(params.len(), 2); let linear_no_bias = Linear::with_bias(10, 5, false);
241 let params_no_bias = linear_no_bias.parameters();
242 assert_eq!(params_no_bias.len(), 1); }
244
245 #[test]
246 fn test_linear_num_parameters() {
247 let linear = Linear::new(10, 5);
248 assert_eq!(linear.num_parameters(), 55);
250 }
251}