Skip to main content

axonml_nn/
module.rs

1//! Module Trait - Neural Network Module Interface
2//!
3//! # File
4//! `crates/axonml-nn/src/module.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_core::Device;
21
22use crate::parameter::Parameter;
23
24// =============================================================================
25// Module Trait
26// =============================================================================
27
28/// Core trait for all neural network modules.
29///
30/// Every layer in Axonml implements this trait, which provides:
31/// - Forward pass computation
32/// - Parameter management
33/// - Training/evaluation mode switching
34/// - Module naming
35pub trait Module: Send + Sync {
36    /// Performs the forward pass.
37    ///
38    /// # Arguments
39    /// * `input` - Input variable
40    ///
41    /// # Returns
42    /// Output variable after applying this module's transformation.
43    fn forward(&self, input: &Variable) -> Variable;
44
45    /// Returns all parameters of this module.
46    ///
47    /// This includes parameters from all child modules.
48    fn parameters(&self) -> Vec<Parameter> {
49        Vec::new()
50    }
51
52    /// Returns named parameters of this module.
53    fn named_parameters(&self) -> HashMap<String, Parameter> {
54        HashMap::new()
55    }
56
57    /// Returns the number of trainable parameters.
58    fn num_parameters(&self) -> usize {
59        self.parameters()
60            .iter()
61            .filter(|p| p.requires_grad())
62            .map(|p| p.numel())
63            .sum()
64    }
65
66    /// Sets the module to training mode.
67    fn train(&mut self) {
68        self.set_training(true);
69    }
70
71    /// Sets the module to evaluation mode.
72    fn eval(&mut self) {
73        self.set_training(false);
74    }
75
76    /// Sets the training mode.
77    /// Sets the training mode.
78    ///
79    /// Modules with training-dependent behavior (Dropout, BatchNorm) MUST
80    /// override this AND `is_training()` to track the mode in an internal field.
81    fn set_training(&mut self, _training: bool) {
82        // Default: no-op. Stateless modules (Linear, Conv, activations)
83        // don't need training mode tracking.
84    }
85
86    /// Returns whether the module is in training mode.
87    ///
88    /// Default returns `true`. Modules that override `set_training()` should
89    /// also override this to return their tracked state.
90    fn is_training(&self) -> bool {
91        true
92    }
93
94    /// Zeros all gradients of parameters.
95    fn zero_grad(&self) {
96        for param in self.parameters() {
97            param.zero_grad();
98        }
99    }
100
101    /// Moves all parameters to the specified device.
102    ///
103    /// **Note:** This only moves `Parameter` tensors. Modules with non-parameter
104    /// state (e.g., BatchNorm running_mean/running_var) should override this
105    /// method to also move their buffers.
106    fn to_device(&self, device: Device) {
107        for param in self.parameters() {
108            param.to_device(device);
109        }
110    }
111
112    /// Returns the module name for debugging.
113    fn name(&self) -> &'static str {
114        std::any::type_name::<Self>()
115    }
116}
117
118// =============================================================================
119// ModuleList
120// =============================================================================
121
122/// A container for holding a list of modules.
123pub struct ModuleList {
124    modules: Vec<Box<dyn Module>>,
125    training: bool,
126}
127
128impl ModuleList {
129    /// Creates a new empty ModuleList.
130    pub fn new() -> Self {
131        Self {
132            modules: Vec::new(),
133            training: true,
134        }
135    }
136
137    /// Creates a ModuleList from a vector of modules.
138    pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
139        Self {
140            modules,
141            training: true,
142        }
143    }
144
145    /// Adds a module to the list.
146    pub fn push<M: Module + 'static>(&mut self, module: M) {
147        self.modules.push(Box::new(module));
148    }
149
150    /// Returns the number of modules.
151    pub fn len(&self) -> usize {
152        self.modules.len()
153    }
154
155    /// Returns true if the list is empty.
156    pub fn is_empty(&self) -> bool {
157        self.modules.is_empty()
158    }
159
160    /// Returns an iterator over the modules.
161    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
162        self.modules.iter()
163    }
164
165    /// Returns a mutable iterator over the modules.
166    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
167        self.modules.iter_mut()
168    }
169
170    /// Gets a module by index.
171    pub fn get(&self, index: usize) -> Option<&dyn Module> {
172        self.modules.get(index).map(|m| m.as_ref())
173    }
174}
175
176impl Default for ModuleList {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl Module for ModuleList {
183    fn forward(&self, input: &Variable) -> Variable {
184        let mut x = input.clone();
185        for module in &self.modules {
186            x = module.forward(&x);
187        }
188        x
189    }
190
191    fn parameters(&self) -> Vec<Parameter> {
192        self.modules.iter().flat_map(|m| m.parameters()).collect()
193    }
194
195    fn named_parameters(&self) -> HashMap<String, Parameter> {
196        let mut params = HashMap::new();
197        for (i, module) in self.modules.iter().enumerate() {
198            for (name, param) in module.named_parameters() {
199                params.insert(format!("{i}.{name}"), param);
200            }
201        }
202        params
203    }
204
205    fn set_training(&mut self, training: bool) {
206        self.training = training;
207        for module in &mut self.modules {
208            module.set_training(training);
209        }
210    }
211
212    fn is_training(&self) -> bool {
213        self.training
214    }
215
216    fn name(&self) -> &'static str {
217        "ModuleList"
218    }
219}
220
221// =============================================================================
222// Tests
223// =============================================================================
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use axonml_tensor::Tensor;
229
230    // Simple test module
231    struct Identity;
232
233    impl Module for Identity {
234        fn forward(&self, input: &Variable) -> Variable {
235            input.clone()
236        }
237
238        fn name(&self) -> &'static str {
239            "Identity"
240        }
241    }
242
243    #[test]
244    fn test_module_list() {
245        let mut list = ModuleList::new();
246        list.push(Identity);
247        list.push(Identity);
248        assert_eq!(list.len(), 2);
249    }
250
251    #[test]
252    fn test_module_list_forward() {
253        let mut list = ModuleList::new();
254        list.push(Identity);
255
256        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
257        let output = list.forward(&input);
258        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
259    }
260}