Skip to main content

axonml_nn/
module.rs

1//! Module Trait - Neural Network Module Interface
2//!
3//! # File
4//! `crates/axonml-nn/src/module.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_core::Device;
21
22use crate::parameter::Parameter;
23
24// =============================================================================
25// Module Trait
26// =============================================================================
27
28/// Core trait for all neural network modules.
29///
30/// Every layer in Axonml implements this trait, which provides:
31/// - Forward pass computation
32/// - Parameter management
33/// - Training/evaluation mode switching
34/// - Module naming
35pub trait Module: Send + Sync {
36    /// Performs the forward pass.
37    ///
38    /// # Arguments
39    /// * `input` - Input variable
40    ///
41    /// # Returns
42    /// Output variable after applying this module's transformation.
43    fn forward(&self, input: &Variable) -> Variable;
44
45    /// Returns all parameters of this module.
46    ///
47    /// This includes parameters from all child modules.
48    fn parameters(&self) -> Vec<Parameter> {
49        Vec::new()
50    }
51
52    /// Returns named parameters of this module.
53    fn named_parameters(&self) -> HashMap<String, Parameter> {
54        HashMap::new()
55    }
56
57    /// Returns the number of trainable parameters.
58    fn num_parameters(&self) -> usize {
59        self.parameters()
60            .iter()
61            .filter(|p| p.requires_grad())
62            .map(|p| p.numel())
63            .sum()
64    }
65
66    /// Sets the module to training mode.
67    fn train(&mut self) {
68        self.set_training(true);
69    }
70
71    /// Sets the module to evaluation mode.
72    fn eval(&mut self) {
73        self.set_training(false);
74    }
75
76    /// Sets the training mode.
77    fn set_training(&mut self, _training: bool) {
78        // Default implementation does nothing
79        // Submodules override this if they have training-specific behavior
80    }
81
82    /// Returns whether the module is in training mode.
83    fn is_training(&self) -> bool {
84        true // Default to training mode
85    }
86
87    /// Zeros all gradients of parameters.
88    fn zero_grad(&self) {
89        for param in self.parameters() {
90            param.zero_grad();
91        }
92    }
93
94    /// Moves all parameters to the specified device.
95    fn to_device(&self, device: Device) {
96        for param in self.parameters() {
97            param.to_device(device);
98        }
99    }
100
101    /// Returns the module name for debugging.
102    fn name(&self) -> &'static str {
103        std::any::type_name::<Self>()
104    }
105}
106
107// =============================================================================
108// ModuleList
109// =============================================================================
110
111/// A container for holding a list of modules.
112pub struct ModuleList {
113    modules: Vec<Box<dyn Module>>,
114    training: bool,
115}
116
117impl ModuleList {
118    /// Creates a new empty ModuleList.
119    pub fn new() -> Self {
120        Self {
121            modules: Vec::new(),
122            training: true,
123        }
124    }
125
126    /// Creates a ModuleList from a vector of modules.
127    pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
128        Self {
129            modules,
130            training: true,
131        }
132    }
133
134    /// Adds a module to the list.
135    pub fn push<M: Module + 'static>(&mut self, module: M) {
136        self.modules.push(Box::new(module));
137    }
138
139    /// Returns the number of modules.
140    pub fn len(&self) -> usize {
141        self.modules.len()
142    }
143
144    /// Returns true if the list is empty.
145    pub fn is_empty(&self) -> bool {
146        self.modules.is_empty()
147    }
148
149    /// Returns an iterator over the modules.
150    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
151        self.modules.iter()
152    }
153
154    /// Returns a mutable iterator over the modules.
155    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
156        self.modules.iter_mut()
157    }
158
159    /// Gets a module by index.
160    pub fn get(&self, index: usize) -> Option<&dyn Module> {
161        self.modules.get(index).map(|m| m.as_ref())
162    }
163}
164
165impl Default for ModuleList {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl Module for ModuleList {
172    fn forward(&self, input: &Variable) -> Variable {
173        let mut x = input.clone();
174        for module in &self.modules {
175            x = module.forward(&x);
176        }
177        x
178    }
179
180    fn parameters(&self) -> Vec<Parameter> {
181        self.modules.iter().flat_map(|m| m.parameters()).collect()
182    }
183
184    fn named_parameters(&self) -> HashMap<String, Parameter> {
185        let mut params = HashMap::new();
186        for (i, module) in self.modules.iter().enumerate() {
187            for (name, param) in module.named_parameters() {
188                params.insert(format!("{i}.{name}"), param);
189            }
190        }
191        params
192    }
193
194    fn set_training(&mut self, training: bool) {
195        self.training = training;
196        for module in &mut self.modules {
197            module.set_training(training);
198        }
199    }
200
201    fn is_training(&self) -> bool {
202        self.training
203    }
204
205    fn name(&self) -> &'static str {
206        "ModuleList"
207    }
208}
209
210// =============================================================================
211// Tests
212// =============================================================================
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use axonml_tensor::Tensor;
218
219    // Simple test module
220    struct Identity;
221
222    impl Module for Identity {
223        fn forward(&self, input: &Variable) -> Variable {
224            input.clone()
225        }
226
227        fn name(&self) -> &'static str {
228            "Identity"
229        }
230    }
231
232    #[test]
233    fn test_module_list() {
234        let mut list = ModuleList::new();
235        list.push(Identity);
236        list.push(Identity);
237        assert_eq!(list.len(), 2);
238    }
239
240    #[test]
241    fn test_module_list_forward() {
242        let mut list = ModuleList::new();
243        list.push(Identity);
244
245        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
246        let output = list.forward(&input);
247        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
248    }
249}