ghostflow_nn/
module.rs

1//! Base Module trait for neural network layers
2
3use ghostflow_core::Tensor;
4
5/// Base trait for all neural network modules
6pub trait Module: Send + Sync {
7    /// Forward pass
8    fn forward(&self, input: &Tensor) -> Tensor;
9    
10    /// Get all trainable parameters
11    fn parameters(&self) -> Vec<Tensor>;
12    
13    /// Number of trainable parameters
14    fn num_parameters(&self) -> usize {
15        self.parameters().iter().map(|p| p.numel()).sum()
16    }
17    
18    /// Set module to training mode
19    fn train(&mut self);
20    
21    /// Set module to evaluation mode
22    fn eval(&mut self);
23    
24    /// Check if in training mode
25    fn is_training(&self) -> bool;
26}
27
28/// Container for sequential layers
29pub struct Sequential {
30    layers: Vec<Box<dyn Module>>,
31    training: bool,
32}
33
34impl Sequential {
35    pub fn new() -> Self {
36        Sequential {
37            layers: Vec::new(),
38            training: true,
39        }
40    }
41
42    /// Add a layer to the sequential model
43    pub fn add_layer<M: Module + 'static>(mut self, layer: M) -> Self {
44        self.layers.push(Box::new(layer));
45        self
46    }
47}
48
49impl Default for Sequential {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl Module for Sequential {
56    fn forward(&self, input: &Tensor) -> Tensor {
57        let mut x = input.clone();
58        for layer in &self.layers {
59            x = layer.forward(&x);
60        }
61        x
62    }
63
64    fn parameters(&self) -> Vec<Tensor> {
65        self.layers.iter()
66            .flat_map(|l| l.parameters())
67            .collect()
68    }
69
70    fn train(&mut self) {
71        self.training = true;
72        for layer in &mut self.layers {
73            layer.train();
74        }
75    }
76
77    fn eval(&mut self) {
78        self.training = false;
79        for layer in &mut self.layers {
80            layer.eval();
81        }
82    }
83
84    fn is_training(&self) -> bool {
85        self.training
86    }
87}