Skip to main content

axonml_nn/
module.rs

1//! `Module` trait — the core interface for all neural network layers.
2//!
3//! 260 lines. `Module` requires: `forward(&self, &Variable) -> Variable`,
4//! `parameters(&self) -> Vec<Parameter>`, `train(&mut self)`, `eval(&mut self)`,
5//! `is_training(&self) -> bool`, `zero_grad(&mut self)`, `name() -> &str`,
6//! `to_device(&mut self, Device)`, `named_parameters() -> HashMap`. Also
7//! `ModuleList` (heterogeneous `Vec<Box<dyn Module>>` with forward-sequential,
8//! parameter aggregation, and train/eval propagation).
9//!
10//! # File
11//! `crates/axonml-nn/src/module.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_core::Device;
29
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Module Trait
34// =============================================================================
35
36/// Core trait for all neural network modules.
37///
38/// Every layer in Axonml implements this trait, which provides:
39/// - Forward pass computation
40/// - Parameter management
41/// - Training/evaluation mode switching
42/// - Module naming
43pub trait Module: Send + Sync {
44    /// Performs the forward pass.
45    ///
46    /// # Arguments
47    /// * `input` - Input variable
48    ///
49    /// # Returns
50    /// Output variable after applying this module's transformation.
51    fn forward(&self, input: &Variable) -> Variable;
52
53    /// Returns all parameters of this module.
54    ///
55    /// This includes parameters from all child modules.
56    fn parameters(&self) -> Vec<Parameter> {
57        Vec::new()
58    }
59
60    /// Returns named parameters of this module.
61    fn named_parameters(&self) -> HashMap<String, Parameter> {
62        HashMap::new()
63    }
64
65    /// Returns the number of trainable parameters.
66    fn num_parameters(&self) -> usize {
67        self.parameters()
68            .iter()
69            .filter(|p| p.requires_grad())
70            .map(|p| p.numel())
71            .sum()
72    }
73
74    /// Sets the module to training mode.
75    fn train(&mut self) {
76        self.set_training(true);
77    }
78
79    /// Sets the module to evaluation mode.
80    fn eval(&mut self) {
81        self.set_training(false);
82    }
83
84    /// Sets the training mode.
85    /// Sets the training mode.
86    ///
87    /// Modules with training-dependent behavior (Dropout, BatchNorm) MUST
88    /// override this AND `is_training()` to track the mode in an internal field.
89    fn set_training(&mut self, _training: bool) {
90        // Default: no-op. Stateless modules (Linear, Conv, activations)
91        // don't need training mode tracking.
92    }
93
94    /// Returns whether the module is in training mode.
95    ///
96    /// Default returns `true`. Modules that override `set_training()` should
97    /// also override this to return their tracked state.
98    fn is_training(&self) -> bool {
99        true
100    }
101
102    /// Zeros all gradients of parameters.
103    fn zero_grad(&self) {
104        for param in self.parameters() {
105            param.zero_grad();
106        }
107    }
108
109    /// Moves all parameters to the specified device.
110    ///
111    /// **Note:** This only moves `Parameter` tensors. Modules with non-parameter
112    /// state (e.g., BatchNorm running_mean/running_var) should override this
113    /// method to also move their buffers.
114    fn to_device(&self, device: Device) {
115        for param in self.parameters() {
116            param.to_device(device);
117        }
118    }
119
120    /// Returns the module name for debugging.
121    fn name(&self) -> &'static str {
122        std::any::type_name::<Self>()
123    }
124}
125
126// =============================================================================
127// ModuleList
128// =============================================================================
129
130/// A container for holding a list of modules.
131pub struct ModuleList {
132    modules: Vec<Box<dyn Module>>,
133    training: bool,
134}
135
136impl ModuleList {
137    /// Creates a new empty ModuleList.
138    pub fn new() -> Self {
139        Self {
140            modules: Vec::new(),
141            training: true,
142        }
143    }
144
145    /// Creates a ModuleList from a vector of modules.
146    pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
147        Self {
148            modules,
149            training: true,
150        }
151    }
152
153    /// Adds a module to the list.
154    pub fn push<M: Module + 'static>(&mut self, module: M) {
155        self.modules.push(Box::new(module));
156    }
157
158    /// Returns the number of modules.
159    pub fn len(&self) -> usize {
160        self.modules.len()
161    }
162
163    /// Returns true if the list is empty.
164    pub fn is_empty(&self) -> bool {
165        self.modules.is_empty()
166    }
167
168    /// Returns an iterator over the modules.
169    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
170        self.modules.iter()
171    }
172
173    /// Returns a mutable iterator over the modules.
174    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
175        self.modules.iter_mut()
176    }
177
178    /// Gets a module by index.
179    pub fn get(&self, index: usize) -> Option<&dyn Module> {
180        self.modules.get(index).map(|m| m.as_ref())
181    }
182}
183
184impl Default for ModuleList {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190impl Module for ModuleList {
191    fn forward(&self, input: &Variable) -> Variable {
192        let mut x = input.clone();
193        for module in &self.modules {
194            x = module.forward(&x);
195        }
196        x
197    }
198
199    fn parameters(&self) -> Vec<Parameter> {
200        self.modules.iter().flat_map(|m| m.parameters()).collect()
201    }
202
203    fn named_parameters(&self) -> HashMap<String, Parameter> {
204        let mut params = HashMap::new();
205        for (i, module) in self.modules.iter().enumerate() {
206            for (name, param) in module.named_parameters() {
207                params.insert(format!("{i}.{name}"), param);
208            }
209        }
210        params
211    }
212
213    fn set_training(&mut self, training: bool) {
214        self.training = training;
215        for module in &mut self.modules {
216            module.set_training(training);
217        }
218    }
219
220    fn is_training(&self) -> bool {
221        self.training
222    }
223
224    fn name(&self) -> &'static str {
225        "ModuleList"
226    }
227}
228
229// =============================================================================
230// Tests
231// =============================================================================
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use axonml_tensor::Tensor;
237
238    // Simple test module
239    struct Identity;
240
241    impl Module for Identity {
242        fn forward(&self, input: &Variable) -> Variable {
243            input.clone()
244        }
245
246        fn name(&self) -> &'static str {
247            "Identity"
248        }
249    }
250
251    #[test]
252    fn test_module_list() {
253        let mut list = ModuleList::new();
254        list.push(Identity);
255        list.push(Identity);
256        assert_eq!(list.len(), 2);
257    }
258
259    #[test]
260    fn test_module_list_forward() {
261        let mut list = ModuleList::new();
262        list.push(Identity);
263
264        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
265        let output = list.forward(&input);
266        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
267    }
268}