use std::collections::HashMap;
use crate::autograd::AutogradError;
use crate::nn::Parameter;
use crate::tensor::Tensor;
pub trait Module {
fn parameters(&self) -> Vec<Parameter>;
fn train(&mut self) {}
fn eval(&mut self) {}
fn state_dict(&self, prefix: &str) -> HashMap<String, Tensor>;
fn load_state_dict(
&mut self,
dict: &HashMap<String, Tensor>,
prefix: &str,
) -> Result<(), AutogradError>;
}
#[cfg(feature = "gpu")]
pub trait ModuleToGpu: Module {
fn to_gpu(&self) {
for param in self.parameters() {
param.tensor.to_gpu();
}
}
}
#[cfg(feature = "gpu")]
impl<T: Module> ModuleToGpu for T {}