1use candle::{Module, Result, Tensor};
5
6pub struct Sequential {
8 layers: Vec<Box<dyn Module>>,
9}
10
11pub fn seq() -> Sequential {
13 Sequential { layers: vec![] }
14}
15
16impl Sequential {
17 pub fn len(&self) -> i64 {
19 self.layers.len() as i64
20 }
21
22 pub fn is_empty(&self) -> bool {
24 self.layers.is_empty()
25 }
26}
27
28impl Module for Sequential {
29 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
30 let mut xs = xs.clone();
31 for layer in self.layers.iter() {
32 xs = layer.forward(&xs)?
33 }
34 Ok(xs)
35 }
36}
37
38impl Sequential {
39 #[allow(clippy::should_implement_trait)]
41 pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
42 self.layers.push(Box::new(layer));
43 self
44 }
45
46 pub fn add_fn<F>(self, f: F) -> Self
48 where
49 F: 'static + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
50 {
51 self.add(super::func(f))
52 }
53
54 pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
56 let mut vec = Vec::with_capacity(self.layers.len());
57 let mut xs = xs.clone();
58 for layer in self.layers.iter() {
59 xs = layer.forward(&xs)?;
60 vec.push(xs.clone())
61 }
62 Ok(vec)
63 }
64}