use super::{Module, ModuleT};
use crate::Tensor;
#[derive(Debug)]
pub struct Sequential {
layers: Vec<Box<dyn Module>>,
}
pub fn seq() -> Sequential {
Sequential { layers: vec![] }
}
impl Sequential {
pub fn len(&self) -> i64 {
self.layers.len() as i64
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl Module for Sequential {
fn forward(&self, xs: &Tensor) -> Tensor {
if self.layers.is_empty() {
xs.shallow_clone()
} else {
let xs = self.layers[0].forward(xs);
self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward(&xs))
}
}
}
impl Sequential {
#[allow(clippy::should_implement_trait)]
pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
self.layers.push(Box::new(layer));
self
}
pub fn add_fn<F>(self, f: F) -> Self
where
F: 'static + Fn(&Tensor) -> Tensor + Send,
{
self.add(super::func(f))
}
pub fn forward_all(&self, xs: &Tensor, n: Option<usize>) -> Vec<Tensor> {
if self.layers.is_empty() {
vec![xs.shallow_clone()]
} else {
let n = n.unwrap_or(self.layers.len());
let xs = self.layers[0].forward(xs);
let mut vec = vec![];
let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
let out = layer.forward(&xs);
vec.push(xs);
out
});
vec.push(out);
vec
}
}
}
#[derive(Debug)]
pub struct SequentialT {
layers: Vec<Box<dyn ModuleT>>,
}
pub fn seq_t() -> SequentialT {
SequentialT { layers: vec![] }
}
impl SequentialT {
pub fn len(&self) -> i64 {
self.layers.len() as i64
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl ModuleT for SequentialT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
if self.layers.is_empty() {
xs.shallow_clone()
} else {
let xs = self.layers[0].forward_t(xs, train);
self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward_t(&xs, train))
}
}
}
impl SequentialT {
#[allow(clippy::should_implement_trait)]
pub fn add<M: ModuleT + 'static>(mut self, layer: M) -> Self {
self.layers.push(Box::new(layer));
self
}
pub fn add_fn<F>(self, f: F) -> Self
where
F: 'static + Fn(&Tensor) -> Tensor + Send,
{
self.add(super::func(f))
}
pub fn add_fn_t<F>(self, f: F) -> Self
where
F: 'static + Fn(&Tensor, bool) -> Tensor + Send,
{
self.add(super::func_t(f))
}
pub fn forward_all_t(&self, xs: &Tensor, train: bool, n: Option<usize>) -> Vec<Tensor> {
if self.layers.is_empty() {
vec![xs.shallow_clone()]
} else {
let n = n.unwrap_or(self.layers.len());
let xs = self.layers[0].forward_t(xs, train);
let mut vec = vec![];
let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
let out = layer.forward_t(&xs, train);
vec.push(xs);
out
});
vec.push(out);
vec
}
}
}