ghostflow_nn/
linear.rs

1//! Linear (fully connected) layer
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7/// Linear transformation: y = xW^T + b
8pub struct Linear {
9    weight: Tensor,
10    bias: Option<Tensor>,
11    in_features: usize,
12    out_features: usize,
13    training: bool,
14}
15
16impl Linear {
17    /// Create a new linear layer
18    pub fn new(in_features: usize, out_features: usize) -> Self {
19        Self::with_bias(in_features, out_features, true)
20    }
21
22    /// Create a linear layer with optional bias
23    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
24        // Kaiming initialization for weights
25        let weight = init::kaiming_uniform(&[out_features, in_features], in_features);
26        
27        let bias = if bias {
28            // Uniform initialization for bias
29            let bound = 1.0 / (in_features as f32).sqrt();
30            Some(init::uniform(&[out_features], -bound, bound))
31        } else {
32            None
33        };
34
35        Linear {
36            weight,
37            bias,
38            in_features,
39            out_features,
40            training: true,
41        }
42    }
43
44    /// Get input features
45    pub fn in_features(&self) -> usize {
46        self.in_features
47    }
48
49    /// Get output features
50    pub fn out_features(&self) -> usize {
51        self.out_features
52    }
53}
54
55impl Module for Linear {
56    fn forward(&self, input: &Tensor) -> Tensor {
57        // input: [*, in_features]
58        // weight: [out_features, in_features]
59        // output: [*, out_features]
60        
61        let weight_t = self.weight.t().unwrap();
62        let mut output = input.matmul(&weight_t).unwrap();
63        
64        if let Some(ref bias) = self.bias {
65            output = output.add(bias).unwrap();
66        }
67        
68        output
69    }
70
71    fn parameters(&self) -> Vec<Tensor> {
72        let mut params = vec![self.weight.clone()];
73        if let Some(ref bias) = self.bias {
74            params.push(bias.clone());
75        }
76        params
77    }
78
79    fn train(&mut self) {
80        self.training = true;
81    }
82
83    fn eval(&mut self) {
84        self.training = false;
85    }
86
87    fn is_training(&self) -> bool {
88        self.training
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_linear_forward() {
98        let linear = Linear::new(10, 5);
99        let input = Tensor::randn(&[2, 10]);
100        let output = linear.forward(&input);
101        
102        assert_eq!(output.dims(), &[2, 5]);
103    }
104
105    #[test]
106    fn test_linear_no_bias() {
107        let linear = Linear::with_bias(10, 5, false);
108        let input = Tensor::randn(&[2, 10]);
109        let output = linear.forward(&input);
110        
111        assert_eq!(output.dims(), &[2, 5]);
112    }
113
114    #[test]
115    fn test_linear_parameters() {
116        let linear = Linear::new(10, 5);
117        let params = linear.parameters();
118        
119        assert_eq!(params.len(), 2); // weight + bias
120        assert_eq!(params[0].numel(), 50); // 10 * 5
121        assert_eq!(params[1].numel(), 5);
122    }
123}