Skip to main content

axonml_nn/layers/
linear.rs

1//! Linear Layer - Fully Connected Layer
2//!
3//! Applies a linear transformation: y = xW^T + b
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{kaiming_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// Linear
19// =============================================================================
20
21/// Applies a linear transformation to the input.
22///
23/// y = xW^T + b
24///
25/// # Arguments
26/// * `in_features` - Size of each input sample
27/// * `out_features` - Size of each output sample
28/// * `bias` - If true, adds a learnable bias (default: true)
29///
30/// # Shape
31/// - Input: (*, in_features) where * means any number of dimensions
32/// - Output: (*, out_features)
33///
34/// # Example
35/// ```ignore
36/// let linear = Linear::new(20, 30);
37/// let input = Variable::new(randn(&[128, 20]), true);
38/// let output = linear.forward(&input);  // Shape: [128, 30]
39/// ```
40pub struct Linear {
41    /// Weight matrix of shape (out_features, in_features).
42    pub weight: Parameter,
43    /// Bias vector of shape (out_features).
44    pub bias: Option<Parameter>,
45    /// Input features.
46    in_features: usize,
47    /// Output features.
48    out_features: usize,
49}
50
51impl Linear {
52    /// Creates a new Linear layer with bias.
53    pub fn new(in_features: usize, out_features: usize) -> Self {
54        Self::with_bias(in_features, out_features, true)
55    }
56
57    /// Creates a new Linear layer with optional bias.
58    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
59        // Initialize weights using Kaiming uniform
60        let weight_data = kaiming_uniform(out_features, in_features);
61        let weight = Parameter::named("weight", weight_data, true);
62
63        let bias_param = if bias {
64            // Initialize bias to zeros
65            let bias_data = zeros(&[out_features]);
66            Some(Parameter::named("bias", bias_data, true))
67        } else {
68            None
69        };
70
71        Self {
72            weight,
73            bias: bias_param,
74            in_features,
75            out_features,
76        }
77    }
78
79    /// Creates a Linear layer from existing weight and bias tensors.
80    pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
81        let out_features = weight.shape()[0];
82        let in_features = weight.shape()[1];
83
84        Self {
85            weight: Parameter::named("weight", weight, true),
86            bias: bias.map(|b| Parameter::named("bias", b, true)),
87            in_features,
88            out_features,
89        }
90    }
91
92    /// Returns the input feature dimension.
93    pub fn in_features(&self) -> usize {
94        self.in_features
95    }
96
97    /// Returns the output feature dimension.
98    pub fn out_features(&self) -> usize {
99        self.out_features
100    }
101}
102
103impl Module for Linear {
104    fn forward(&self, input: &Variable) -> Variable {
105        // Get input shape
106        let input_shape = input.shape();
107        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
108
109        // Reshape to 2D: (batch, in_features)
110        let total_batch: usize = batch_dims.iter().product();
111        let input_2d = if input_shape.len() > 2 {
112            // Use autograd-tracked reshape to maintain gradient flow
113            input.reshape(&[total_batch, self.in_features])
114        } else {
115            input.clone()
116        };
117
118        // y = xW^T
119        // x: (batch, in_features), W: (out_features, in_features)
120        // We need x @ W^T = (batch, out_features)
121        let weight_var = self.weight.variable();
122        // Use autograd-tracked transpose to maintain gradient flow
123        let weight_t = weight_var.transpose(0, 1);
124        let mut output = input_2d.matmul(&weight_t);
125
126        // Add bias if present
127        if let Some(ref bias) = self.bias {
128            let bias_var = bias.variable();
129            output = output.add_var(&bias_var);
130        }
131
132        // Reshape back to original batch dimensions
133        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
134            let mut output_shape: Vec<usize> = batch_dims.clone();
135            output_shape.push(self.out_features);
136            output.reshape(&output_shape)
137        } else {
138            output
139        }
140    }
141
142    fn parameters(&self) -> Vec<Parameter> {
143        let mut params = vec![self.weight.clone()];
144        if let Some(ref bias) = self.bias {
145            params.push(bias.clone());
146        }
147        params
148    }
149
150    fn named_parameters(&self) -> HashMap<String, Parameter> {
151        let mut params = HashMap::new();
152        params.insert("weight".to_string(), self.weight.clone());
153        if let Some(ref bias) = self.bias {
154            params.insert("bias".to_string(), bias.clone());
155        }
156        params
157    }
158
159    fn name(&self) -> &'static str {
160        "Linear"
161    }
162}
163
164impl std::fmt::Debug for Linear {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("Linear")
167            .field("in_features", &self.in_features)
168            .field("out_features", &self.out_features)
169            .field("bias", &self.bias.is_some())
170            .finish()
171    }
172}
173
174// =============================================================================
175// Tests
176// =============================================================================
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_linear_creation() {
184        let linear = Linear::new(10, 5);
185        assert_eq!(linear.in_features(), 10);
186        assert_eq!(linear.out_features(), 5);
187        assert!(linear.bias.is_some());
188    }
189
190    #[test]
191    fn test_linear_no_bias() {
192        let linear = Linear::with_bias(10, 5, false);
193        assert!(linear.bias.is_none());
194    }
195
196    #[test]
197    fn test_linear_forward() {
198        let linear = Linear::new(3, 2);
199
200        let input = Variable::new(
201            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
202            false,
203        );
204        let output = linear.forward(&input);
205
206        assert_eq!(output.shape(), vec![1, 2]);
207    }
208
209    #[test]
210    fn test_linear_batch_forward() {
211        let linear = Linear::new(4, 2);
212
213        let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
214        let output = linear.forward(&input);
215
216        assert_eq!(output.shape(), vec![3, 2]);
217    }
218
219    #[test]
220    fn test_linear_parameters() {
221        let linear = Linear::new(10, 5);
222        let params = linear.parameters();
223        assert_eq!(params.len(), 2); // weight + bias
224
225        let linear_no_bias = Linear::with_bias(10, 5, false);
226        let params_no_bias = linear_no_bias.parameters();
227        assert_eq!(params_no_bias.len(), 1); // weight only
228    }
229
230    #[test]
231    fn test_linear_num_parameters() {
232        let linear = Linear::new(10, 5);
233        // weight: 10*5 = 50, bias: 5, total: 55
234        assert_eq!(linear.num_parameters(), 55);
235    }
236}