ferrite/network/module/
module.rs

1use crate::tensor::*;
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5
6
7pub trait Module {
8  fn forward(&mut self, input: &Tensor) -> Tensor;
9  
10  // Optional methods with defaults
11  fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
12    HashMap::new()
13  }
14  
15  fn train(&mut self) { }
16  fn eval(&mut self) { }
17  fn zero_grad(&mut self) { }
18
19  /// Visit all parameters with a callback function
20  fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
21    // Default implementation uses parameters()
22    for (name, param) in self.parameters() {
23      if let Ok(tensor) = param.read() {
24        f(&name, &tensor);
25      }
26    }
27  }
28
29  /// Print all parameters and their shapes
30  fn print_parameters(&self, values: bool) {
31    self.visit_parameters(&mut |name, param| {
32      if values {
33        println!("Parameter {}: shape={:?}. values={:?}", name, param.shape(), param);
34      } else {
35        println!("Parameter {}: shape={:?}", name, param.shape());
36      }
37    });
38  }
39}