Skip to main content

axonml_nn/layers/
linear.rs

1//! `Linear` — fully connected layer (y = xW^T + b).
2//!
3//! 245 lines. `Linear::new(in_features, out_features)` with Xavier-uniform
4//! weight init and optional bias. Implements `Module` (forward, parameters,
5//! train/eval, zero_grad, to_device). Also `Linear::no_bias(in, out)` for
6//! bias-free variants used in some LLM architectures.
7//!
8//! # File
9//! `crates/axonml-nn/src/layers/linear.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use std::collections::HashMap;
24
25use axonml_autograd::Variable;
26use axonml_tensor::Tensor;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Linear
34// =============================================================================
35
36/// Applies a linear transformation to the input.
37///
38/// y = xW^T + b
39///
40/// # Arguments
41/// * `in_features` - Size of each input sample
42/// * `out_features` - Size of each output sample
43/// * `bias` - If true, adds a learnable bias (default: true)
44///
45/// # Shape
46/// - Input: (*, in_features) where * means any number of dimensions
47/// - Output: (*, out_features)
48///
49/// # Example
50/// ```ignore
51/// let linear = Linear::new(20, 30);
52/// let input = Variable::new(randn(&[128, 20]), true);
53/// let output = linear.forward(&input);  // Shape: [128, 30]
54/// ```
55pub struct Linear {
56    /// Weight matrix of shape (out_features, in_features).
57    pub weight: Parameter,
58    /// Bias vector of shape (out_features).
59    pub bias: Option<Parameter>,
60    /// Input features.
61    in_features: usize,
62    /// Output features.
63    out_features: usize,
64}
65
66impl Linear {
67    /// Creates a new Linear layer with bias.
68    pub fn new(in_features: usize, out_features: usize) -> Self {
69        Self::with_bias(in_features, out_features, true)
70    }
71
72    /// Creates a new Linear layer with optional bias.
73    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
74        // Initialize weights using Kaiming uniform
75        let weight_data = kaiming_uniform(out_features, in_features);
76        let weight = Parameter::named("weight", weight_data, true);
77
78        let bias_param = if bias {
79            // Initialize bias to zeros
80            let bias_data = zeros(&[out_features]);
81            Some(Parameter::named("bias", bias_data, true))
82        } else {
83            None
84        };
85
86        Self {
87            weight,
88            bias: bias_param,
89            in_features,
90            out_features,
91        }
92    }
93
94    /// Creates a Linear layer from existing weight and bias tensors.
95    pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
96        let out_features = weight.shape()[0];
97        let in_features = weight.shape()[1];
98
99        Self {
100            weight: Parameter::named("weight", weight, true),
101            bias: bias.map(|b| Parameter::named("bias", b, true)),
102            in_features,
103            out_features,
104        }
105    }
106
107    /// Returns the input feature dimension.
108    pub fn in_features(&self) -> usize {
109        self.in_features
110    }
111
112    /// Returns the output feature dimension.
113    pub fn out_features(&self) -> usize {
114        self.out_features
115    }
116}
117
118impl Module for Linear {
119    fn forward(&self, input: &Variable) -> Variable {
120        // Get input shape
121        let input_shape = input.shape();
122        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
123
124        // Reshape to 2D: (batch, in_features)
125        let total_batch: usize = batch_dims.iter().product();
126        let input_2d = if input_shape.len() > 2 {
127            // Use autograd-tracked reshape to maintain gradient flow
128            input.reshape(&[total_batch, self.in_features])
129        } else {
130            input.clone()
131        };
132
133        // y = xW^T
134        // x: (batch, in_features), W: (out_features, in_features)
135        // We need x @ W^T = (batch, out_features)
136        let weight_var = self.weight.variable();
137        // Use autograd-tracked transpose to maintain gradient flow
138        let weight_t = weight_var.transpose(0, 1);
139        let mut output = input_2d.matmul(&weight_t);
140
141        // Add bias if present
142        if let Some(ref bias) = self.bias {
143            let bias_var = bias.variable();
144            output = output.add_var(&bias_var);
145        }
146
147        // Reshape back to original batch dimensions
148        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
149            let mut output_shape: Vec<usize> = batch_dims.clone();
150            output_shape.push(self.out_features);
151            output.reshape(&output_shape)
152        } else {
153            output
154        }
155    }
156
157    fn parameters(&self) -> Vec<Parameter> {
158        let mut params = vec![self.weight.clone()];
159        if let Some(ref bias) = self.bias {
160            params.push(bias.clone());
161        }
162        params
163    }
164
165    fn named_parameters(&self) -> HashMap<String, Parameter> {
166        let mut params = HashMap::new();
167        params.insert("weight".to_string(), self.weight.clone());
168        if let Some(ref bias) = self.bias {
169            params.insert("bias".to_string(), bias.clone());
170        }
171        params
172    }
173
174    fn name(&self) -> &'static str {
175        "Linear"
176    }
177}
178
179impl std::fmt::Debug for Linear {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("Linear")
182            .field("in_features", &self.in_features)
183            .field("out_features", &self.out_features)
184            .field("bias", &self.bias.is_some())
185            .finish()
186    }
187}
188
189// =============================================================================
190// Tests
191// =============================================================================
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_linear_creation() {
199        let linear = Linear::new(10, 5);
200        assert_eq!(linear.in_features(), 10);
201        assert_eq!(linear.out_features(), 5);
202        assert!(linear.bias.is_some());
203    }
204
205    #[test]
206    fn test_linear_no_bias() {
207        let linear = Linear::with_bias(10, 5, false);
208        assert!(linear.bias.is_none());
209    }
210
211    #[test]
212    fn test_linear_forward() {
213        let linear = Linear::new(3, 2);
214
215        let input = Variable::new(
216            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
217            false,
218        );
219        let output = linear.forward(&input);
220
221        assert_eq!(output.shape(), vec![1, 2]);
222    }
223
224    #[test]
225    fn test_linear_batch_forward() {
226        let linear = Linear::new(4, 2);
227
228        let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
229        let output = linear.forward(&input);
230
231        assert_eq!(output.shape(), vec![3, 2]);
232    }
233
234    #[test]
235    fn test_linear_parameters() {
236        let linear = Linear::new(10, 5);
237        let params = linear.parameters();
238        assert_eq!(params.len(), 2); // weight + bias
239
240        let linear_no_bias = Linear::with_bias(10, 5, false);
241        let params_no_bias = linear_no_bias.parameters();
242        assert_eq!(params_no_bias.len(), 1); // weight only
243    }
244
245    #[test]
246    fn test_linear_num_parameters() {
247        let linear = Linear::new(10, 5);
248        // weight: 10*5 = 50, bias: 5, total: 55
249        assert_eq!(linear.num_parameters(), 55);
250    }
251}