use crate::tensor::*;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub trait Module {
fn forward(&mut self, input: &Tensor) -> Tensor;
fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
HashMap::new()
}
fn train(&mut self) { }
fn eval(&mut self) { }
fn zero_grad(&mut self) { }
fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
for (name, param) in self.parameters() {
if let Ok(tensor) = param.read() {
f(&name, &tensor);
}
}
}
fn print_parameters(&self, values: bool) {
self.visit_parameters(&mut |name, param| {
if values {
println!("Parameter {}: shape={:?}. values={:?}", name, param.shape(), param);
} else {
println!("Parameter {}: shape={:?}", name, param.shape());
}
});
}
}