Skip to main content

axonml_nn/
sequential.rs

1//! `Sequential` — ordered container that chains Module forward passes.
2//!
3//! 198 lines. `Sequential::new()`, `.add(Box<dyn Module>)`, `.forward()`
4//! (feeds each layer's output as the next layer's input), and the full
5//! `Module` trait impl (parameters, train/eval, zero_grad, to_device
6//! all propagate to children). Builder pattern: `Sequential::new().add(
7//! Box::new(Linear::new(784, 128))).add(Box::new(ReLU)).add(...)`.
8//!
9//! # File
10//! `crates/axonml-nn/src/sequential.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use std::collections::HashMap;
25
26use axonml_autograd::Variable;
27
28use crate::module::Module;
29use crate::parameter::Parameter;
30
31// =============================================================================
32// Sequential
33// =============================================================================
34
35/// A sequential container that chains modules together.
36///
37/// Modules are added in the order they should be executed.
38/// The forward pass executes each module in order, passing
39/// the output of one as the input to the next.
40///
41/// # Example
42/// ```ignore
43/// let model = Sequential::new()
44///     .add(Linear::new(784, 256))
45///     .add(ReLU)
46///     .add(Linear::new(256, 10));
47///
48/// let output = model.forward(&input);
49/// ```
50pub struct Sequential {
51    modules: Vec<(String, Box<dyn Module>)>,
52    training: bool,
53}
54
55impl Sequential {
56    /// Creates a new empty Sequential container.
57    pub fn new() -> Self {
58        Self {
59            modules: Vec::new(),
60            training: true,
61        }
62    }
63
64    /// Adds a module with an auto-generated name.
65    pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
66        let name = format!("{}", self.modules.len());
67        self.modules.push((name, Box::new(module)));
68        self
69    }
70
71    /// Adds a module with a specific name.
72    pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
73        self.modules.push((name.into(), Box::new(module)));
74        self
75    }
76
77    /// Pushes a module (non-builder pattern).
78    pub fn push<M: Module + 'static>(&mut self, module: M) {
79        let name = format!("{}", self.modules.len());
80        self.modules.push((name, Box::new(module)));
81    }
82
83    /// Pushes a named module (non-builder pattern).
84    pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
85        self.modules.push((name.into(), Box::new(module)));
86    }
87
88    /// Returns the number of modules.
89    pub fn len(&self) -> usize {
90        self.modules.len()
91    }
92
93    /// Returns true if empty.
94    pub fn is_empty(&self) -> bool {
95        self.modules.is_empty()
96    }
97
98    /// Returns an iterator over named modules.
99    pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
100        self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
101    }
102}
103
104impl Default for Sequential {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl Module for Sequential {
111    fn forward(&self, input: &Variable) -> Variable {
112        let mut x = input.clone();
113        for (_, module) in &self.modules {
114            x = module.forward(&x);
115        }
116        x
117    }
118
119    fn parameters(&self) -> Vec<Parameter> {
120        self.modules
121            .iter()
122            .flat_map(|(_, m)| m.parameters())
123            .collect()
124    }
125
126    fn named_parameters(&self) -> HashMap<String, Parameter> {
127        let mut params = HashMap::new();
128        for (module_name, module) in &self.modules {
129            for (param_name, param) in module.named_parameters() {
130                params.insert(format!("{module_name}.{param_name}"), param);
131            }
132        }
133        params
134    }
135
136    fn set_training(&mut self, training: bool) {
137        self.training = training;
138        for (_, module) in &mut self.modules {
139            module.set_training(training);
140        }
141    }
142
143    fn is_training(&self) -> bool {
144        self.training
145    }
146
147    fn name(&self) -> &'static str {
148        "Sequential"
149    }
150}
151
152// =============================================================================
153// Tests
154// =============================================================================
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use axonml_tensor::Tensor;
160
161    // Test identity module
162    struct TestIdentity;
163
164    impl Module for TestIdentity {
165        fn forward(&self, input: &Variable) -> Variable {
166            input.clone()
167        }
168    }
169
170    // Test doubling module
171    struct TestDouble;
172
173    impl Module for TestDouble {
174        fn forward(&self, input: &Variable) -> Variable {
175            input.add_var(input)
176        }
177    }
178
179    #[test]
180    fn test_sequential_creation() {
181        let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
182        assert_eq!(seq.len(), 2);
183    }
184
185    #[test]
186    fn test_sequential_forward() {
187        let seq = Sequential::new().add(TestDouble).add(TestDouble);
188
189        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
190        let output = seq.forward(&input);
191
192        // Double twice: 1*2*2=4, 2*2*2=8
193        assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
194    }
195
196    #[test]
197    fn test_sequential_named() {
198        let seq = Sequential::new()
199            .add_named("layer1", TestIdentity)
200            .add_named("layer2", TestDouble);
201
202        let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
203        assert_eq!(names, vec!["layer1", "layer2"]);
204    }
205}