use scivex_core::Float;
use crate::error::Result;
use crate::variable::Variable;
use super::Layer;
pub struct Sequential<T: Float> {
layers: Vec<Box<dyn Layer<T>>>,
}
impl<T: Float> Sequential<T> {
pub fn new(layers: Vec<Box<dyn Layer<T>>>) -> Self {
Self { layers }
}
pub fn push(&mut self, layer: Box<dyn Layer<T>>) {
self.layers.push(layer);
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl<T: Float> Layer<T> for Sequential<T> {
fn forward(&self, x: &Variable<T>) -> Result<Variable<T>> {
let mut out = x.clone();
for layer in &self.layers {
out = layer.forward(&out)?;
}
Ok(out)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.layers.iter().flat_map(|l| l.parameters()).collect()
}
fn train(&mut self) {
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
for layer in &mut self.layers {
layer.eval();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layer::{Linear, ReLU};
use scivex_core::Tensor;
use scivex_core::random::Rng;
#[test]
fn test_sequential_forward() {
let mut rng = Rng::new(42);
let model: Sequential<f64> = Sequential::new(vec![
Box::new(Linear::new(4, 3, true, &mut rng)),
Box::new(ReLU),
Box::new(Linear::new(3, 2, true, &mut rng)),
]);
let x = Variable::new(Tensor::ones(vec![2, 4]), true);
let y = model.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 2]);
}
#[test]
fn test_sequential_parameters() {
let mut rng = Rng::new(42);
let model: Sequential<f64> = Sequential::new(vec![
Box::new(Linear::new(4, 3, true, &mut rng)),
Box::new(ReLU),
Box::new(Linear::new(3, 2, true, &mut rng)),
]);
assert_eq!(model.parameters().len(), 4);
}
}