1use crate::module::{
2 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3 ModuleVisitor,
4};
5
6use alloc::{format, vec::Vec};
7
8use burn_tensor::{
9 backend::{AutodiffBackend, Backend},
10 ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16 T: Module<B> + Debug + Send + Clone,
17 B: Backend,
18{
19 type Record = Option<T::Record>;
20
21 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22 if let Some(module) = self {
23 module.visit(visitor)
24 }
25 }
26
27 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28 self.map(|module| module.map(mapper))
29 }
30
31 fn load_record(self, record: Self::Record) -> Self {
32 let is_constant = self.num_params() == 0;
33
34 if is_constant {
35 return self;
36 }
37
38 self.zip(record)
39 .map(|(module, record)| module.load_record(record))
40 }
41
42 fn into_record(self) -> Self::Record {
43 self.map(Module::into_record)
44 }
45
46 fn to_device(self, device: &Device<B>) -> Self {
47 self.map(|module| module.to_device(device))
48 }
49
50 fn fork(self, device: &Device<B>) -> Self {
51 self.map(|module| module.fork(device))
52 }
53
54 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55 if let Some(module) = self.as_ref() {
56 devices = module.collect_devices(devices);
57 }
58
59 devices
60 }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64 fn content(&self, content: Content) -> Option<Content> {
65 match self {
66 Some(module) => content.add_single(module).optional(),
67 None => content.add_single("None").optional(),
68 }
69 }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76 T: AutodiffModule<B> + Debug + Send + Clone,
77 B: AutodiffBackend,
78{
79 type InnerModule = Option<T::InnerModule>;
80
81 fn valid(&self) -> Self::InnerModule {
82 self.as_ref().map(|module| module.valid())
83 }
84}
85
86impl<T, B> Module<B> for Vec<T>
87where
88 T: Module<B> + Debug + Send + Clone,
89 B: Backend,
90{
91 type Record = Vec<T::Record>;
92
93 fn num_params(&self) -> usize {
94 let mut num_params = 0;
95 for module in self.iter() {
96 num_params += module.num_params();
97 }
98
99 num_params
100 }
101
102 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
103 self.iter().for_each(|module| {
104 module.visit(visitor);
105 });
106 }
107
108 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
109 self.into_iter().map(|module| module.map(mapper)).collect()
110 }
111
112 fn into_record(self) -> Self::Record {
113 self.into_iter().map(Module::into_record).collect()
114 }
115
116 fn load_record(self, record: Self::Record) -> Self {
117 assert_eq!(
118 self.len(),
119 record.len(),
120 r#"[Load Record Error] The vec record does not the same length as the module.
121 Make sure you module initialization is compatible with the record being loaded.
122 "#,
123 );
124
125 self.into_iter()
126 .zip(record)
127 .map(|(module, record)| module.load_record(record))
128 .collect()
129 }
130
131 fn to_device(self, device: &Device<B>) -> Self {
132 self.into_iter()
133 .map(|module| module.to_device(device))
134 .collect()
135 }
136
137 fn fork(self, device: &Device<B>) -> Self {
138 self.into_iter().map(|module| module.fork(device)).collect()
139 }
140
141 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
142 for module in self.iter() {
143 devices = module.collect_devices(devices);
144 }
145
146 devices
147 }
148}
149
150impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
151 fn content(&self, content: Content) -> Option<Content> {
152 self.iter()
153 .enumerate()
154 .fold(content, |acc, (i, module)| {
155 let index = format!("{i}");
156 acc.add(&index, module)
157 })
158 .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
159 .optional()
160 }
161}
162
163impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
164
165impl<T, B> AutodiffModule<B> for Vec<T>
166where
167 T: AutodiffModule<B> + Debug + Send + Clone,
168 B: AutodiffBackend,
169{
170 type InnerModule = Vec<T::InnerModule>;
171
172 fn valid(&self) -> Self::InnerModule {
173 self.iter().map(|module| module.valid()).collect()
174 }
175}
176
177impl<const N: usize, T, B> Module<B> for [T; N]
178where
179 T: Module<B> + Debug + Send + Clone,
180 B: Backend,
181{
182 type Record = [T::Record; N];
183
184 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
185 for module in self.iter() {
186 devices = module.collect_devices(devices);
187 }
188
189 devices
190 }
191
192 fn num_params(&self) -> usize {
193 let mut num_params = 0;
194 for module in self.iter() {
195 num_params += module.num_params();
196 }
197
198 num_params
199 }
200
201 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
202 self.iter().for_each(|module| {
203 module.visit(visitor);
204 });
205 }
206
207 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
208 self.map(|module| module.map(mapper))
209 }
210
211 fn load_record(self, record: Self::Record) -> Self {
212 self.into_iter()
213 .zip(record)
214 .map(|(module, record)| module.load_record(record))
215 .collect::<Vec<_>>()
216 .try_into()
217 .unwrap()
218 }
219
220 fn into_record(self) -> Self::Record {
221 self.map(Module::into_record)
222 }
223
224 fn to_device(self, device: &Device<B>) -> Self {
225 self.map(|module| module.to_device(device))
226 }
227
228 fn fork(self, device: &Device<B>) -> Self {
229 self.map(|module| module.fork(device))
230 }
231}
232
233impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
234 fn content(&self, content: Content) -> Option<Content> {
235 self.iter()
236 .enumerate()
237 .fold(content, |acc, (i, module)| {
238 let index = format!("{i}");
239 acc.add(&index, module)
240 })
241 .set_top_level_type(format!("[0..{}]", self.len()).as_str())
242 .optional()
243 }
244}
245
246impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
247
248impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
249where
250 T: AutodiffModule<B> + Debug + Send + Clone,
251 T::InnerModule: Debug,
252 B: AutodiffBackend,
253{
254 type InnerModule = [T::InnerModule; N];
255
256 fn valid(&self) -> Self::InnerModule {
257 self.clone().map(|module| module.valid())
258 }
259}
260
261macro_rules! impl_module_tuple {
267 ([$($l:ident),*][$($i:tt),*]) => {
270 impl<B, $($l,)*> Module<B> for ($($l,)*)
271 where
272 B: Backend,
273 $($l: Module<B> + Debug + Send + Clone,)*
274 {
275 type Record = ($($l::Record),*);
276
277 fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
278 $(devices = self.$i.collect_devices(devices);)*
279 devices
280 }
281
282 fn fork(self, device: &Device<B>) -> Self {
283 ($(self.$i.fork(device),)*)
284 }
285
286 fn to_device(self, device: &Device<B>) -> Self {
287 ($(self.$i.to_device(device),)*)
288 }
289
290 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
291 $(self.$i.visit(visitor);)*
292 }
293
294 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
295 ($(self.$i.map(mapper),)*)
296 }
297
298 fn load_record(self, record: Self::Record) -> Self {
299 ($(self.$i.load_record(record.$i),)*)
300 }
301
302 fn into_record(self) -> Self::Record {
303 ($(self.$i.into_record(),)*)
304 }
305 }
306
307 impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
308 where
309 B: AutodiffBackend,
310 $($l: AutodiffModule<B> + Debug + Send + Clone,)*
311 {
312 type InnerModule = ($($l::InnerModule,)*);
313
314 fn valid(&self) -> Self::InnerModule {
315 ($(self.$i.valid(),)*)
316 }
317 }
318
319 impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
320 where
321 $($l: ModuleDisplay,)*
322 {
323 fn content(&self, content: Content) -> Option<Content> {
324 let content = content
325 $(.add(&format!("{}", $i), &self.$i))*
326 .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
327 content.optional()
328 }
329 }
330
331 impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
332
333 };
334}
335
336impl_module_tuple!([L0, L1][0, 1]);
337impl_module_tuple!([L0, L1, L2][0, 1, 2]);
338impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
339impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
340impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
341impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
342impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
343impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
344impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::TestBackend;
350
351 #[test]
352 fn dont_override_constant_module_when_loading_record() {
353 let module = Some(42);
354
355 let record = Module::<TestBackend>::into_record(module);
356 let loaded = Module::<TestBackend>::load_record(module, record);
357
358 assert_eq!(loaded, module);
359 }
360 #[test]
361 fn dont_override_constant_module_when_loading_none_record() {
362 let module = Some(42);
363
364 let record = None;
365 let loaded = Module::<TestBackend>::load_record(module, record);
366
367 assert_eq!(loaded, module);
368 }
369}