1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_core::Device;
21
22use crate::parameter::Parameter;
23
24pub trait Module: Send + Sync {
36 fn forward(&self, input: &Variable) -> Variable;
44
45 fn parameters(&self) -> Vec<Parameter> {
49 Vec::new()
50 }
51
52 fn named_parameters(&self) -> HashMap<String, Parameter> {
54 HashMap::new()
55 }
56
57 fn num_parameters(&self) -> usize {
59 self.parameters()
60 .iter()
61 .filter(|p| p.requires_grad())
62 .map(|p| p.numel())
63 .sum()
64 }
65
66 fn train(&mut self) {
68 self.set_training(true);
69 }
70
71 fn eval(&mut self) {
73 self.set_training(false);
74 }
75
76 fn set_training(&mut self, _training: bool) {
78 }
81
82 fn is_training(&self) -> bool {
84 true }
86
87 fn zero_grad(&self) {
89 for param in self.parameters() {
90 param.zero_grad();
91 }
92 }
93
94 fn to_device(&self, device: Device) {
96 for param in self.parameters() {
97 param.to_device(device);
98 }
99 }
100
101 fn name(&self) -> &'static str {
103 std::any::type_name::<Self>()
104 }
105}
106
107pub struct ModuleList {
113 modules: Vec<Box<dyn Module>>,
114 training: bool,
115}
116
117impl ModuleList {
118 pub fn new() -> Self {
120 Self {
121 modules: Vec::new(),
122 training: true,
123 }
124 }
125
126 pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
128 Self {
129 modules,
130 training: true,
131 }
132 }
133
134 pub fn push<M: Module + 'static>(&mut self, module: M) {
136 self.modules.push(Box::new(module));
137 }
138
139 pub fn len(&self) -> usize {
141 self.modules.len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.modules.is_empty()
147 }
148
149 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
151 self.modules.iter()
152 }
153
154 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
156 self.modules.iter_mut()
157 }
158
159 pub fn get(&self, index: usize) -> Option<&dyn Module> {
161 self.modules.get(index).map(|m| m.as_ref())
162 }
163}
164
165impl Default for ModuleList {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl Module for ModuleList {
172 fn forward(&self, input: &Variable) -> Variable {
173 let mut x = input.clone();
174 for module in &self.modules {
175 x = module.forward(&x);
176 }
177 x
178 }
179
180 fn parameters(&self) -> Vec<Parameter> {
181 self.modules.iter().flat_map(|m| m.parameters()).collect()
182 }
183
184 fn named_parameters(&self) -> HashMap<String, Parameter> {
185 let mut params = HashMap::new();
186 for (i, module) in self.modules.iter().enumerate() {
187 for (name, param) in module.named_parameters() {
188 params.insert(format!("{i}.{name}"), param);
189 }
190 }
191 params
192 }
193
194 fn set_training(&mut self, training: bool) {
195 self.training = training;
196 for module in &mut self.modules {
197 module.set_training(training);
198 }
199 }
200
201 fn is_training(&self) -> bool {
202 self.training
203 }
204
205 fn name(&self) -> &'static str {
206 "ModuleList"
207 }
208}
209
210#[cfg(test)]
215mod tests {
216 use super::*;
217 use axonml_tensor::Tensor;
218
219 struct Identity;
221
222 impl Module for Identity {
223 fn forward(&self, input: &Variable) -> Variable {
224 input.clone()
225 }
226
227 fn name(&self) -> &'static str {
228 "Identity"
229 }
230 }
231
232 #[test]
233 fn test_module_list() {
234 let mut list = ModuleList::new();
235 list.push(Identity);
236 list.push(Identity);
237 assert_eq!(list.len(), 2);
238 }
239
240 #[test]
241 fn test_module_list_forward() {
242 let mut list = ModuleList::new();
243 list.push(Identity);
244
245 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
246 let output = list.forward(&input);
247 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
248 }
249}