ghostflow_nn/
module.rs

1//! Base Module trait for neural network layers
2
3use ghostflow_core::Tensor;
4
5/// Base trait for all neural network modules
6pub trait Module: Send + Sync {
7    /// Forward pass
8    fn forward(&self, input: &Tensor) -> Tensor;
9    
10    /// Get all trainable parameters
11    fn parameters(&self) -> Vec<Tensor>;
12    
13    /// Number of trainable parameters
14    fn num_parameters(&self) -> usize {
15        self.parameters().iter().map(|p| p.numel()).sum()
16    }
17    
18    /// Set module to training mode
19    fn train(&mut self);
20    
21    /// Set module to evaluation mode
22    fn eval(&mut self);
23    
24    /// Check if in training mode
25    fn is_training(&self) -> bool;
26}
27
28/// Container for sequential layers
29pub struct Sequential {
30    layers: Vec<Box<dyn Module>>,
31    training: bool,
32}
33
34impl Sequential {
35    pub fn new() -> Self {
36        Sequential {
37            layers: Vec::new(),
38            training: true,
39        }
40    }
41
42    pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
43        self.layers.push(Box::new(layer));
44        self
45    }
46}
47
48impl Default for Sequential {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Module for Sequential {
55    fn forward(&self, input: &Tensor) -> Tensor {
56        let mut x = input.clone();
57        for layer in &self.layers {
58            x = layer.forward(&x);
59        }
60        x
61    }
62
63    fn parameters(&self) -> Vec<Tensor> {
64        self.layers.iter()
65            .flat_map(|l| l.parameters())
66            .collect()
67    }
68
69    fn train(&mut self) {
70        self.training = true;
71        for layer in &mut self.layers {
72            layer.train();
73        }
74    }
75
76    fn eval(&mut self) {
77        self.training = false;
78        for layer in &mut self.layers {
79            layer.eval();
80        }
81    }
82
83    fn is_training(&self) -> bool {
84        self.training
85    }
86}