Skip to main content

axonml_nn/
sequential.rs

1//! Sequential - Sequential Container for Modules
2//!
3//! A container that runs modules in sequence, passing the output
4//! of each module as input to the next.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::module::Module;
14use crate::parameter::Parameter;
15
16// =============================================================================
17// Sequential
18// =============================================================================
19
20/// A sequential container that chains modules together.
21///
22/// Modules are added in the order they should be executed.
23/// The forward pass executes each module in order, passing
24/// the output of one as the input to the next.
25///
26/// # Example
27/// ```ignore
28/// let model = Sequential::new()
29///     .add(Linear::new(784, 256))
30///     .add(ReLU)
31///     .add(Linear::new(256, 10));
32///
33/// let output = model.forward(&input);
34/// ```
35pub struct Sequential {
36    modules: Vec<(String, Box<dyn Module>)>,
37    training: bool,
38}
39
40impl Sequential {
41    /// Creates a new empty Sequential container.
42    pub fn new() -> Self {
43        Self {
44            modules: Vec::new(),
45            training: true,
46        }
47    }
48
49    /// Adds a module with an auto-generated name.
50    pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
51        let name = format!("{}", self.modules.len());
52        self.modules.push((name, Box::new(module)));
53        self
54    }
55
56    /// Adds a module with a specific name.
57    pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
58        self.modules.push((name.into(), Box::new(module)));
59        self
60    }
61
62    /// Pushes a module (non-builder pattern).
63    pub fn push<M: Module + 'static>(&mut self, module: M) {
64        let name = format!("{}", self.modules.len());
65        self.modules.push((name, Box::new(module)));
66    }
67
68    /// Pushes a named module (non-builder pattern).
69    pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
70        self.modules.push((name.into(), Box::new(module)));
71    }
72
73    /// Returns the number of modules.
74    pub fn len(&self) -> usize {
75        self.modules.len()
76    }
77
78    /// Returns true if empty.
79    pub fn is_empty(&self) -> bool {
80        self.modules.is_empty()
81    }
82
83    /// Returns an iterator over named modules.
84    pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
85        self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
86    }
87}
88
89impl Default for Sequential {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Module for Sequential {
96    fn forward(&self, input: &Variable) -> Variable {
97        let mut x = input.clone();
98        for (_, module) in &self.modules {
99            x = module.forward(&x);
100        }
101        x
102    }
103
104    fn parameters(&self) -> Vec<Parameter> {
105        self.modules
106            .iter()
107            .flat_map(|(_, m)| m.parameters())
108            .collect()
109    }
110
111    fn named_parameters(&self) -> HashMap<String, Parameter> {
112        let mut params = HashMap::new();
113        for (module_name, module) in &self.modules {
114            for (param_name, param) in module.named_parameters() {
115                params.insert(format!("{module_name}.{param_name}"), param);
116            }
117        }
118        params
119    }
120
121    fn set_training(&mut self, training: bool) {
122        self.training = training;
123        for (_, module) in &mut self.modules {
124            module.set_training(training);
125        }
126    }
127
128    fn is_training(&self) -> bool {
129        self.training
130    }
131
132    fn name(&self) -> &'static str {
133        "Sequential"
134    }
135}
136
137// =============================================================================
138// Tests
139// =============================================================================
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use axonml_tensor::Tensor;
145
146    // Test identity module
147    struct TestIdentity;
148
149    impl Module for TestIdentity {
150        fn forward(&self, input: &Variable) -> Variable {
151            input.clone()
152        }
153    }
154
155    // Test doubling module
156    struct TestDouble;
157
158    impl Module for TestDouble {
159        fn forward(&self, input: &Variable) -> Variable {
160            input.add_var(input)
161        }
162    }
163
164    #[test]
165    fn test_sequential_creation() {
166        let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
167        assert_eq!(seq.len(), 2);
168    }
169
170    #[test]
171    fn test_sequential_forward() {
172        let seq = Sequential::new().add(TestDouble).add(TestDouble);
173
174        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
175        let output = seq.forward(&input);
176
177        // Double twice: 1*2*2=4, 2*2*2=8
178        assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
179    }
180
181    #[test]
182    fn test_sequential_named() {
183        let seq = Sequential::new()
184            .add_named("layer1", TestIdentity)
185            .add_named("layer2", TestDouble);
186
187        let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
188        assert_eq!(names, vec!["layer1", "layer2"]);
189    }
190}