Skip to main content

axonml_nn/
sequential.rs

1//! Sequential - Sequential Container for Modules
2//!
3//! # File
4//! `crates/axonml-nn/src/sequential.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::module::Module;
22use crate::parameter::Parameter;
23
24// =============================================================================
25// Sequential
26// =============================================================================
27
28/// A sequential container that chains modules together.
29///
30/// Modules are added in the order they should be executed.
31/// The forward pass executes each module in order, passing
32/// the output of one as the input to the next.
33///
34/// # Example
35/// ```ignore
36/// let model = Sequential::new()
37///     .add(Linear::new(784, 256))
38///     .add(ReLU)
39///     .add(Linear::new(256, 10));
40///
41/// let output = model.forward(&input);
42/// ```
43pub struct Sequential {
44    modules: Vec<(String, Box<dyn Module>)>,
45    training: bool,
46}
47
48impl Sequential {
49    /// Creates a new empty Sequential container.
50    pub fn new() -> Self {
51        Self {
52            modules: Vec::new(),
53            training: true,
54        }
55    }
56
57    /// Adds a module with an auto-generated name.
58    pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
59        let name = format!("{}", self.modules.len());
60        self.modules.push((name, Box::new(module)));
61        self
62    }
63
64    /// Adds a module with a specific name.
65    pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
66        self.modules.push((name.into(), Box::new(module)));
67        self
68    }
69
70    /// Pushes a module (non-builder pattern).
71    pub fn push<M: Module + 'static>(&mut self, module: M) {
72        let name = format!("{}", self.modules.len());
73        self.modules.push((name, Box::new(module)));
74    }
75
76    /// Pushes a named module (non-builder pattern).
77    pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
78        self.modules.push((name.into(), Box::new(module)));
79    }
80
81    /// Returns the number of modules.
82    pub fn len(&self) -> usize {
83        self.modules.len()
84    }
85
86    /// Returns true if empty.
87    pub fn is_empty(&self) -> bool {
88        self.modules.is_empty()
89    }
90
91    /// Returns an iterator over named modules.
92    pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
93        self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
94    }
95}
96
97impl Default for Sequential {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl Module for Sequential {
104    fn forward(&self, input: &Variable) -> Variable {
105        let mut x = input.clone();
106        for (_, module) in &self.modules {
107            x = module.forward(&x);
108        }
109        x
110    }
111
112    fn parameters(&self) -> Vec<Parameter> {
113        self.modules
114            .iter()
115            .flat_map(|(_, m)| m.parameters())
116            .collect()
117    }
118
119    fn named_parameters(&self) -> HashMap<String, Parameter> {
120        let mut params = HashMap::new();
121        for (module_name, module) in &self.modules {
122            for (param_name, param) in module.named_parameters() {
123                params.insert(format!("{module_name}.{param_name}"), param);
124            }
125        }
126        params
127    }
128
129    fn set_training(&mut self, training: bool) {
130        self.training = training;
131        for (_, module) in &mut self.modules {
132            module.set_training(training);
133        }
134    }
135
136    fn is_training(&self) -> bool {
137        self.training
138    }
139
140    fn name(&self) -> &'static str {
141        "Sequential"
142    }
143}
144
145// =============================================================================
146// Tests
147// =============================================================================
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use axonml_tensor::Tensor;
153
154    // Test identity module
155    struct TestIdentity;
156
157    impl Module for TestIdentity {
158        fn forward(&self, input: &Variable) -> Variable {
159            input.clone()
160        }
161    }
162
163    // Test doubling module
164    struct TestDouble;
165
166    impl Module for TestDouble {
167        fn forward(&self, input: &Variable) -> Variable {
168            input.add_var(input)
169        }
170    }
171
172    #[test]
173    fn test_sequential_creation() {
174        let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
175        assert_eq!(seq.len(), 2);
176    }
177
178    #[test]
179    fn test_sequential_forward() {
180        let seq = Sequential::new().add(TestDouble).add(TestDouble);
181
182        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
183        let output = seq.forward(&input);
184
185        // Double twice: 1*2*2=4, 2*2*2=8
186        assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
187    }
188
189    #[test]
190    fn test_sequential_named() {
191        let seq = Sequential::new()
192            .add_named("layer1", TestIdentity)
193            .add_named("layer2", TestDouble);
194
195        let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
196        assert_eq!(names, vec!["layer1", "layer2"]);
197    }
198}