ferrite/network/module/
module.rs1use crate::tensor::*;
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5
6
7pub trait Module {
8 fn forward(&mut self, input: &Tensor) -> Tensor;
9
10 fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
12 HashMap::new()
13 }
14
15 fn train(&mut self) { }
16 fn eval(&mut self) { }
17 fn zero_grad(&mut self) { }
18
19 fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
21 for (name, param) in self.parameters() {
23 if let Ok(tensor) = param.read() {
24 f(&name, &tensor);
25 }
26 }
27 }
28
29 fn print_parameters(&self, values: bool) {
31 self.visit_parameters(&mut |name, param| {
32 if values {
33 println!("Parameter {}: shape={:?}. values={:?}", name, param.shape(), param);
34 } else {
35 println!("Parameter {}: shape={:?}", name, param.shape());
36 }
37 });
38 }
39}