1use std::collections::HashMap;
25
26use axonml_autograd::Variable;
27
28use crate::module::Module;
29use crate::parameter::Parameter;
30
31pub struct Sequential {
51 modules: Vec<(String, Box<dyn Module>)>,
52 training: bool,
53}
54
55impl Sequential {
56 pub fn new() -> Self {
58 Self {
59 modules: Vec::new(),
60 training: true,
61 }
62 }
63
64 pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
66 let name = format!("{}", self.modules.len());
67 self.modules.push((name, Box::new(module)));
68 self
69 }
70
71 pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
73 self.modules.push((name.into(), Box::new(module)));
74 self
75 }
76
77 pub fn push<M: Module + 'static>(&mut self, module: M) {
79 let name = format!("{}", self.modules.len());
80 self.modules.push((name, Box::new(module)));
81 }
82
83 pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
85 self.modules.push((name.into(), Box::new(module)));
86 }
87
88 pub fn len(&self) -> usize {
90 self.modules.len()
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.modules.is_empty()
96 }
97
98 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
100 self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
101 }
102}
103
104impl Default for Sequential {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl Module for Sequential {
111 fn forward(&self, input: &Variable) -> Variable {
112 let mut x = input.clone();
113 for (_, module) in &self.modules {
114 x = module.forward(&x);
115 }
116 x
117 }
118
119 fn parameters(&self) -> Vec<Parameter> {
120 self.modules
121 .iter()
122 .flat_map(|(_, m)| m.parameters())
123 .collect()
124 }
125
126 fn named_parameters(&self) -> HashMap<String, Parameter> {
127 let mut params = HashMap::new();
128 for (module_name, module) in &self.modules {
129 for (param_name, param) in module.named_parameters() {
130 params.insert(format!("{module_name}.{param_name}"), param);
131 }
132 }
133 params
134 }
135
136 fn set_training(&mut self, training: bool) {
137 self.training = training;
138 for (_, module) in &mut self.modules {
139 module.set_training(training);
140 }
141 }
142
143 fn is_training(&self) -> bool {
144 self.training
145 }
146
147 fn name(&self) -> &'static str {
148 "Sequential"
149 }
150}
151
152#[cfg(test)]
157mod tests {
158 use super::*;
159 use axonml_tensor::Tensor;
160
161 struct TestIdentity;
163
164 impl Module for TestIdentity {
165 fn forward(&self, input: &Variable) -> Variable {
166 input.clone()
167 }
168 }
169
170 struct TestDouble;
172
173 impl Module for TestDouble {
174 fn forward(&self, input: &Variable) -> Variable {
175 input.add_var(input)
176 }
177 }
178
179 #[test]
180 fn test_sequential_creation() {
181 let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
182 assert_eq!(seq.len(), 2);
183 }
184
185 #[test]
186 fn test_sequential_forward() {
187 let seq = Sequential::new().add(TestDouble).add(TestDouble);
188
189 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
190 let output = seq.forward(&input);
191
192 assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
194 }
195
196 #[test]
197 fn test_sequential_named() {
198 let seq = Sequential::new()
199 .add_named("layer1", TestIdentity)
200 .add_named("layer2", TestDouble);
201
202 let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
203 assert_eq!(names, vec!["layer1", "layer2"]);
204 }
205}