axonml_nn/layers/
linear.rs1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{kaiming_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct Linear {
41 pub weight: Parameter,
43 pub bias: Option<Parameter>,
45 in_features: usize,
47 out_features: usize,
49}
50
51impl Linear {
52 pub fn new(in_features: usize, out_features: usize) -> Self {
54 Self::with_bias(in_features, out_features, true)
55 }
56
57 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
59 let weight_data = kaiming_uniform(out_features, in_features);
61 let weight = Parameter::named("weight", weight_data, true);
62
63 let bias_param = if bias {
64 let bias_data = zeros(&[out_features]);
66 Some(Parameter::named("bias", bias_data, true))
67 } else {
68 None
69 };
70
71 Self {
72 weight,
73 bias: bias_param,
74 in_features,
75 out_features,
76 }
77 }
78
79 pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
81 let out_features = weight.shape()[0];
82 let in_features = weight.shape()[1];
83
84 Self {
85 weight: Parameter::named("weight", weight, true),
86 bias: bias.map(|b| Parameter::named("bias", b, true)),
87 in_features,
88 out_features,
89 }
90 }
91
92 pub fn in_features(&self) -> usize {
94 self.in_features
95 }
96
97 pub fn out_features(&self) -> usize {
99 self.out_features
100 }
101}
102
103impl Module for Linear {
104 fn forward(&self, input: &Variable) -> Variable {
105 let input_shape = input.shape();
107 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
108
109 let total_batch: usize = batch_dims.iter().product();
111 let input_2d = if input_shape.len() > 2 {
112 input.reshape(&[total_batch, self.in_features])
114 } else {
115 input.clone()
116 };
117
118 let weight_var = self.weight.variable();
122 let weight_t = weight_var.transpose(0, 1);
124 let mut output = input_2d.matmul(&weight_t);
125
126 if let Some(ref bias) = self.bias {
128 let bias_var = bias.variable();
129 output = output.add_var(&bias_var);
130 }
131
132 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
134 let mut output_shape: Vec<usize> = batch_dims.clone();
135 output_shape.push(self.out_features);
136 output.reshape(&output_shape)
137 } else {
138 output
139 }
140 }
141
142 fn parameters(&self) -> Vec<Parameter> {
143 let mut params = vec![self.weight.clone()];
144 if let Some(ref bias) = self.bias {
145 params.push(bias.clone());
146 }
147 params
148 }
149
150 fn named_parameters(&self) -> HashMap<String, Parameter> {
151 let mut params = HashMap::new();
152 params.insert("weight".to_string(), self.weight.clone());
153 if let Some(ref bias) = self.bias {
154 params.insert("bias".to_string(), bias.clone());
155 }
156 params
157 }
158
159 fn name(&self) -> &'static str {
160 "Linear"
161 }
162}
163
164impl std::fmt::Debug for Linear {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("Linear")
167 .field("in_features", &self.in_features)
168 .field("out_features", &self.out_features)
169 .field("bias", &self.bias.is_some())
170 .finish()
171 }
172}
173
174#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_linear_creation() {
184 let linear = Linear::new(10, 5);
185 assert_eq!(linear.in_features(), 10);
186 assert_eq!(linear.out_features(), 5);
187 assert!(linear.bias.is_some());
188 }
189
190 #[test]
191 fn test_linear_no_bias() {
192 let linear = Linear::with_bias(10, 5, false);
193 assert!(linear.bias.is_none());
194 }
195
196 #[test]
197 fn test_linear_forward() {
198 let linear = Linear::new(3, 2);
199
200 let input = Variable::new(
201 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
202 false,
203 );
204 let output = linear.forward(&input);
205
206 assert_eq!(output.shape(), vec![1, 2]);
207 }
208
209 #[test]
210 fn test_linear_batch_forward() {
211 let linear = Linear::new(4, 2);
212
213 let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
214 let output = linear.forward(&input);
215
216 assert_eq!(output.shape(), vec![3, 2]);
217 }
218
219 #[test]
220 fn test_linear_parameters() {
221 let linear = Linear::new(10, 5);
222 let params = linear.parameters();
223 assert_eq!(params.len(), 2); let linear_no_bias = Linear::with_bias(10, 5, false);
226 let params_no_bias = linear_no_bias.parameters();
227 assert_eq!(params_no_bias.len(), 1); }
229
230 #[test]
231 fn test_linear_num_parameters() {
232 let linear = Linear::new(10, 5);
233 assert_eq!(linear.num_parameters(), 55);
235 }
236}