axonml_nn/module.rs
1//! `Module` trait — the core interface for all neural network layers.
2//!
3//! 260 lines. `Module` requires: `forward(&self, &Variable) -> Variable`,
4//! `parameters(&self) -> Vec<Parameter>`, `train(&mut self)`, `eval(&mut self)`,
5//! `is_training(&self) -> bool`, `zero_grad(&mut self)`, `name() -> &str`,
6//! `to_device(&mut self, Device)`, `named_parameters() -> HashMap`. Also
7//! `ModuleList` (heterogeneous `Vec<Box<dyn Module>>` with forward-sequential,
8//! parameter aggregation, and train/eval propagation).
9//!
10//! # File
11//! `crates/axonml-nn/src/module.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_core::Device;
29
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Module Trait
34// =============================================================================
35
36/// Core trait for all neural network modules.
37///
38/// Every layer in Axonml implements this trait, which provides:
39/// - Forward pass computation
40/// - Parameter management
41/// - Training/evaluation mode switching
42/// - Module naming
43pub trait Module: Send + Sync {
44 /// Performs the forward pass.
45 ///
46 /// # Arguments
47 /// * `input` - Input variable
48 ///
49 /// # Returns
50 /// Output variable after applying this module's transformation.
51 fn forward(&self, input: &Variable) -> Variable;
52
53 /// Returns all parameters of this module.
54 ///
55 /// This includes parameters from all child modules.
56 fn parameters(&self) -> Vec<Parameter> {
57 Vec::new()
58 }
59
60 /// Returns named parameters of this module.
61 fn named_parameters(&self) -> HashMap<String, Parameter> {
62 HashMap::new()
63 }
64
65 /// Returns the number of trainable parameters.
66 fn num_parameters(&self) -> usize {
67 self.parameters()
68 .iter()
69 .filter(|p| p.requires_grad())
70 .map(|p| p.numel())
71 .sum()
72 }
73
74 /// Sets the module to training mode.
75 fn train(&mut self) {
76 self.set_training(true);
77 }
78
79 /// Sets the module to evaluation mode.
80 fn eval(&mut self) {
81 self.set_training(false);
82 }
83
84 /// Sets the training mode.
85 /// Sets the training mode.
86 ///
87 /// Modules with training-dependent behavior (Dropout, BatchNorm) MUST
88 /// override this AND `is_training()` to track the mode in an internal field.
89 fn set_training(&mut self, _training: bool) {
90 // Default: no-op. Stateless modules (Linear, Conv, activations)
91 // don't need training mode tracking.
92 }
93
94 /// Returns whether the module is in training mode.
95 ///
96 /// Default returns `true`. Modules that override `set_training()` should
97 /// also override this to return their tracked state.
98 fn is_training(&self) -> bool {
99 true
100 }
101
102 /// Zeros all gradients of parameters.
103 fn zero_grad(&self) {
104 for param in self.parameters() {
105 param.zero_grad();
106 }
107 }
108
109 /// Moves all parameters to the specified device.
110 ///
111 /// **Note:** This only moves `Parameter` tensors. Modules with non-parameter
112 /// state (e.g., BatchNorm running_mean/running_var) should override this
113 /// method to also move their buffers.
114 fn to_device(&self, device: Device) {
115 for param in self.parameters() {
116 param.to_device(device);
117 }
118 }
119
120 /// Returns the module name for debugging.
121 fn name(&self) -> &'static str {
122 std::any::type_name::<Self>()
123 }
124}
125
126// =============================================================================
127// ModuleList
128// =============================================================================
129
130/// A container for holding a list of modules.
131pub struct ModuleList {
132 modules: Vec<Box<dyn Module>>,
133 training: bool,
134}
135
136impl ModuleList {
137 /// Creates a new empty ModuleList.
138 pub fn new() -> Self {
139 Self {
140 modules: Vec::new(),
141 training: true,
142 }
143 }
144
145 /// Creates a ModuleList from a vector of modules.
146 pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
147 Self {
148 modules,
149 training: true,
150 }
151 }
152
153 /// Adds a module to the list.
154 pub fn push<M: Module + 'static>(&mut self, module: M) {
155 self.modules.push(Box::new(module));
156 }
157
158 /// Returns the number of modules.
159 pub fn len(&self) -> usize {
160 self.modules.len()
161 }
162
163 /// Returns true if the list is empty.
164 pub fn is_empty(&self) -> bool {
165 self.modules.is_empty()
166 }
167
168 /// Returns an iterator over the modules.
169 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
170 self.modules.iter()
171 }
172
173 /// Returns a mutable iterator over the modules.
174 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
175 self.modules.iter_mut()
176 }
177
178 /// Gets a module by index.
179 pub fn get(&self, index: usize) -> Option<&dyn Module> {
180 self.modules.get(index).map(|m| m.as_ref())
181 }
182}
183
184impl Default for ModuleList {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190impl Module for ModuleList {
191 fn forward(&self, input: &Variable) -> Variable {
192 let mut x = input.clone();
193 for module in &self.modules {
194 x = module.forward(&x);
195 }
196 x
197 }
198
199 fn parameters(&self) -> Vec<Parameter> {
200 self.modules.iter().flat_map(|m| m.parameters()).collect()
201 }
202
203 fn named_parameters(&self) -> HashMap<String, Parameter> {
204 let mut params = HashMap::new();
205 for (i, module) in self.modules.iter().enumerate() {
206 for (name, param) in module.named_parameters() {
207 params.insert(format!("{i}.{name}"), param);
208 }
209 }
210 params
211 }
212
213 fn set_training(&mut self, training: bool) {
214 self.training = training;
215 for module in &mut self.modules {
216 module.set_training(training);
217 }
218 }
219
220 fn is_training(&self) -> bool {
221 self.training
222 }
223
224 fn name(&self) -> &'static str {
225 "ModuleList"
226 }
227}
228
229// =============================================================================
230// Tests
231// =============================================================================
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use axonml_tensor::Tensor;
237
238 // Simple test module
239 struct Identity;
240
241 impl Module for Identity {
242 fn forward(&self, input: &Variable) -> Variable {
243 input.clone()
244 }
245
246 fn name(&self) -> &'static str {
247 "Identity"
248 }
249 }
250
251 #[test]
252 fn test_module_list() {
253 let mut list = ModuleList::new();
254 list.push(Identity);
255 list.push(Identity);
256 assert_eq!(list.len(), 2);
257 }
258
259 #[test]
260 fn test_module_list_forward() {
261 let mut list = ModuleList::new();
262 list.push(Identity);
263
264 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
265 let output = list.forward(&input);
266 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
267 }
268}