1use crate::module::{
2 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3 ModuleVisitor,
4};
5
6use alloc::{format, string::ToString, vec::Vec};
7
8use burn_tensor::{
9 backend::{AutodiffBackend, Backend},
10 ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16 T: Module<B> + Debug + Send + Clone,
17 B: Backend,
18{
19 type Record = Option<T::Record>;
20
21 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22 if let Some(module) = self {
23 module.visit(visitor)
24 }
25 }
26
27 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28 self.map(|module| module.map(mapper))
29 }
30
31 fn load_record(self, record: Self::Record) -> Self {
32 let is_constant = self.num_params() == 0;
33
34 if is_constant {
35 return self;
36 }
37
38 self.zip(record)
39 .map(|(module, record)| module.load_record(record))
40 }
41
42 fn into_record(self) -> Self::Record {
43 self.map(Module::into_record)
44 }
45
46 fn to_device(self, device: &Device<B>) -> Self {
47 self.map(|module| module.to_device(device))
48 }
49
50 fn fork(self, device: &Device<B>) -> Self {
51 self.map(|module| module.fork(device))
52 }
53
54 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55 if let Some(module) = self.as_ref() {
56 devices = module.collect_devices(devices);
57 }
58
59 devices
60 }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64 fn content(&self, content: Content) -> Option<Content> {
65 match self {
66 Some(module) => content.add_single(module).optional(),
67 None => content.add_single("None").optional(),
68 }
69 }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76 T: AutodiffModule<B> + Debug + Send + Clone,
77 B: AutodiffBackend,
78{
79 type InnerModule = Option<T::InnerModule>;
80
81 fn valid(&self) -> Self::InnerModule {
82 self.as_ref().map(|module| module.valid())
83 }
84
85 fn from_inner(module: Self::InnerModule) -> Self {
86 module.map(|module| T::from_inner(module))
87 }
88}
89
90impl<T, B> Module<B> for Vec<T>
91where
92 T: Module<B> + Debug + Send + Clone,
93 B: Backend,
94{
95 type Record = Vec<T::Record>;
96
97 fn num_params(&self) -> usize {
98 let mut num_params = 0;
99 for module in self.iter() {
100 num_params += module.num_params();
101 }
102
103 num_params
104 }
105
106 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
107 for (i, module) in self.iter().enumerate() {
108 let index_str = alloc::format!("{}", i);
109 visitor.enter_module(&index_str, "Vec");
110 module.visit(visitor);
111 visitor.exit_module(&index_str, "Vec");
112 }
113 }
114
115 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
116 self.into_iter()
117 .enumerate()
118 .map(|(i, module)| {
119 let index_str = alloc::format!("{}", i);
120 mapper.enter_module(&index_str, "Vec");
121 let mapped = module.map(mapper);
122 mapper.exit_module(&index_str, "Vec");
123 mapped
124 })
125 .collect()
126 }
127
128 fn into_record(self) -> Self::Record {
129 self.into_iter().map(Module::into_record).collect()
130 }
131
132 fn load_record(self, record: Self::Record) -> Self {
133 assert_eq!(
134 self.len(),
135 record.len(),
136 r#"[Load Record Error] The vec record does not the same length as the module.
137 Make sure you module initialization is compatible with the record being loaded.
138 "#,
139 );
140
141 self.into_iter()
142 .zip(record)
143 .map(|(module, record)| module.load_record(record))
144 .collect()
145 }
146
147 fn to_device(self, device: &Device<B>) -> Self {
148 self.into_iter()
149 .map(|module| module.to_device(device))
150 .collect()
151 }
152
153 fn fork(self, device: &Device<B>) -> Self {
154 self.into_iter().map(|module| module.fork(device)).collect()
155 }
156
157 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
158 for module in self.iter() {
159 devices = module.collect_devices(devices);
160 }
161
162 devices
163 }
164}
165
166impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
167 fn content(&self, content: Content) -> Option<Content> {
168 self.iter()
169 .enumerate()
170 .fold(content, |acc, (i, module)| {
171 let index = format!("{i}");
172 acc.add(&index, module)
173 })
174 .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
175 .optional()
176 }
177}
178
179impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
180
181impl<T, B> AutodiffModule<B> for Vec<T>
182where
183 T: AutodiffModule<B> + Debug + Send + Clone,
184 B: AutodiffBackend,
185{
186 type InnerModule = Vec<T::InnerModule>;
187
188 fn valid(&self) -> Self::InnerModule {
189 self.iter().map(|module| module.valid()).collect()
190 }
191
192 fn from_inner(module: Self::InnerModule) -> Self {
193 module
194 .into_iter()
195 .map(|module| T::from_inner(module))
196 .collect()
197 }
198}
199
200impl<const N: usize, T, B> Module<B> for [T; N]
201where
202 T: Module<B> + Debug + Send + Clone,
203 B: Backend,
204{
205 type Record = [T::Record; N];
206
207 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
208 for module in self.iter() {
209 devices = module.collect_devices(devices);
210 }
211
212 devices
213 }
214
215 fn num_params(&self) -> usize {
216 let mut num_params = 0;
217 for module in self.iter() {
218 num_params += module.num_params();
219 }
220
221 num_params
222 }
223
224 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
225 for (i, module) in self.iter().enumerate() {
226 let index_str = alloc::format!("{}", i);
227 visitor.enter_module(&index_str, "Array");
228 module.visit(visitor);
229 visitor.exit_module(&index_str, "Array");
230 }
231 }
232
233 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
234 let mut result = Vec::with_capacity(N);
235 for (i, module) in IntoIterator::into_iter(self).enumerate() {
236 let index_str = alloc::format!("{}", i);
237 mapper.enter_module(&index_str, "Array");
238 let mapped = module.map(mapper);
239 mapper.exit_module(&index_str, "Array");
240 result.push(mapped);
241 }
242 result
243 .try_into()
244 .unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
245 }
246
247 fn load_record(self, record: Self::Record) -> Self {
248 self.into_iter()
249 .zip(record)
250 .map(|(module, record)| module.load_record(record))
251 .collect::<Vec<_>>()
252 .try_into()
253 .unwrap()
254 }
255
256 fn into_record(self) -> Self::Record {
257 self.map(Module::into_record)
258 }
259
260 fn to_device(self, device: &Device<B>) -> Self {
261 self.map(|module| module.to_device(device))
262 }
263
264 fn fork(self, device: &Device<B>) -> Self {
265 self.map(|module| module.fork(device))
266 }
267}
268
269impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
270 fn content(&self, content: Content) -> Option<Content> {
271 self.iter()
272 .enumerate()
273 .fold(content, |acc, (i, module)| {
274 let index = format!("{i}");
275 acc.add(&index, module)
276 })
277 .set_top_level_type(format!("[0..{}]", self.len()).as_str())
278 .optional()
279 }
280}
281
282impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
283
284impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
285where
286 T: AutodiffModule<B> + Debug + Send + Clone,
287 T::InnerModule: Debug,
288 B: AutodiffBackend,
289{
290 type InnerModule = [T::InnerModule; N];
291
292 fn valid(&self) -> Self::InnerModule {
293 self.clone().map(|module| module.valid())
294 }
295
296 fn from_inner(module: Self::InnerModule) -> Self {
297 module.map(|module| T::from_inner(module))
298 }
299}
300
301macro_rules! impl_module_tuple {
307 ([$($l:ident),*][$($i:tt),*]) => {
310 impl<B, $($l,)*> Module<B> for ($($l,)*)
311 where
312 B: Backend,
313 $($l: Module<B> + Debug + Send + Clone,)*
314 {
315 type Record = ($($l::Record),*);
316
317 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
318 $(devices = self.$i.collect_devices(devices);)*
319 devices
320 }
321
322 fn fork(self, device: &Device<B>) -> Self {
323 ($(self.$i.fork(device),)*)
324 }
325
326 fn to_device(self, device: &Device<B>) -> Self {
327 ($(self.$i.to_device(device),)*)
328 }
329
330 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
331 $(
332 let index_str = $i.to_string();
333 visitor.enter_module(&index_str, "Tuple");
334 self.$i.visit(visitor);
335 visitor.exit_module(&index_str, "Tuple");
336 )*
337 }
338
339 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
340 ($(
341 {
342 let index_str = $i.to_string();
343 mapper.enter_module(&index_str, "Tuple");
344 let mapped = self.$i.map(mapper);
345 mapper.exit_module(&index_str, "Tuple");
346 mapped
347 }
348 ,)*)
349 }
350
351 fn load_record(self, record: Self::Record) -> Self {
352 ($(self.$i.load_record(record.$i),)*)
353 }
354
355 fn into_record(self) -> Self::Record {
356 ($(self.$i.into_record(),)*)
357 }
358 }
359
360 impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
361 where
362 B: AutodiffBackend,
363 $($l: AutodiffModule<B> + Debug + Send + Clone,)*
364 {
365 type InnerModule = ($($l::InnerModule,)*);
366
367 fn valid(&self) -> Self::InnerModule {
368 ($(self.$i.valid(),)*)
369 }
370
371 fn from_inner(module: Self::InnerModule) -> Self {
372 ($($l::from_inner(module.$i),)*)
373 }
374 }
375
376 impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
377 where
378 $($l: ModuleDisplay,)*
379 {
380 fn content(&self, content: Content) -> Option<Content> {
381 let content = content
382 $(.add(&format!("{}", $i), &self.$i))*
383 .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
384 content.optional()
385 }
386 }
387
388 impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
389
390 };
391}
392
393impl_module_tuple!([L0, L1][0, 1]);
394impl_module_tuple!([L0, L1, L2][0, 1, 2]);
395impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
396impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
397impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
398impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
399impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
400impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
401impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::TestBackend;
407
408 #[test]
409 fn dont_override_constant_module_when_loading_record() {
410 let module = Some(42);
411
412 let record = Module::<TestBackend>::into_record(module);
413 let loaded = Module::<TestBackend>::load_record(module, record);
414
415 assert_eq!(loaded, module);
416 }
417 #[test]
418 fn dont_override_constant_module_when_loading_none_record() {
419 let module = Some(42);
420
421 let record = None;
422 let loaded = Module::<TestBackend>::load_record(module, record);
423
424 assert_eq!(loaded, module);
425 }
426}