axonml_nn/layers/
linear.rs1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26pub struct Linear {
50 pub weight: Parameter,
52 pub bias: Option<Parameter>,
54 in_features: usize,
56 out_features: usize,
58}
59
60impl Linear {
61 pub fn new(in_features: usize, out_features: usize) -> Self {
63 Self::with_bias(in_features, out_features, true)
64 }
65
66 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
68 let weight_data = kaiming_uniform(out_features, in_features);
70 let weight = Parameter::named("weight", weight_data, true);
71
72 let bias_param = if bias {
73 let bias_data = zeros(&[out_features]);
75 Some(Parameter::named("bias", bias_data, true))
76 } else {
77 None
78 };
79
80 Self {
81 weight,
82 bias: bias_param,
83 in_features,
84 out_features,
85 }
86 }
87
88 pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
90 let out_features = weight.shape()[0];
91 let in_features = weight.shape()[1];
92
93 Self {
94 weight: Parameter::named("weight", weight, true),
95 bias: bias.map(|b| Parameter::named("bias", b, true)),
96 in_features,
97 out_features,
98 }
99 }
100
101 pub fn in_features(&self) -> usize {
103 self.in_features
104 }
105
106 pub fn out_features(&self) -> usize {
108 self.out_features
109 }
110}
111
112impl Module for Linear {
113 fn forward(&self, input: &Variable) -> Variable {
114 let input_shape = input.shape();
116 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
117
118 let total_batch: usize = batch_dims.iter().product();
120 let input_2d = if input_shape.len() > 2 {
121 input.reshape(&[total_batch, self.in_features])
123 } else {
124 input.clone()
125 };
126
127 let weight_var = self.weight.variable();
131 let weight_t = weight_var.transpose(0, 1);
133 let mut output = input_2d.matmul(&weight_t);
134
135 if let Some(ref bias) = self.bias {
137 let bias_var = bias.variable();
138 output = output.add_var(&bias_var);
139 }
140
141 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
143 let mut output_shape: Vec<usize> = batch_dims.clone();
144 output_shape.push(self.out_features);
145 output.reshape(&output_shape)
146 } else {
147 output
148 }
149 }
150
151 fn parameters(&self) -> Vec<Parameter> {
152 let mut params = vec![self.weight.clone()];
153 if let Some(ref bias) = self.bias {
154 params.push(bias.clone());
155 }
156 params
157 }
158
159 fn named_parameters(&self) -> HashMap<String, Parameter> {
160 let mut params = HashMap::new();
161 params.insert("weight".to_string(), self.weight.clone());
162 if let Some(ref bias) = self.bias {
163 params.insert("bias".to_string(), bias.clone());
164 }
165 params
166 }
167
168 fn name(&self) -> &'static str {
169 "Linear"
170 }
171}
172
173impl std::fmt::Debug for Linear {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("Linear")
176 .field("in_features", &self.in_features)
177 .field("out_features", &self.out_features)
178 .field("bias", &self.bias.is_some())
179 .finish()
180 }
181}
182
183#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_linear_creation() {
193 let linear = Linear::new(10, 5);
194 assert_eq!(linear.in_features(), 10);
195 assert_eq!(linear.out_features(), 5);
196 assert!(linear.bias.is_some());
197 }
198
199 #[test]
200 fn test_linear_no_bias() {
201 let linear = Linear::with_bias(10, 5, false);
202 assert!(linear.bias.is_none());
203 }
204
205 #[test]
206 fn test_linear_forward() {
207 let linear = Linear::new(3, 2);
208
209 let input = Variable::new(
210 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
211 false,
212 );
213 let output = linear.forward(&input);
214
215 assert_eq!(output.shape(), vec![1, 2]);
216 }
217
218 #[test]
219 fn test_linear_batch_forward() {
220 let linear = Linear::new(4, 2);
221
222 let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
223 let output = linear.forward(&input);
224
225 assert_eq!(output.shape(), vec![3, 2]);
226 }
227
228 #[test]
229 fn test_linear_parameters() {
230 let linear = Linear::new(10, 5);
231 let params = linear.parameters();
232 assert_eq!(params.len(), 2); let linear_no_bias = Linear::with_bias(10, 5, false);
235 let params_no_bias = linear_no_bias.parameters();
236 assert_eq!(params_no_bias.len(), 1); }
238
239 #[test]
240 fn test_linear_num_parameters() {
241 let linear = Linear::new(10, 5);
242 assert_eq!(linear.num_parameters(), 55);
244 }
245}