axonml_nn/module.rs
1//! Module Trait - Neural Network Module Interface
2//!
3//! # File
4//! `crates/axonml-nn/src/module.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_core::Device;
21
22use crate::parameter::Parameter;
23
24// =============================================================================
25// Module Trait
26// =============================================================================
27
28/// Core trait for all neural network modules.
29///
30/// Every layer in Axonml implements this trait, which provides:
31/// - Forward pass computation
32/// - Parameter management
33/// - Training/evaluation mode switching
34/// - Module naming
35pub trait Module: Send + Sync {
36 /// Performs the forward pass.
37 ///
38 /// # Arguments
39 /// * `input` - Input variable
40 ///
41 /// # Returns
42 /// Output variable after applying this module's transformation.
43 fn forward(&self, input: &Variable) -> Variable;
44
45 /// Returns all parameters of this module.
46 ///
47 /// This includes parameters from all child modules.
48 fn parameters(&self) -> Vec<Parameter> {
49 Vec::new()
50 }
51
52 /// Returns named parameters of this module.
53 fn named_parameters(&self) -> HashMap<String, Parameter> {
54 HashMap::new()
55 }
56
57 /// Returns the number of trainable parameters.
58 fn num_parameters(&self) -> usize {
59 self.parameters()
60 .iter()
61 .filter(|p| p.requires_grad())
62 .map(|p| p.numel())
63 .sum()
64 }
65
66 /// Sets the module to training mode.
67 fn train(&mut self) {
68 self.set_training(true);
69 }
70
71 /// Sets the module to evaluation mode.
72 fn eval(&mut self) {
73 self.set_training(false);
74 }
75
76 /// Sets the training mode.
77 /// Sets the training mode.
78 ///
79 /// Modules with training-dependent behavior (Dropout, BatchNorm) MUST
80 /// override this AND `is_training()` to track the mode in an internal field.
81 fn set_training(&mut self, _training: bool) {
82 // Default: no-op. Stateless modules (Linear, Conv, activations)
83 // don't need training mode tracking.
84 }
85
86 /// Returns whether the module is in training mode.
87 ///
88 /// Default returns `true`. Modules that override `set_training()` should
89 /// also override this to return their tracked state.
90 fn is_training(&self) -> bool {
91 true
92 }
93
94 /// Zeros all gradients of parameters.
95 fn zero_grad(&self) {
96 for param in self.parameters() {
97 param.zero_grad();
98 }
99 }
100
101 /// Moves all parameters to the specified device.
102 ///
103 /// **Note:** This only moves `Parameter` tensors. Modules with non-parameter
104 /// state (e.g., BatchNorm running_mean/running_var) should override this
105 /// method to also move their buffers.
106 fn to_device(&self, device: Device) {
107 for param in self.parameters() {
108 param.to_device(device);
109 }
110 }
111
112 /// Returns the module name for debugging.
113 fn name(&self) -> &'static str {
114 std::any::type_name::<Self>()
115 }
116}
117
118// =============================================================================
119// ModuleList
120// =============================================================================
121
122/// A container for holding a list of modules.
123pub struct ModuleList {
124 modules: Vec<Box<dyn Module>>,
125 training: bool,
126}
127
128impl ModuleList {
129 /// Creates a new empty ModuleList.
130 pub fn new() -> Self {
131 Self {
132 modules: Vec::new(),
133 training: true,
134 }
135 }
136
137 /// Creates a ModuleList from a vector of modules.
138 pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
139 Self {
140 modules,
141 training: true,
142 }
143 }
144
145 /// Adds a module to the list.
146 pub fn push<M: Module + 'static>(&mut self, module: M) {
147 self.modules.push(Box::new(module));
148 }
149
150 /// Returns the number of modules.
151 pub fn len(&self) -> usize {
152 self.modules.len()
153 }
154
155 /// Returns true if the list is empty.
156 pub fn is_empty(&self) -> bool {
157 self.modules.is_empty()
158 }
159
160 /// Returns an iterator over the modules.
161 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
162 self.modules.iter()
163 }
164
165 /// Returns a mutable iterator over the modules.
166 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
167 self.modules.iter_mut()
168 }
169
170 /// Gets a module by index.
171 pub fn get(&self, index: usize) -> Option<&dyn Module> {
172 self.modules.get(index).map(|m| m.as_ref())
173 }
174}
175
176impl Default for ModuleList {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl Module for ModuleList {
183 fn forward(&self, input: &Variable) -> Variable {
184 let mut x = input.clone();
185 for module in &self.modules {
186 x = module.forward(&x);
187 }
188 x
189 }
190
191 fn parameters(&self) -> Vec<Parameter> {
192 self.modules.iter().flat_map(|m| m.parameters()).collect()
193 }
194
195 fn named_parameters(&self) -> HashMap<String, Parameter> {
196 let mut params = HashMap::new();
197 for (i, module) in self.modules.iter().enumerate() {
198 for (name, param) in module.named_parameters() {
199 params.insert(format!("{i}.{name}"), param);
200 }
201 }
202 params
203 }
204
205 fn set_training(&mut self, training: bool) {
206 self.training = training;
207 for module in &mut self.modules {
208 module.set_training(training);
209 }
210 }
211
212 fn is_training(&self) -> bool {
213 self.training
214 }
215
216 fn name(&self) -> &'static str {
217 "ModuleList"
218 }
219}
220
221// =============================================================================
222// Tests
223// =============================================================================
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use axonml_tensor::Tensor;
229
230 // Simple test module
231 struct Identity;
232
233 impl Module for Identity {
234 fn forward(&self, input: &Variable) -> Variable {
235 input.clone()
236 }
237
238 fn name(&self) -> &'static str {
239 "Identity"
240 }
241 }
242
243 #[test]
244 fn test_module_list() {
245 let mut list = ModuleList::new();
246 list.push(Identity);
247 list.push(Identity);
248 assert_eq!(list.len(), 2);
249 }
250
251 #[test]
252 fn test_module_list_forward() {
253 let mut list = ModuleList::new();
254 list.push(Identity);
255
256 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
257 let output = list.forward(&input);
258 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
259 }
260}