axonml_nn/
module.rs

1//! Module Trait - Neural Network Module Interface
2//!
3//! Defines the core Module trait that all neural network layers implement.
4//! This is the foundation of the neural network abstraction in Axonml.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::parameter::Parameter;
14
15// =============================================================================
16// Module Trait
17// =============================================================================
18
19/// Core trait for all neural network modules.
20///
21/// Every layer in Axonml implements this trait, which provides:
22/// - Forward pass computation
23/// - Parameter management
24/// - Training/evaluation mode switching
25/// - Module naming
26pub trait Module: Send + Sync {
27    /// Performs the forward pass.
28    ///
29    /// # Arguments
30    /// * `input` - Input variable
31    ///
32    /// # Returns
33    /// Output variable after applying this module's transformation.
34    fn forward(&self, input: &Variable) -> Variable;
35
36    /// Returns all parameters of this module.
37    ///
38    /// This includes parameters from all child modules.
39    fn parameters(&self) -> Vec<Parameter> {
40        Vec::new()
41    }
42
43    /// Returns named parameters of this module.
44    fn named_parameters(&self) -> HashMap<String, Parameter> {
45        HashMap::new()
46    }
47
48    /// Returns the number of trainable parameters.
49    fn num_parameters(&self) -> usize {
50        self.parameters()
51            .iter()
52            .filter(|p| p.requires_grad())
53            .map(|p| p.numel())
54            .sum()
55    }
56
57    /// Sets the module to training mode.
58    fn train(&mut self) {
59        self.set_training(true);
60    }
61
62    /// Sets the module to evaluation mode.
63    fn eval(&mut self) {
64        self.set_training(false);
65    }
66
67    /// Sets the training mode.
68    fn set_training(&mut self, _training: bool) {
69        // Default implementation does nothing
70        // Submodules override this if they have training-specific behavior
71    }
72
73    /// Returns whether the module is in training mode.
74    fn is_training(&self) -> bool {
75        true // Default to training mode
76    }
77
78    /// Zeros all gradients of parameters.
79    fn zero_grad(&self) {
80        for param in self.parameters() {
81            param.zero_grad();
82        }
83    }
84
85    /// Returns the module name for debugging.
86    fn name(&self) -> &'static str {
87        std::any::type_name::<Self>()
88    }
89}
90
91// =============================================================================
92// ModuleList
93// =============================================================================
94
95/// A container for holding a list of modules.
96pub struct ModuleList {
97    modules: Vec<Box<dyn Module>>,
98    training: bool,
99}
100
101impl ModuleList {
102    /// Creates a new empty ModuleList.
103    pub fn new() -> Self {
104        Self {
105            modules: Vec::new(),
106            training: true,
107        }
108    }
109
110    /// Creates a ModuleList from a vector of modules.
111    pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
112        Self {
113            modules,
114            training: true,
115        }
116    }
117
118    /// Adds a module to the list.
119    pub fn push<M: Module + 'static>(&mut self, module: M) {
120        self.modules.push(Box::new(module));
121    }
122
123    /// Returns the number of modules.
124    pub fn len(&self) -> usize {
125        self.modules.len()
126    }
127
128    /// Returns true if the list is empty.
129    pub fn is_empty(&self) -> bool {
130        self.modules.is_empty()
131    }
132
133    /// Returns an iterator over the modules.
134    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
135        self.modules.iter()
136    }
137
138    /// Returns a mutable iterator over the modules.
139    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
140        self.modules.iter_mut()
141    }
142
143    /// Gets a module by index.
144    pub fn get(&self, index: usize) -> Option<&dyn Module> {
145        self.modules.get(index).map(|m| m.as_ref())
146    }
147}
148
149impl Default for ModuleList {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl Module for ModuleList {
156    fn forward(&self, input: &Variable) -> Variable {
157        let mut x = input.clone();
158        for module in &self.modules {
159            x = module.forward(&x);
160        }
161        x
162    }
163
164    fn parameters(&self) -> Vec<Parameter> {
165        self.modules.iter().flat_map(|m| m.parameters()).collect()
166    }
167
168    fn named_parameters(&self) -> HashMap<String, Parameter> {
169        let mut params = HashMap::new();
170        for (i, module) in self.modules.iter().enumerate() {
171            for (name, param) in module.named_parameters() {
172                params.insert(format!("{i}.{name}"), param);
173            }
174        }
175        params
176    }
177
178    fn set_training(&mut self, training: bool) {
179        self.training = training;
180        for module in &mut self.modules {
181            module.set_training(training);
182        }
183    }
184
185    fn is_training(&self) -> bool {
186        self.training
187    }
188
189    fn name(&self) -> &'static str {
190        "ModuleList"
191    }
192}
193
194// =============================================================================
195// Tests
196// =============================================================================
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use axonml_tensor::Tensor;
202
203    // Simple test module
204    struct Identity;
205
206    impl Module for Identity {
207        fn forward(&self, input: &Variable) -> Variable {
208            input.clone()
209        }
210
211        fn name(&self) -> &'static str {
212            "Identity"
213        }
214    }
215
216    #[test]
217    fn test_module_list() {
218        let mut list = ModuleList::new();
219        list.push(Identity);
220        list.push(Identity);
221        assert_eq!(list.len(), 2);
222    }
223
224    #[test]
225    fn test_module_list_forward() {
226        let mut list = ModuleList::new();
227        list.push(Identity);
228
229        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
230        let output = list.forward(&input);
231        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
232    }
233}