1use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12
13use crate::parameter::Parameter;
14
15pub trait Module: Send + Sync {
27 fn forward(&self, input: &Variable) -> Variable;
35
36 fn parameters(&self) -> Vec<Parameter> {
40 Vec::new()
41 }
42
43 fn named_parameters(&self) -> HashMap<String, Parameter> {
45 HashMap::new()
46 }
47
48 fn num_parameters(&self) -> usize {
50 self.parameters()
51 .iter()
52 .filter(|p| p.requires_grad())
53 .map(|p| p.numel())
54 .sum()
55 }
56
57 fn train(&mut self) {
59 self.set_training(true);
60 }
61
62 fn eval(&mut self) {
64 self.set_training(false);
65 }
66
67 fn set_training(&mut self, _training: bool) {
69 }
72
73 fn is_training(&self) -> bool {
75 true }
77
78 fn zero_grad(&self) {
80 for param in self.parameters() {
81 param.zero_grad();
82 }
83 }
84
85 fn name(&self) -> &'static str {
87 std::any::type_name::<Self>()
88 }
89}
90
91pub struct ModuleList {
97 modules: Vec<Box<dyn Module>>,
98 training: bool,
99}
100
101impl ModuleList {
102 pub fn new() -> Self {
104 Self {
105 modules: Vec::new(),
106 training: true,
107 }
108 }
109
110 pub fn from_vec(modules: Vec<Box<dyn Module>>) -> Self {
112 Self {
113 modules,
114 training: true,
115 }
116 }
117
118 pub fn push<M: Module + 'static>(&mut self, module: M) {
120 self.modules.push(Box::new(module));
121 }
122
123 pub fn len(&self) -> usize {
125 self.modules.len()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.modules.is_empty()
131 }
132
133 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Module>> {
135 self.modules.iter()
136 }
137
138 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn Module>> {
140 self.modules.iter_mut()
141 }
142
143 pub fn get(&self, index: usize) -> Option<&dyn Module> {
145 self.modules.get(index).map(|m| m.as_ref())
146 }
147}
148
149impl Default for ModuleList {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl Module for ModuleList {
156 fn forward(&self, input: &Variable) -> Variable {
157 let mut x = input.clone();
158 for module in &self.modules {
159 x = module.forward(&x);
160 }
161 x
162 }
163
164 fn parameters(&self) -> Vec<Parameter> {
165 self.modules.iter().flat_map(|m| m.parameters()).collect()
166 }
167
168 fn named_parameters(&self) -> HashMap<String, Parameter> {
169 let mut params = HashMap::new();
170 for (i, module) in self.modules.iter().enumerate() {
171 for (name, param) in module.named_parameters() {
172 params.insert(format!("{i}.{name}"), param);
173 }
174 }
175 params
176 }
177
178 fn set_training(&mut self, training: bool) {
179 self.training = training;
180 for module in &mut self.modules {
181 module.set_training(training);
182 }
183 }
184
185 fn is_training(&self) -> bool {
186 self.training
187 }
188
189 fn name(&self) -> &'static str {
190 "ModuleList"
191 }
192}
193
194#[cfg(test)]
199mod tests {
200 use super::*;
201 use axonml_tensor::Tensor;
202
203 struct Identity;
205
206 impl Module for Identity {
207 fn forward(&self, input: &Variable) -> Variable {
208 input.clone()
209 }
210
211 fn name(&self) -> &'static str {
212 "Identity"
213 }
214 }
215
216 #[test]
217 fn test_module_list() {
218 let mut list = ModuleList::new();
219 list.push(Identity);
220 list.push(Identity);
221 assert_eq!(list.len(), 2);
222 }
223
224 #[test]
225 fn test_module_list_forward() {
226 let mut list = ModuleList::new();
227 list.push(Identity);
228
229 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
230 let output = list.forward(&input);
231 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
232 }
233}