use tch::Tensor;
use super::{NNModule};
#[derive(Debug)]
pub struct Sequential {
pub layers: Vec<Box<dyn NNModule>>,
}
impl Default for Sequential {
fn default() -> Self {
Sequential {
layers: vec![]
}
}
}
impl Sequential {
pub fn new() -> Self {
Sequential {
layers: vec![]
}
}
pub fn from_layers(layers: Vec<Box<dyn NNModule>>) -> Self {
Sequential {
layers
}
}
pub fn len(&self) -> i64 {
self.layers.len() as i64
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl Sequential {
#[allow(clippy::should_implement_trait)]
pub fn add<M: NNModule + 'static>(&mut self, layer: M) {
self.layers.push(Box::new(layer));
}
pub fn add_fn<F: 'static + Fn(&Tensor, bool) -> Tensor + Send>(&mut self, f: F) {
self.add(super::func(f))
}
pub fn forward_all(&self, xs: &Tensor, n: Option<usize>) -> Vec<Tensor> {
if self.layers.is_empty() {
vec![xs.shallow_clone()]
} else {
let n = n.unwrap_or_else(|| self.layers.len());
let xs = self.layers[0].forward(xs);
let mut vec = vec![];
let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
let out = layer.forward(&xs);
vec.push(xs);
out
});
vec.push(out);
vec
}
}
}
impl NNModule for Sequential {
fn train(&mut self) {
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
for layer in &mut self.layers {
layer.eval();
}
}
fn forward(&self, x: &tch::Tensor) -> tch::Tensor {
if self.layers.is_empty() {
x.shallow_clone()
} else {
let x = self.layers[0].forward(x);
self.layers
.iter()
.skip(1)
.fold(x, |x, layer| layer.forward(&x))
}
}
}
#[macro_use]
mod sequential_macro {
#[macro_export]
macro_rules! sequential {
( $( $x:expr ),* ) => {
{
let mut seq = Sequential::new();
$(
seq.add($x);
)*
seq
}
};
}
}
pub use sequential_macro::*;