1use crate::module::{
2 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3 ModuleVisitor,
4};
5
6use alloc::{format, vec::Vec};
7
8use burn_tensor::{
9 backend::{AutodiffBackend, Backend},
10 ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16 T: Module<B> + Debug + Send + Clone,
17 B: Backend,
18{
19 type Record = Option<T::Record>;
20
21 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22 if let Some(module) = self {
23 module.visit(visitor)
24 }
25 }
26
27 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28 self.map(|module| module.map(mapper))
29 }
30
31 fn load_record(self, record: Self::Record) -> Self {
32 let is_constant = self.num_params() == 0;
33
34 if is_constant {
35 return self;
36 }
37
38 self.zip(record)
39 .map(|(module, record)| module.load_record(record))
40 }
41
42 fn into_record(self) -> Self::Record {
43 self.map(Module::into_record)
44 }
45
46 fn to_device(self, device: &Device<B>) -> Self {
47 self.map(|module| module.to_device(device))
48 }
49
50 fn fork(self, device: &Device<B>) -> Self {
51 self.map(|module| module.fork(device))
52 }
53
54 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55 if let Some(module) = self.as_ref() {
56 devices = module.collect_devices(devices);
57 }
58
59 devices
60 }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64 fn content(&self, content: Content) -> Option<Content> {
65 match self {
66 Some(module) => content.add_single(module).optional(),
67 None => content.add_single("None").optional(),
68 }
69 }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76 T: AutodiffModule<B> + Debug + Send + Clone,
77 B: AutodiffBackend,
78{
79 type InnerModule = Option<T::InnerModule>;
80
81 fn valid(&self) -> Self::InnerModule {
82 self.as_ref().map(|module| module.valid())
83 }
84}
85
86impl<T, B> Module<B> for Vec<T>
87where
88 T: Module<B> + Debug + Send + Clone,
89 B: Backend,
90{
91 type Record = Vec<T::Record>;
92
93 fn num_params(&self) -> usize {
94 let mut num_params = 0;
95 for module in self.iter() {
96 num_params += module.num_params();
97 }
98
99 num_params
100 }
101
102 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
103 for (i, module) in self.iter().enumerate() {
104 let index_str = alloc::format!("{}", i);
105 visitor.enter_module(&index_str, "Vec");
106 module.visit(visitor);
107 visitor.exit_module(&index_str, "Vec");
108 }
109 }
110
111 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
112 self.into_iter()
113 .enumerate()
114 .map(|(i, module)| {
115 let index_str = alloc::format!("{}", i);
116 mapper.enter_module(&index_str, "Vec");
117 let mapped = module.map(mapper);
118 mapper.exit_module(&index_str, "Vec");
119 mapped
120 })
121 .collect()
122 }
123
124 fn into_record(self) -> Self::Record {
125 self.into_iter().map(Module::into_record).collect()
126 }
127
128 fn load_record(self, record: Self::Record) -> Self {
129 assert_eq!(
130 self.len(),
131 record.len(),
132 r#"[Load Record Error] The vec record does not the same length as the module.
133 Make sure you module initialization is compatible with the record being loaded.
134 "#,
135 );
136
137 self.into_iter()
138 .zip(record)
139 .map(|(module, record)| module.load_record(record))
140 .collect()
141 }
142
143 fn to_device(self, device: &Device<B>) -> Self {
144 self.into_iter()
145 .map(|module| module.to_device(device))
146 .collect()
147 }
148
149 fn fork(self, device: &Device<B>) -> Self {
150 self.into_iter().map(|module| module.fork(device)).collect()
151 }
152
153 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
154 for module in self.iter() {
155 devices = module.collect_devices(devices);
156 }
157
158 devices
159 }
160}
161
162impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
163 fn content(&self, content: Content) -> Option<Content> {
164 self.iter()
165 .enumerate()
166 .fold(content, |acc, (i, module)| {
167 let index = format!("{i}");
168 acc.add(&index, module)
169 })
170 .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
171 .optional()
172 }
173}
174
175impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
176
177impl<T, B> AutodiffModule<B> for Vec<T>
178where
179 T: AutodiffModule<B> + Debug + Send + Clone,
180 B: AutodiffBackend,
181{
182 type InnerModule = Vec<T::InnerModule>;
183
184 fn valid(&self) -> Self::InnerModule {
185 self.iter().map(|module| module.valid()).collect()
186 }
187}
188
189impl<const N: usize, T, B> Module<B> for [T; N]
190where
191 T: Module<B> + Debug + Send + Clone,
192 B: Backend,
193{
194 type Record = [T::Record; N];
195
196 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
197 for module in self.iter() {
198 devices = module.collect_devices(devices);
199 }
200
201 devices
202 }
203
204 fn num_params(&self) -> usize {
205 let mut num_params = 0;
206 for module in self.iter() {
207 num_params += module.num_params();
208 }
209
210 num_params
211 }
212
213 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
214 for (i, module) in self.iter().enumerate() {
215 let index_str = alloc::format!("{}", i);
216 visitor.enter_module(&index_str, "Vec");
217 module.visit(visitor);
218 visitor.exit_module(&index_str, "Vec");
219 }
220 }
221
222 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
223 let mut result = Vec::with_capacity(N);
224 for (i, module) in IntoIterator::into_iter(self).enumerate() {
225 let index_str = alloc::format!("{}", i);
226 mapper.enter_module(&index_str, "Array");
227 let mapped = module.map(mapper);
228 mapper.exit_module(&index_str, "Array");
229 result.push(mapped);
230 }
231 result
232 .try_into()
233 .unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
234 }
235
236 fn load_record(self, record: Self::Record) -> Self {
237 self.into_iter()
238 .zip(record)
239 .map(|(module, record)| module.load_record(record))
240 .collect::<Vec<_>>()
241 .try_into()
242 .unwrap()
243 }
244
245 fn into_record(self) -> Self::Record {
246 self.map(Module::into_record)
247 }
248
249 fn to_device(self, device: &Device<B>) -> Self {
250 self.map(|module| module.to_device(device))
251 }
252
253 fn fork(self, device: &Device<B>) -> Self {
254 self.map(|module| module.fork(device))
255 }
256}
257
258impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
259 fn content(&self, content: Content) -> Option<Content> {
260 self.iter()
261 .enumerate()
262 .fold(content, |acc, (i, module)| {
263 let index = format!("{i}");
264 acc.add(&index, module)
265 })
266 .set_top_level_type(format!("[0..{}]", self.len()).as_str())
267 .optional()
268 }
269}
270
271impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
272
273impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
274where
275 T: AutodiffModule<B> + Debug + Send + Clone,
276 T::InnerModule: Debug,
277 B: AutodiffBackend,
278{
279 type InnerModule = [T::InnerModule; N];
280
281 fn valid(&self) -> Self::InnerModule {
282 self.clone().map(|module| module.valid())
283 }
284}
285
286macro_rules! impl_module_tuple {
292 ([$($l:ident),*][$($i:tt),*]) => {
295 impl<B, $($l,)*> Module<B> for ($($l,)*)
296 where
297 B: Backend,
298 $($l: Module<B> + Debug + Send + Clone,)*
299 {
300 type Record = ($($l::Record),*);
301
302 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
303 $(devices = self.$i.collect_devices(devices);)*
304 devices
305 }
306
307 fn fork(self, device: &Device<B>) -> Self {
308 ($(self.$i.fork(device),)*)
309 }
310
311 fn to_device(self, device: &Device<B>) -> Self {
312 ($(self.$i.to_device(device),)*)
313 }
314
315 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
316 $(self.$i.visit(visitor);)*
317 }
318
319 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
320 ($(self.$i.map(mapper),)*)
321 }
322
323 fn load_record(self, record: Self::Record) -> Self {
324 ($(self.$i.load_record(record.$i),)*)
325 }
326
327 fn into_record(self) -> Self::Record {
328 ($(self.$i.into_record(),)*)
329 }
330 }
331
332 impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
333 where
334 B: AutodiffBackend,
335 $($l: AutodiffModule<B> + Debug + Send + Clone,)*
336 {
337 type InnerModule = ($($l::InnerModule,)*);
338
339 fn valid(&self) -> Self::InnerModule {
340 ($(self.$i.valid(),)*)
341 }
342 }
343
344 impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
345 where
346 $($l: ModuleDisplay,)*
347 {
348 fn content(&self, content: Content) -> Option<Content> {
349 let content = content
350 $(.add(&format!("{}", $i), &self.$i))*
351 .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
352 content.optional()
353 }
354 }
355
356 impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
357
358 };
359}
360
361impl_module_tuple!([L0, L1][0, 1]);
362impl_module_tuple!([L0, L1, L2][0, 1, 2]);
363impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
364impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
365impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
366impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
367impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
368impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
369impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::TestBackend;
375
376 #[test]
377 fn dont_override_constant_module_when_loading_record() {
378 let module = Some(42);
379
380 let record = Module::<TestBackend>::into_record(module);
381 let loaded = Module::<TestBackend>::load_record(module, record);
382
383 assert_eq!(loaded, module);
384 }
385 #[test]
386 fn dont_override_constant_module_when_loading_none_record() {
387 let module = Some(42);
388
389 let record = None;
390 let loaded = Module::<TestBackend>::load_record(module, record);
391
392 assert_eq!(loaded, module);
393 }
394}