candle_nn/
sequential.rs

1//! Sequential Layer
2//!
3//! A sequential layer used to chain multiple layers and closures.
4use candle::{Module, Result, Tensor};
5
6/// A sequential layer combining multiple other layers.
7pub struct Sequential {
8    layers: Vec<Box<dyn Module>>,
9}
10
11/// Creates a new empty sequential layer.
12pub fn seq() -> Sequential {
13    Sequential { layers: vec![] }
14}
15
16impl Sequential {
17    /// The number of sub-layers embedded in this layer.
18    pub fn len(&self) -> i64 {
19        self.layers.len() as i64
20    }
21
22    /// Returns true if this layer does not have any sub-layer.
23    pub fn is_empty(&self) -> bool {
24        self.layers.is_empty()
25    }
26}
27
28impl Module for Sequential {
29    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
30        let mut xs = xs.clone();
31        for layer in self.layers.iter() {
32            xs = layer.forward(&xs)?
33        }
34        Ok(xs)
35    }
36}
37
38impl Sequential {
39    /// Appends a layer after all the current layers.
40    #[allow(clippy::should_implement_trait)]
41    pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
42        self.layers.push(Box::new(layer));
43        self
44    }
45
46    /// Appends a closure after all the current layers.
47    pub fn add_fn<F>(self, f: F) -> Self
48    where
49        F: 'static + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
50    {
51        self.add(super::func(f))
52    }
53
54    /// Applies the forward pass and returns the output for each layer.
55    pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
56        let mut vec = Vec::with_capacity(self.layers.len());
57        let mut xs = xs.clone();
58        for layer in self.layers.iter() {
59            xs = layer.forward(&xs)?;
60            vec.push(xs.clone())
61        }
62        Ok(vec)
63    }
64}