ferrite/network/
module.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
use crate::autograd::tensor::*;
use std::collections::HashMap;
use std::rc::Rc;
use std::cell::RefCell;

use super::parameter::*;

pub trait Segment {
  fn forward(input: Tensor) -> Tensor;
}

pub struct Module {
  parameters: HashMap<String, Rc<RefCell<Parameter>>>,
  modules: HashMap<String, Rc<RefCell<Module>>>,
  training: bool
}

impl Module {
  pub fn new() -> Self {
    Module {
      parameters: HashMap::new(),
      modules: HashMap::new(),
      training: false,
    }
  }

  pub fn add_parameter(&mut self, name: &str, parameter: Parameter) {
    self.parameters.insert(name.to_string(), Rc::new(RefCell::new(parameter)));
  }

  pub fn add_module(&mut self, name: &str, module: Module) {
    self.modules.insert(name.to_string(), Rc::new(RefCell::new(module)));
  }

  pub fn visit_parameters<F>(&self, mut f: F)
  where
    F: FnMut(&Parameter)
  {
    // Visit parameters in current module
    for (name, param) in &self.parameters {
      f(&*param.borrow());
    }
    
    // Recursively visit child modules
    for (name, module) in &self.modules {
      module.borrow().visit_parameters(&mut f);
    }
  }

  pub fn train(&self) {

  }

  pub fn eval(&self) {

  }

  pub fn zero_grad(&self) {

  }
}