use super::Parameter;
use alloc::string::String;
use alloc::vec::Vec;
use crate::Scalar;
use crate::tensor::Tensor;
pub trait Module<S: Scalar> {
fn forward(&mut self, input: &Tensor<S>) -> Tensor<S>;
fn backward(&mut self, grad_output: &Tensor<S>) -> Tensor<S>;
fn parameters(&self) -> Vec<&Parameter<S>>;
fn parameters_mut(&mut self) -> Vec<&mut Parameter<S>>;
fn named_parameters(&self) -> Vec<(String, &Parameter<S>)> {
self.parameters()
.into_iter()
.enumerate()
.map(|(i, p)| (alloc::format!("{}", i), p))
.collect()
}
fn named_parameters_mut(&mut self) -> Vec<(String, &mut Parameter<S>)> {
self.parameters_mut()
.into_iter()
.enumerate()
.map(|(i, p)| (alloc::format!("{}", i), p))
.collect()
}
fn state_dict(&self) -> Vec<(String, Tensor<S>)> {
self.named_parameters()
.into_iter()
.map(|(name, param)| (name, param.data.clone()))
.collect()
}
fn load_state_dict(&mut self, state: &[(String, Tensor<S>)]) {
for (name, param) in self.named_parameters_mut() {
if let Some((_, tensor)) = state.iter().find(|(n, _)| n == &name) {
param.data = tensor.clone();
}
}
}
fn set_training(&mut self, _training: bool) {}
fn zero_grad(&mut self) {
for p in self.parameters_mut() {
p.zero_grad();
}
}
}