use super::module::Module;
use crate::autograd::Tensor;
use std::collections::HashMap;
pub struct Sequential {
modules: Vec<Box<dyn Module>>,
training: bool,
}
impl Sequential {
#[must_use]
pub fn new() -> Self {
Self {
modules: Vec::new(),
training: true,
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
self.modules.push(Box::new(module));
self
}
#[must_use]
pub fn add_boxed(mut self, module: Box<dyn Module>) -> Self {
self.modules.push(module);
self
}
#[must_use]
pub fn len(&self) -> usize {
self.modules.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
}
impl Default for Sequential {
fn default() -> Self {
Self::new()
}
}
impl Module for Sequential {
fn forward(&self, input: &Tensor) -> Tensor {
self.modules
.iter()
.fold(input.clone(), |x, module| module.forward(&x))
}
fn parameters(&self) -> Vec<&Tensor> {
self.modules.iter().flat_map(|m| m.parameters()).collect()
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.modules
.iter_mut()
.flat_map(|m| m.parameters_mut())
.collect()
}
fn train(&mut self) {
self.training = true;
for module in &mut self.modules {
module.train();
}
}
fn eval(&mut self) {
self.training = false;
for module in &mut self.modules {
module.eval();
}
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for Sequential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequential")
.field("num_modules", &self.modules.len())
.field("training", &self.training)
.finish()
}
}
pub struct ModuleList {
modules: Vec<Box<dyn Module>>,
training: bool,
}
impl ModuleList {
#[must_use]
pub fn new() -> Self {
Self {
modules: Vec::new(),
training: true,
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
self.modules.push(Box::new(module));
self
}
#[must_use]
pub fn add_boxed(mut self, module: Box<dyn Module>) -> Self {
self.modules.push(module);
self
}
pub fn get(&self, index: usize) -> Option<&dyn Module> {
self.modules.get(index).map(AsRef::as_ref)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut Box<dyn Module>> {
self.modules.get_mut(index)
}
#[must_use]
pub fn len(&self) -> usize {
self.modules.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &dyn Module> {
self.modules.iter().map(AsRef::as_ref)
}
}
impl Default for ModuleList {
fn default() -> Self {
Self::new()
}
}
impl ModuleList {
#[must_use]
pub fn parameters(&self) -> Vec<&Tensor> {
self.modules.iter().flat_map(|m| m.parameters()).collect()
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.modules
.iter_mut()
.flat_map(|m| m.parameters_mut())
.collect()
}
pub fn train(&mut self) {
self.training = true;
for module in &mut self.modules {
module.train();
}
}
pub fn eval(&mut self) {
self.training = false;
for module in &mut self.modules {
module.eval();
}
}
}
impl std::fmt::Debug for ModuleList {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModuleList")
.field("num_modules", &self.modules.len())
.field("training", &self.training)
.finish()
}
}
pub struct ModuleDict {
modules: HashMap<String, Box<dyn Module>>,
keys: Vec<String>,
training: bool,
}
impl ModuleDict {
#[must_use]
pub fn new() -> Self {
Self {
modules: HashMap::new(),
keys: Vec::new(),
training: true,
}
}
pub fn insert<S: Into<String>, M: Module + 'static>(mut self, name: S, module: M) -> Self {
let name = name.into();
if !self.modules.contains_key(&name) {
self.keys.push(name.clone());
}
self.modules.insert(name, Box::new(module));
self
}
pub fn insert_boxed<S: Into<String>>(mut self, name: S, module: Box<dyn Module>) -> Self {
let name = name.into();
if !self.modules.contains_key(&name) {
self.keys.push(name.clone());
}
self.modules.insert(name, module);
self
}
pub fn get(&self, name: &str) -> Option<&dyn Module> {
self.modules.get(name).map(AsRef::as_ref)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn Module>> {
self.modules.get_mut(name)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.modules.contains_key(name)
}
pub fn remove(&mut self, name: &str) -> Option<Box<dyn Module>> {
if let Some(module) = self.modules.remove(name) {
self.keys.retain(|k| k != name);
Some(module)
} else {
None
}
}
#[must_use]
pub fn len(&self) -> usize {
self.modules.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.keys.iter().map(String::as_str)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
self.keys.iter().map(|k| {
let module = self
.modules
.get(k)
.map(AsRef::as_ref)
.expect("key must exist");
(k.as_str(), module)
})
}
#[must_use]
pub fn parameters(&self) -> Vec<&Tensor> {
self.keys
.iter()
.filter_map(|k| self.modules.get(k))
.flat_map(|m| m.parameters())
.collect()
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.modules
.values_mut()
.flat_map(|m| m.parameters_mut())
.collect()
}
pub fn train(&mut self) {
self.training = true;
for module in self.modules.values_mut() {
module.train();
}
}
pub fn eval(&mut self) {
self.training = false;
for module in self.modules.values_mut() {
module.eval();
}
}
#[must_use]
pub fn training(&self) -> bool {
self.training
}
}
impl Default for ModuleDict {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ModuleDict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModuleDict")
.field("keys", &self.keys)
.field("training", &self.training)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[path = "container_tests.rs"]
mod tests;