1use ghostflow_core::Tensor;
4
5pub trait Module: Send + Sync {
7 fn forward(&self, input: &Tensor) -> Tensor;
9
10 fn parameters(&self) -> Vec<Tensor>;
12
13 fn num_parameters(&self) -> usize {
15 self.parameters().iter().map(|p| p.numel()).sum()
16 }
17
18 fn train(&mut self);
20
21 fn eval(&mut self);
23
24 fn is_training(&self) -> bool;
26}
27
28pub struct Sequential {
30 layers: Vec<Box<dyn Module>>,
31 training: bool,
32}
33
34impl Sequential {
35 pub fn new() -> Self {
36 Sequential {
37 layers: Vec::new(),
38 training: true,
39 }
40 }
41
42 pub fn add_layer<M: Module + 'static>(mut self, layer: M) -> Self {
44 self.layers.push(Box::new(layer));
45 self
46 }
47}
48
49impl Default for Sequential {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl Module for Sequential {
56 fn forward(&self, input: &Tensor) -> Tensor {
57 let mut x = input.clone();
58 for layer in &self.layers {
59 x = layer.forward(&x);
60 }
61 x
62 }
63
64 fn parameters(&self) -> Vec<Tensor> {
65 self.layers.iter()
66 .flat_map(|l| l.parameters())
67 .collect()
68 }
69
70 fn train(&mut self) {
71 self.training = true;
72 for layer in &mut self.layers {
73 layer.train();
74 }
75 }
76
77 fn eval(&mut self) {
78 self.training = false;
79 for layer in &mut self.layers {
80 layer.eval();
81 }
82 }
83
84 fn is_training(&self) -> bool {
85 self.training
86 }
87}