1use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::module::Module;
14use crate::parameter::Parameter;
15
16pub struct Sequential {
36 modules: Vec<(String, Box<dyn Module>)>,
37 training: bool,
38}
39
40impl Sequential {
41 pub fn new() -> Self {
43 Self {
44 modules: Vec::new(),
45 training: true,
46 }
47 }
48
49 pub fn add<M: Module + 'static>(mut self, module: M) -> Self {
51 let name = format!("{}", self.modules.len());
52 self.modules.push((name, Box::new(module)));
53 self
54 }
55
56 pub fn add_named<M: Module + 'static>(mut self, name: impl Into<String>, module: M) -> Self {
58 self.modules.push((name.into(), Box::new(module)));
59 self
60 }
61
62 pub fn push<M: Module + 'static>(&mut self, module: M) {
64 let name = format!("{}", self.modules.len());
65 self.modules.push((name, Box::new(module)));
66 }
67
68 pub fn push_named<M: Module + 'static>(&mut self, name: impl Into<String>, module: M) {
70 self.modules.push((name.into(), Box::new(module)));
71 }
72
73 pub fn len(&self) -> usize {
75 self.modules.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.modules.is_empty()
81 }
82
83 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn Module)> {
85 self.modules.iter().map(|(n, m)| (n.as_str(), m.as_ref()))
86 }
87}
88
89impl Default for Sequential {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl Module for Sequential {
96 fn forward(&self, input: &Variable) -> Variable {
97 let mut x = input.clone();
98 for (_, module) in &self.modules {
99 x = module.forward(&x);
100 }
101 x
102 }
103
104 fn parameters(&self) -> Vec<Parameter> {
105 self.modules
106 .iter()
107 .flat_map(|(_, m)| m.parameters())
108 .collect()
109 }
110
111 fn named_parameters(&self) -> HashMap<String, Parameter> {
112 let mut params = HashMap::new();
113 for (module_name, module) in &self.modules {
114 for (param_name, param) in module.named_parameters() {
115 params.insert(format!("{module_name}.{param_name}"), param);
116 }
117 }
118 params
119 }
120
121 fn set_training(&mut self, training: bool) {
122 self.training = training;
123 for (_, module) in &mut self.modules {
124 module.set_training(training);
125 }
126 }
127
128 fn is_training(&self) -> bool {
129 self.training
130 }
131
132 fn name(&self) -> &'static str {
133 "Sequential"
134 }
135}
136
137#[cfg(test)]
142mod tests {
143 use super::*;
144 use axonml_tensor::Tensor;
145
146 struct TestIdentity;
148
149 impl Module for TestIdentity {
150 fn forward(&self, input: &Variable) -> Variable {
151 input.clone()
152 }
153 }
154
155 struct TestDouble;
157
158 impl Module for TestDouble {
159 fn forward(&self, input: &Variable) -> Variable {
160 input.add_var(input)
161 }
162 }
163
164 #[test]
165 fn test_sequential_creation() {
166 let seq = Sequential::new().add(TestIdentity).add(TestIdentity);
167 assert_eq!(seq.len(), 2);
168 }
169
170 #[test]
171 fn test_sequential_forward() {
172 let seq = Sequential::new().add(TestDouble).add(TestDouble);
173
174 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
175 let output = seq.forward(&input);
176
177 assert_eq!(output.data().to_vec(), vec![4.0, 8.0]);
179 }
180
181 #[test]
182 fn test_sequential_named() {
183 let seq = Sequential::new()
184 .add_named("layer1", TestIdentity)
185 .add_named("layer2", TestDouble);
186
187 let names: Vec<&str> = seq.iter().map(|(n, _)| n).collect();
188 assert_eq!(names, vec!["layer1", "layer2"]);
189 }
190}