1use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7pub struct Linear {
9 weight: Tensor,
10 bias: Option<Tensor>,
11 in_features: usize,
12 out_features: usize,
13 training: bool,
14}
15
16impl Linear {
17 pub fn new(in_features: usize, out_features: usize) -> Self {
19 Self::with_bias(in_features, out_features, true)
20 }
21
22 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
24 let weight = init::kaiming_uniform(&[out_features, in_features], in_features);
26
27 let bias = if bias {
28 let bound = 1.0 / (in_features as f32).sqrt();
30 Some(init::uniform(&[out_features], -bound, bound))
31 } else {
32 None
33 };
34
35 Linear {
36 weight,
37 bias,
38 in_features,
39 out_features,
40 training: true,
41 }
42 }
43
44 pub fn in_features(&self) -> usize {
46 self.in_features
47 }
48
49 pub fn out_features(&self) -> usize {
51 self.out_features
52 }
53}
54
55impl Module for Linear {
56 fn forward(&self, input: &Tensor) -> Tensor {
57 let weight_t = self.weight.t().unwrap();
62 let mut output = input.matmul(&weight_t).unwrap();
63
64 if let Some(ref bias) = self.bias {
65 output = output.add(bias).unwrap();
66 }
67
68 output
69 }
70
71 fn parameters(&self) -> Vec<Tensor> {
72 let mut params = vec![self.weight.clone()];
73 if let Some(ref bias) = self.bias {
74 params.push(bias.clone());
75 }
76 params
77 }
78
79 fn train(&mut self) {
80 self.training = true;
81 }
82
83 fn eval(&mut self) {
84 self.training = false;
85 }
86
87 fn is_training(&self) -> bool {
88 self.training
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_linear_forward() {
98 let linear = Linear::new(10, 5);
99 let input = Tensor::randn(&[2, 10]);
100 let output = linear.forward(&input);
101
102 assert_eq!(output.dims(), &[2, 5]);
103 }
104
105 #[test]
106 fn test_linear_no_bias() {
107 let linear = Linear::with_bias(10, 5, false);
108 let input = Tensor::randn(&[2, 10]);
109 let output = linear.forward(&input);
110
111 assert_eq!(output.dims(), &[2, 5]);
112 }
113
114 #[test]
115 fn test_linear_parameters() {
116 let linear = Linear::new(10, 5);
117 let params = linear.parameters();
118
119 assert_eq!(params.len(), 2); assert_eq!(params[0].numel(), 50); assert_eq!(params[1].numel(), 5);
122 }
123}