ferrite/network/module/
linear.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::module::*;
5use crate::tensor::*;
6
7pub struct Linear {
9 weight: Arc<RwLock<Tensor>>,
10 bias: Option<Arc<RwLock<Tensor>>>,
11 training: bool,
12}
13
14
15impl Linear {
16 pub fn new(in_features: usize, out_features: usize, bias: bool, device: Device) -> Self {
17 let bound = f32::sqrt(1./in_features as f32);
18 let weight = Arc::new(RwLock::new(
19 Tensor::ones(vec![out_features, in_features], device, Some(true))
21 ));
22
23 let bias = if bias {
24 Some(Arc::new(RwLock::new(Tensor::zeros(vec![out_features], device, Some(false)))))
25 } else {
26 None
27 };
28
29 Linear{weight, bias, training: false,}
30 }
31
32 fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
33 if let Ok(weight) = self.weight.read() {
35 f("weight", &weight);
36 }
37 if let Some(bias) = &self.bias {
38 if let Ok(bias) = bias.read() {
39 f("bias", &bias);
40 }
41 }
42 }
43}
44
45
46
47impl Module for Linear {
48 fn forward(&mut self, input: &Tensor) -> Tensor {
49 let weight = self.weight.read().unwrap();
51
52 let mut output = input.matmul(&weight, false, true);
54
55 if let Some(bias) = &self.bias {
57 let bias = bias.read().unwrap();
58 output = &output + &*bias;
59 }
60 output
61 }
62
63 fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
64 let mut params = HashMap::new();
65 params.insert("weight".to_string(), self.weight.clone());
66 if let Some(bias) = &self.bias {
67 params.insert("bias".to_string(), bias.clone());
68 }
69 params
70 }
71
72 fn train(&mut self) {
73 self.training = true;
74 }
75
76 fn eval(&mut self) {
77 self.training = false;
78 }
79
80 fn zero_grad(&mut self) {
81 todo!("Implement zero_grad for Linear");
82 }
83
84}