ferrite/network/module/
linear.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::module::*;
5use crate::tensor::*;
6
7// Linear layer implementation
8pub struct Linear {
9  weight: Arc<RwLock<Tensor>>,
10  bias: Option<Arc<RwLock<Tensor>>>,
11  training: bool,
12}
13
14
15impl Linear {
16  pub fn new(in_features: usize, out_features: usize, bias: bool, device: Device) -> Self {
17    let bound = f32::sqrt(1./in_features as f32);
18    let weight = Arc::new(RwLock::new(
19      //Tensor::uniform(-bound, bound, vec![out_features, in_features], Some(true))
20      Tensor::ones(vec![out_features, in_features], device, Some(true))
21    ));
22
23    let bias = if bias {
24      Some(Arc::new(RwLock::new(Tensor::zeros(vec![out_features], device, Some(false)))))
25    } else {
26      None
27    };
28
29    Linear{weight, bias, training: false,}
30  }
31
32  fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
33    // More efficient direct implementation than using parameters()
34    if let Ok(weight) = self.weight.read() {
35      f("weight", &weight);
36    }
37    if let Some(bias) = &self.bias {
38      if let Ok(bias) = bias.read() {
39        f("bias", &bias);
40      }
41    }
42  }
43}
44
45
46
47impl Module for Linear {
48  fn forward(&mut self, input: &Tensor) -> Tensor {
49    // Get weight parameter and access its tensor
50    let weight = self.weight.read().unwrap();
51
52    // Perform matrix multiplication
53    let mut output = input.matmul(&weight, false, true);
54
55    // Add bias if present
56    if let Some(bias) = &self.bias {
57      let bias = bias.read().unwrap();
58      output = &output + &*bias;
59    } 
60    output
61  }
62
63  fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
64    let mut params = HashMap::new();
65    params.insert("weight".to_string(), self.weight.clone());
66    if let Some(bias) = &self.bias {
67      params.insert("bias".to_string(), bias.clone());
68    }
69    params
70  }
71
72  fn train(&mut self) {
73    self.training = true;
74  }
75
76  fn eval(&mut self) {
77    self.training = false;
78  }
79
80  fn zero_grad(&mut self) {
81    todo!("Implement zero_grad for Linear");
82  }
83
84}