Skip to main content

axonml_nn/layers/
linear.rs

1//! Linear Layer - Fully Connected Layer
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/linear.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26// =============================================================================
27// Linear
28// =============================================================================
29
30/// Applies a linear transformation to the input.
31///
32/// y = xW^T + b
33///
34/// # Arguments
35/// * `in_features` - Size of each input sample
36/// * `out_features` - Size of each output sample
37/// * `bias` - If true, adds a learnable bias (default: true)
38///
39/// # Shape
40/// - Input: (*, in_features) where * means any number of dimensions
41/// - Output: (*, out_features)
42///
43/// # Example
44/// ```ignore
45/// let linear = Linear::new(20, 30);
46/// let input = Variable::new(randn(&[128, 20]), true);
47/// let output = linear.forward(&input);  // Shape: [128, 30]
48/// ```
49pub struct Linear {
50    /// Weight matrix of shape (out_features, in_features).
51    pub weight: Parameter,
52    /// Bias vector of shape (out_features).
53    pub bias: Option<Parameter>,
54    /// Input features.
55    in_features: usize,
56    /// Output features.
57    out_features: usize,
58}
59
60impl Linear {
61    /// Creates a new Linear layer with bias.
62    pub fn new(in_features: usize, out_features: usize) -> Self {
63        Self::with_bias(in_features, out_features, true)
64    }
65
66    /// Creates a new Linear layer with optional bias.
67    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
68        // Initialize weights using Kaiming uniform
69        let weight_data = kaiming_uniform(out_features, in_features);
70        let weight = Parameter::named("weight", weight_data, true);
71
72        let bias_param = if bias {
73            // Initialize bias to zeros
74            let bias_data = zeros(&[out_features]);
75            Some(Parameter::named("bias", bias_data, true))
76        } else {
77            None
78        };
79
80        Self {
81            weight,
82            bias: bias_param,
83            in_features,
84            out_features,
85        }
86    }
87
88    /// Creates a Linear layer from existing weight and bias tensors.
89    pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
90        let out_features = weight.shape()[0];
91        let in_features = weight.shape()[1];
92
93        Self {
94            weight: Parameter::named("weight", weight, true),
95            bias: bias.map(|b| Parameter::named("bias", b, true)),
96            in_features,
97            out_features,
98        }
99    }
100
101    /// Returns the input feature dimension.
102    pub fn in_features(&self) -> usize {
103        self.in_features
104    }
105
106    /// Returns the output feature dimension.
107    pub fn out_features(&self) -> usize {
108        self.out_features
109    }
110}
111
112impl Module for Linear {
113    fn forward(&self, input: &Variable) -> Variable {
114        // Get input shape
115        let input_shape = input.shape();
116        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
117
118        // Reshape to 2D: (batch, in_features)
119        let total_batch: usize = batch_dims.iter().product();
120        let input_2d = if input_shape.len() > 2 {
121            // Use autograd-tracked reshape to maintain gradient flow
122            input.reshape(&[total_batch, self.in_features])
123        } else {
124            input.clone()
125        };
126
127        // y = xW^T
128        // x: (batch, in_features), W: (out_features, in_features)
129        // We need x @ W^T = (batch, out_features)
130        let weight_var = self.weight.variable();
131        // Use autograd-tracked transpose to maintain gradient flow
132        let weight_t = weight_var.transpose(0, 1);
133        let mut output = input_2d.matmul(&weight_t);
134
135        // Add bias if present
136        if let Some(ref bias) = self.bias {
137            let bias_var = bias.variable();
138            output = output.add_var(&bias_var);
139        }
140
141        // Reshape back to original batch dimensions
142        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
143            let mut output_shape: Vec<usize> = batch_dims.clone();
144            output_shape.push(self.out_features);
145            output.reshape(&output_shape)
146        } else {
147            output
148        }
149    }
150
151    fn parameters(&self) -> Vec<Parameter> {
152        let mut params = vec![self.weight.clone()];
153        if let Some(ref bias) = self.bias {
154            params.push(bias.clone());
155        }
156        params
157    }
158
159    fn named_parameters(&self) -> HashMap<String, Parameter> {
160        let mut params = HashMap::new();
161        params.insert("weight".to_string(), self.weight.clone());
162        if let Some(ref bias) = self.bias {
163            params.insert("bias".to_string(), bias.clone());
164        }
165        params
166    }
167
168    fn name(&self) -> &'static str {
169        "Linear"
170    }
171}
172
173impl std::fmt::Debug for Linear {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("Linear")
176            .field("in_features", &self.in_features)
177            .field("out_features", &self.out_features)
178            .field("bias", &self.bias.is_some())
179            .finish()
180    }
181}
182
183// =============================================================================
184// Tests
185// =============================================================================
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_linear_creation() {
193        let linear = Linear::new(10, 5);
194        assert_eq!(linear.in_features(), 10);
195        assert_eq!(linear.out_features(), 5);
196        assert!(linear.bias.is_some());
197    }
198
199    #[test]
200    fn test_linear_no_bias() {
201        let linear = Linear::with_bias(10, 5, false);
202        assert!(linear.bias.is_none());
203    }
204
205    #[test]
206    fn test_linear_forward() {
207        let linear = Linear::new(3, 2);
208
209        let input = Variable::new(
210            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
211            false,
212        );
213        let output = linear.forward(&input);
214
215        assert_eq!(output.shape(), vec![1, 2]);
216    }
217
218    #[test]
219    fn test_linear_batch_forward() {
220        let linear = Linear::new(4, 2);
221
222        let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
223        let output = linear.forward(&input);
224
225        assert_eq!(output.shape(), vec![3, 2]);
226    }
227
228    #[test]
229    fn test_linear_parameters() {
230        let linear = Linear::new(10, 5);
231        let params = linear.parameters();
232        assert_eq!(params.len(), 2); // weight + bias
233
234        let linear_no_bias = Linear::with_bias(10, 5, false);
235        let params_no_bias = linear_no_bias.parameters();
236        assert_eq!(params_no_bias.len(), 1); // weight only
237    }
238
239    #[test]
240    fn test_linear_num_parameters() {
241        let linear = Linear::new(10, 5);
242        // weight: 10*5 = 50, bias: 5, total: 55
243        assert_eq!(linear.num_parameters(), 55);
244    }
245}