Skip to main content

burn_core/module/param/
primitive.rs

1use crate::module::{
2    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3    ModuleVisitor,
4};
5
6use alloc::{format, string::ToString, vec::Vec};
7
8use burn_tensor::{
9    backend::{AutodiffBackend, Backend},
10    ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16    T: Module<B> + Debug + Send + Clone,
17    B: Backend,
18{
19    type Record = Option<T::Record>;
20
21    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22        if let Some(module) = self {
23            module.visit(visitor)
24        }
25    }
26
27    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28        self.map(|module| module.map(mapper))
29    }
30
31    fn load_record(self, record: Self::Record) -> Self {
32        let is_constant = self.num_params() == 0;
33
34        if is_constant {
35            return self;
36        }
37
38        self.zip(record)
39            .map(|(module, record)| module.load_record(record))
40    }
41
42    fn into_record(self) -> Self::Record {
43        self.map(Module::into_record)
44    }
45
46    fn to_device(self, device: &Device<B>) -> Self {
47        self.map(|module| module.to_device(device))
48    }
49
50    fn fork(self, device: &Device<B>) -> Self {
51        self.map(|module| module.fork(device))
52    }
53
54    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55        if let Some(module) = self.as_ref() {
56            devices = module.collect_devices(devices);
57        }
58
59        devices
60    }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64    fn content(&self, content: Content) -> Option<Content> {
65        match self {
66            Some(module) => content.add_single(module).optional(),
67            None => content.add_single("None").optional(),
68        }
69    }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76    T: AutodiffModule<B> + Debug + Send + Clone,
77    B: AutodiffBackend,
78{
79    type InnerModule = Option<T::InnerModule>;
80
81    fn valid(&self) -> Self::InnerModule {
82        self.as_ref().map(|module| module.valid())
83    }
84
85    fn from_inner(module: Self::InnerModule) -> Self {
86        module.map(|module| T::from_inner(module))
87    }
88}
89
90impl<T, B> Module<B> for Vec<T>
91where
92    T: Module<B> + Debug + Send + Clone,
93    B: Backend,
94{
95    type Record = Vec<T::Record>;
96
97    fn num_params(&self) -> usize {
98        let mut num_params = 0;
99        for module in self.iter() {
100            num_params += module.num_params();
101        }
102
103        num_params
104    }
105
106    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
107        for (i, module) in self.iter().enumerate() {
108            let index_str = alloc::format!("{}", i);
109            visitor.enter_module(&index_str, "Vec");
110            module.visit(visitor);
111            visitor.exit_module(&index_str, "Vec");
112        }
113    }
114
115    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
116        self.into_iter()
117            .enumerate()
118            .map(|(i, module)| {
119                let index_str = alloc::format!("{}", i);
120                mapper.enter_module(&index_str, "Vec");
121                let mapped = module.map(mapper);
122                mapper.exit_module(&index_str, "Vec");
123                mapped
124            })
125            .collect()
126    }
127
128    fn into_record(self) -> Self::Record {
129        self.into_iter().map(Module::into_record).collect()
130    }
131
132    fn load_record(self, record: Self::Record) -> Self {
133        assert_eq!(
134            self.len(),
135            record.len(),
136            r#"[Load Record Error] The vec record does not the same length as the module.
137            Make sure you module initialization is compatible with the record being loaded.
138            "#,
139        );
140
141        self.into_iter()
142            .zip(record)
143            .map(|(module, record)| module.load_record(record))
144            .collect()
145    }
146
147    fn to_device(self, device: &Device<B>) -> Self {
148        self.into_iter()
149            .map(|module| module.to_device(device))
150            .collect()
151    }
152
153    fn fork(self, device: &Device<B>) -> Self {
154        self.into_iter().map(|module| module.fork(device)).collect()
155    }
156
157    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
158        for module in self.iter() {
159            devices = module.collect_devices(devices);
160        }
161
162        devices
163    }
164}
165
166impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
167    fn content(&self, content: Content) -> Option<Content> {
168        self.iter()
169            .enumerate()
170            .fold(content, |acc, (i, module)| {
171                let index = format!("{i}");
172                acc.add(&index, module)
173            })
174            .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
175            .optional()
176    }
177}
178
179impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
180
181impl<T, B> AutodiffModule<B> for Vec<T>
182where
183    T: AutodiffModule<B> + Debug + Send + Clone,
184    B: AutodiffBackend,
185{
186    type InnerModule = Vec<T::InnerModule>;
187
188    fn valid(&self) -> Self::InnerModule {
189        self.iter().map(|module| module.valid()).collect()
190    }
191
192    fn from_inner(module: Self::InnerModule) -> Self {
193        module
194            .into_iter()
195            .map(|module| T::from_inner(module))
196            .collect()
197    }
198}
199
200impl<const N: usize, T, B> Module<B> for [T; N]
201where
202    T: Module<B> + Debug + Send + Clone,
203    B: Backend,
204{
205    type Record = [T::Record; N];
206
207    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
208        for module in self.iter() {
209            devices = module.collect_devices(devices);
210        }
211
212        devices
213    }
214
215    fn num_params(&self) -> usize {
216        let mut num_params = 0;
217        for module in self.iter() {
218            num_params += module.num_params();
219        }
220
221        num_params
222    }
223
224    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
225        for (i, module) in self.iter().enumerate() {
226            let index_str = alloc::format!("{}", i);
227            visitor.enter_module(&index_str, "Array");
228            module.visit(visitor);
229            visitor.exit_module(&index_str, "Array");
230        }
231    }
232
233    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
234        let mut result = Vec::with_capacity(N);
235        for (i, module) in IntoIterator::into_iter(self).enumerate() {
236            let index_str = alloc::format!("{}", i);
237            mapper.enter_module(&index_str, "Array");
238            let mapped = module.map(mapper);
239            mapper.exit_module(&index_str, "Array");
240            result.push(mapped);
241        }
242        result
243            .try_into()
244            .unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
245    }
246
247    fn load_record(self, record: Self::Record) -> Self {
248        self.into_iter()
249            .zip(record)
250            .map(|(module, record)| module.load_record(record))
251            .collect::<Vec<_>>()
252            .try_into()
253            .unwrap()
254    }
255
256    fn into_record(self) -> Self::Record {
257        self.map(Module::into_record)
258    }
259
260    fn to_device(self, device: &Device<B>) -> Self {
261        self.map(|module| module.to_device(device))
262    }
263
264    fn fork(self, device: &Device<B>) -> Self {
265        self.map(|module| module.fork(device))
266    }
267}
268
269impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
270    fn content(&self, content: Content) -> Option<Content> {
271        self.iter()
272            .enumerate()
273            .fold(content, |acc, (i, module)| {
274                let index = format!("{i}");
275                acc.add(&index, module)
276            })
277            .set_top_level_type(format!("[0..{}]", self.len()).as_str())
278            .optional()
279    }
280}
281
282impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
283
284impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
285where
286    T: AutodiffModule<B> + Debug + Send + Clone,
287    T::InnerModule: Debug,
288    B: AutodiffBackend,
289{
290    type InnerModule = [T::InnerModule; N];
291
292    fn valid(&self) -> Self::InnerModule {
293        self.clone().map(|module| module.valid())
294    }
295
296    fn from_inner(module: Self::InnerModule) -> Self {
297        module.map(|module| T::from_inner(module))
298    }
299}
300
301/// A macro for generating implementations for tuple modules of different sizes.
302/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
303/// Would generate an implementation for a tuple of size 2.
304/// For this macro to work properly, please adhere to the convention:
305/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
306macro_rules! impl_module_tuple {
307    // `$l` represents the generic modules.
308    // `$i` represents the indices of the modules in the tuple.
309    ([$($l:ident),*][$($i:tt),*]) => {
310        impl<B, $($l,)*> Module<B> for ($($l,)*)
311        where
312            B: Backend,
313            $($l: Module<B> + Debug + Send + Clone,)*
314        {
315            type Record = ($($l::Record),*);
316
317            fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
318                $(devices = self.$i.collect_devices(devices);)*
319                devices
320            }
321
322            fn fork(self, device: &Device<B>) -> Self {
323                ($(self.$i.fork(device),)*)
324            }
325
326            fn to_device(self, device: &Device<B>) -> Self {
327                ($(self.$i.to_device(device),)*)
328            }
329
330            fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
331                $(
332                    let index_str = $i.to_string();
333                    visitor.enter_module(&index_str, "Tuple");
334                    self.$i.visit(visitor);
335                    visitor.exit_module(&index_str, "Tuple");
336                )*
337            }
338
339            fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
340                ($(
341                    {
342                        let index_str = $i.to_string();
343                        mapper.enter_module(&index_str, "Tuple");
344                        let mapped = self.$i.map(mapper);
345                        mapper.exit_module(&index_str, "Tuple");
346                        mapped
347                    }
348                ,)*)
349            }
350
351            fn load_record(self, record: Self::Record) -> Self {
352                ($(self.$i.load_record(record.$i),)*)
353            }
354
355            fn into_record(self) -> Self::Record {
356                ($(self.$i.into_record(),)*)
357            }
358        }
359
360        impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
361        where
362            B: AutodiffBackend,
363            $($l: AutodiffModule<B> + Debug + Send + Clone,)*
364        {
365            type InnerModule = ($($l::InnerModule,)*);
366
367            fn valid(&self) -> Self::InnerModule {
368                ($(self.$i.valid(),)*)
369            }
370
371            fn from_inner(module: Self::InnerModule) -> Self {
372                ($($l::from_inner(module.$i),)*)
373            }
374        }
375
376        impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
377        where
378            $($l: ModuleDisplay,)*
379        {
380            fn content(&self, content: Content) -> Option<Content> {
381                let content = content
382                    $(.add(&format!("{}", $i), &self.$i))*
383                    .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
384                content.optional()
385            }
386        }
387
388        impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
389
390    };
391}
392
393impl_module_tuple!([L0, L1][0, 1]);
394impl_module_tuple!([L0, L1, L2][0, 1, 2]);
395impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
396impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
397impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
398impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
399impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
400impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
401impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::TestBackend;
407
408    #[test]
409    fn dont_override_constant_module_when_loading_record() {
410        let module = Some(42);
411
412        let record = Module::<TestBackend>::into_record(module);
413        let loaded = Module::<TestBackend>::load_record(module, record);
414
415        assert_eq!(loaded, module);
416    }
417    #[test]
418    fn dont_override_constant_module_when_loading_none_record() {
419        let module = Some(42);
420
421        let record = None;
422        let loaded = Module::<TestBackend>::load_record(module, record);
423
424        assert_eq!(loaded, module);
425    }
426}