1use ghostflow_core::Tensor;
4
5pub trait Module: Send + Sync {
7 fn forward(&self, input: &Tensor) -> Tensor;
9
10 fn parameters(&self) -> Vec<Tensor>;
12
13 fn num_parameters(&self) -> usize {
15 self.parameters().iter().map(|p| p.numel()).sum()
16 }
17
18 fn train(&mut self);
20
21 fn eval(&mut self);
23
24 fn is_training(&self) -> bool;
26}
27
28pub struct Sequential {
30 layers: Vec<Box<dyn Module>>,
31 training: bool,
32}
33
34impl Sequential {
35 pub fn new() -> Self {
36 Sequential {
37 layers: Vec::new(),
38 training: true,
39 }
40 }
41
42 pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
43 self.layers.push(Box::new(layer));
44 self
45 }
46}
47
48impl Default for Sequential {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Module for Sequential {
55 fn forward(&self, input: &Tensor) -> Tensor {
56 let mut x = input.clone();
57 for layer in &self.layers {
58 x = layer.forward(&x);
59 }
60 x
61 }
62
63 fn parameters(&self) -> Vec<Tensor> {
64 self.layers.iter()
65 .flat_map(|l| l.parameters())
66 .collect()
67 }
68
69 fn train(&mut self) {
70 self.training = true;
71 for layer in &mut self.layers {
72 layer.train();
73 }
74 }
75
76 fn eval(&mut self) {
77 self.training = false;
78 for layer in &mut self.layers {
79 layer.eval();
80 }
81 }
82
83 fn is_training(&self) -> bool {
84 self.training
85 }
86}