1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::module::Module;
22use crate::parameter::Parameter;
23
24pub struct Sequential {
44 modules: Vec<(String, Box<dyn Module>)>,
45 training: bool,
46}
47
48impl Sequential {
49 pub fn new() -> Self {
51 Self {
52 modules: Vec::new(),
53 training: true,
54 }
55 }
56
57 pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
59 let name = format!("{}", self.modules.len());
60 self.modules.push((name, Box::new(module)));
61 self
62 }
63
64 pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
66 self.modules.push((name.into(), Box::new(module)));
67 self
68 }
69
70 pub fn push<M: Module + 'static>(&mut self, module: M) {
72 let name = format!("{}", self.modules.len());
73 self.modules.push((name, Box::new(module)));
74 }
75
76 pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
78 self.modules.push((name.into(), Box::new(module)));
79 }
80
81 pub fn len(&self) -> usize {
83 self.modules.len()
84 }
85
86 pub fn is_empty(&self) -> bool {
88 self.modules.is_empty()
89 }
90
91 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
93 self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
94 }
95}
96
97impl Default for Sequential {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl Module for Sequential {
104 fn forward(&self, input: &Variable) -> Variable {
105 let mut x = input.clone();
106 for (_, module) in &self.modules {
107 x = module.forward(&x);
108 }
109 x
110 }
111
112 fn parameters(&self) -> Vec<Parameter> {
113 self.modules
114 .iter()
115 .flat_map(|(_, m)| m.parameters())
116 .collect()
117 }
118
119 fn named_parameters(&self) -> HashMap<String, Parameter> {
120 let mut params = HashMap::new();
121 for (module_name, module) in &self.modules {
122 for (param_name, param) in module.named_parameters() {
123 params.insert(format!("{module_name}.{param_name}"), param);
124 }
125 }
126 params
127 }
128
129 fn set_training(&mut self, training: bool) {
130 self.training = training;
131 for (_, module) in &mut self.modules {
132 module.set_training(training);
133 }
134 }
135
136 fn is_training(&self) -> bool {
137 self.training
138 }
139
140 fn name(&self) -> &'static str {
141 "Sequential"
142 }
143}
144
145#[cfg(test)]
150mod tests {
151 use super::*;
152 use axonml_tensor::Tensor;
153
154 struct TestIdentity;
156
157 impl Module for TestIdentity {
158 fn forward(&self, input: &Variable) -> Variable {
159 input.clone()
160 }
161 }
162
163 struct TestDouble;
165
166 impl Module for TestDouble {
167 fn forward(&self, input: &Variable) -> Variable {
168 input.add_var(input)
169 }
170 }
171
172 #[test]
173 fn test_sequential_creation() {
174 let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
175 assert_eq!(seq.len(), 2);
176 }
177
178 #[test]
179 fn test_sequential_forward() {
180 let seq = Sequential::new().add(TestDouble).add(TestDouble);
181
182 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
183 let output = seq.forward(&input);
184
185 assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
187 }
188
189 #[test]
190 fn test_sequential_named() {
191 let seq = Sequential::new()
192 .add_named("layer1", TestIdentity)
193 .add_named("layer2", TestDouble);
194
195 let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
196 assert_eq!(names, vec!["layer1", "layer2"]);
197 }
198}