use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_core::Device;
use crate::parameter::Parameter;
pub trait Module: Send + Sync {
fn forward(&self, input: &Variable) -> Variable;
fn parameters(&self) -> Vec<Parameter> {
Vec::new()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn num_parameters(&self) -> usize {
self.parameters()
.iter()
.filter(|p| p.requires_grad())
.map(|p| p.numel())
.sum()
}
fn train(&mut self) {
self.set_training(true);
}
fn eval(&mut self) {
self.set_training(false);
}
fn set_training(&mut self, _training: bool) {
}
fn is_training(&self) -> bool {
true }
fn zero_grad(&self) {
for param in self.parameters() {
param.zero_grad();
}
}
fn to_device(&self, device: Device) {
for param in self.parameters() {
param.to_device(device);
}
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
pub struct ModuleList {
modules: Vec<Box<dyn Module>>,
training: bool,
}
impl ModuleList {
pub fn new() -> Self {
Self {
modules: Vec::new(),
training: true,
}
}
pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
Self {
modules,
training: true,
}
}
pub fn push<M: Module + 'static>(&mut self, module: M) {
self.modules.push(Box::new(module));
}
pub fn len(&self) -> usize {
self.modules.len()
}
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
self.modules.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
self.modules.iter_mut()
}
pub fn get(&self, index: usize) -> Option<&dyn Module> {
self.modules.get(index).map(|m| m.as_ref())
}
}
impl Default for ModuleList {
fn default() -> Self {
Self::new()
}
}
impl Module for ModuleList {
fn forward(&self, input: &Variable) -> Variable {
let mut x = input.clone();
for module in &self.modules {
x = module.forward(&x);
}
x
}
fn parameters(&self) -> Vec<Parameter> {
self.modules.iter().flat_map(|m| m.parameters()).collect()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, module) in self.modules.iter().enumerate() {
for (name, param) in module.named_parameters() {
params.insert(format!("{i}.{name}"), param);
}
}
params
}
fn set_training(&mut self, training: bool) {
self.training = training;
for module in &mut self.modules {
module.set_training(training);
}
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &'static str {
"ModuleList"
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
struct Identity;
impl Module for Identity {
fn forward(&self, input: &Variable) -> Variable {
input.clone()
}
fn name(&self) -> &'static str {
"Identity"
}
}
#[test]
fn test_module_list() {
let mut list = ModuleList::new();
list.push(Identity);
list.push(Identity);
assert_eq!(list.len(), 2);
}
#[test]
fn test_module_list_forward() {
let mut list = ModuleList::new();
list.push(Identity);
let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
let output = list.forward(&input);
assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
}
}