burn_core/module/param/
primitive.rs

1use crate::module::{
2    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3    ModuleVisitor,
4};
5
6use alloc::{format, vec::Vec};
7
8use burn_tensor::{
9    backend::{AutodiffBackend, Backend},
10    ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16    T: Module<B> + Debug + Send + Clone,
17    B: Backend,
18{
19    type Record = Option<T::Record>;
20
21    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22        if let Some(module) = self {
23            module.visit(visitor)
24        }
25    }
26
27    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28        self.map(|module| module.map(mapper))
29    }
30
31    fn load_record(self, record: Self::Record) -> Self {
32        let is_constant = self.num_params() == 0;
33
34        if is_constant {
35            return self;
36        }
37
38        self.zip(record)
39            .map(|(module, record)| module.load_record(record))
40    }
41
42    fn into_record(self) -> Self::Record {
43        self.map(Module::into_record)
44    }
45
46    fn to_device(self, device: &Device<B>) -> Self {
47        self.map(|module| module.to_device(device))
48    }
49
50    fn fork(self, device: &Device<B>) -> Self {
51        self.map(|module| module.fork(device))
52    }
53
54    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55        if let Some(module) = self.as_ref() {
56            devices = module.collect_devices(devices);
57        }
58
59        devices
60    }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64    fn content(&self, content: Content) -> Option<Content> {
65        match self {
66            Some(module) => content.add_single(module).optional(),
67            None => content.add_single("None").optional(),
68        }
69    }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76    T: AutodiffModule<B> + Debug + Send + Clone,
77    B: AutodiffBackend,
78{
79    type InnerModule = Option<T::InnerModule>;
80
81    fn valid(&self) -> Self::InnerModule {
82        self.as_ref().map(|module| module.valid())
83    }
84}
85
86impl<T, B> Module<B> for Vec<T>
87where
88    T: Module<B> + Debug + Send + Clone,
89    B: Backend,
90{
91    type Record = Vec<T::Record>;
92
93    fn num_params(&self) -> usize {
94        let mut num_params = 0;
95        for module in self.iter() {
96            num_params += module.num_params();
97        }
98
99        num_params
100    }
101
102    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
103        self.iter().for_each(|module| {
104            module.visit(visitor);
105        });
106    }
107
108    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
109        self.into_iter().map(|module| module.map(mapper)).collect()
110    }
111
112    fn into_record(self) -> Self::Record {
113        self.into_iter().map(Module::into_record).collect()
114    }
115
116    fn load_record(self, record: Self::Record) -> Self {
117        assert_eq!(
118            self.len(),
119            record.len(),
120            r#"[Load Record Error] The vec record does not the same length as the module.
121            Make sure you module initialization is compatible with the record being loaded.
122            "#,
123        );
124
125        self.into_iter()
126            .zip(record)
127            .map(|(module, record)| module.load_record(record))
128            .collect()
129    }
130
131    fn to_device(self, device: &Device<B>) -> Self {
132        self.into_iter()
133            .map(|module| module.to_device(device))
134            .collect()
135    }
136
137    fn fork(self, device: &Device<B>) -> Self {
138        self.into_iter().map(|module| module.fork(device)).collect()
139    }
140
141    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
142        for module in self.iter() {
143            devices = module.collect_devices(devices);
144        }
145
146        devices
147    }
148}
149
150impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
151    fn content(&self, content: Content) -> Option<Content> {
152        self.iter()
153            .enumerate()
154            .fold(content, |acc, (i, module)| {
155                let index = format!("{i}");
156                acc.add(&index, module)
157            })
158            .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
159            .optional()
160    }
161}
162
163impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
164
165impl<T, B> AutodiffModule<B> for Vec<T>
166where
167    T: AutodiffModule<B> + Debug + Send + Clone,
168    B: AutodiffBackend,
169{
170    type InnerModule = Vec<T::InnerModule>;
171
172    fn valid(&self) -> Self::InnerModule {
173        self.iter().map(|module| module.valid()).collect()
174    }
175}
176
177impl<const N: usize, T, B> Module<B> for [T; N]
178where
179    T: Module<B> + Debug + Send + Clone,
180    B: Backend,
181{
182    type Record = [T::Record; N];
183
184    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
185        for module in self.iter() {
186            devices = module.collect_devices(devices);
187        }
188
189        devices
190    }
191
192    fn num_params(&self) -> usize {
193        let mut num_params = 0;
194        for module in self.iter() {
195            num_params += module.num_params();
196        }
197
198        num_params
199    }
200
201    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
202        self.iter().for_each(|module| {
203            module.visit(visitor);
204        });
205    }
206
207    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
208        self.map(|module| module.map(mapper))
209    }
210
211    fn load_record(self, record: Self::Record) -> Self {
212        self.into_iter()
213            .zip(record)
214            .map(|(module, record)| module.load_record(record))
215            .collect::<Vec<_>>()
216            .try_into()
217            .unwrap()
218    }
219
220    fn into_record(self) -> Self::Record {
221        self.map(Module::into_record)
222    }
223
224    fn to_device(self, device: &Device<B>) -> Self {
225        self.map(|module| module.to_device(device))
226    }
227
228    fn fork(self, device: &Device<B>) -> Self {
229        self.map(|module| module.fork(device))
230    }
231}
232
233impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
234    fn content(&self, content: Content) -> Option<Content> {
235        self.iter()
236            .enumerate()
237            .fold(content, |acc, (i, module)| {
238                let index = format!("{i}");
239                acc.add(&index, module)
240            })
241            .set_top_level_type(format!("[0..{}]", self.len()).as_str())
242            .optional()
243    }
244}
245
246impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
247
248impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
249where
250    T: AutodiffModule<B> + Debug + Send + Clone,
251    T::InnerModule: Debug,
252    B: AutodiffBackend,
253{
254    type InnerModule = [T::InnerModule; N];
255
256    fn valid(&self) -> Self::InnerModule {
257        self.clone().map(|module| module.valid())
258    }
259}
260
261/// A macro for generating implementations for tuple modules of different sizes.
262/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
263/// Would generate an implementation for a tuple of size 2.
264/// For this macro to work properly, please adhear to the convention:
265/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
266macro_rules! impl_module_tuple {
267    // `$l` represents the generic modules.
268    // `$i` represents the indices of the modules in the tuple.
269    ([$($l:ident),*][$($i:tt),*]) => {
270        impl<B, $($l,)*> Module<B> for ($($l,)*)
271        where
272            B: Backend,
273            $($l: Module<B> + Debug + Send + Clone,)*
274        {
275            type Record = ($($l::Record),*);
276
277            fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
278                $(devices = self.$i.collect_devices(devices);)*
279                devices
280            }
281
282            fn fork(self, device: &Device<B>) -> Self {
283                ($(self.$i.fork(device),)*)
284            }
285
286            fn to_device(self, device: &Device<B>) -> Self {
287                ($(self.$i.to_device(device),)*)
288            }
289
290            fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
291                $(self.$i.visit(visitor);)*
292            }
293
294            fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
295                ($(self.$i.map(mapper),)*)
296            }
297
298            fn load_record(self, record: Self::Record) -> Self {
299                ($(self.$i.load_record(record.$i),)*)
300            }
301
302            fn into_record(self) -> Self::Record {
303                ($(self.$i.into_record(),)*)
304            }
305        }
306
307        impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
308        where
309            B: AutodiffBackend,
310            $($l: AutodiffModule<B> + Debug + Send + Clone,)*
311        {
312            type InnerModule = ($($l::InnerModule,)*);
313
314            fn valid(&self) -> Self::InnerModule {
315                ($(self.$i.valid(),)*)
316            }
317        }
318
319        impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
320        where
321            $($l: ModuleDisplay,)*
322        {
323            fn content(&self, content: Content) -> Option<Content> {
324                let content = content
325                    $(.add(&format!("{}", $i), &self.$i))*
326                    .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
327                content.optional()
328            }
329        }
330
331        impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
332
333    };
334}
335
336impl_module_tuple!([L0, L1][0, 1]);
337impl_module_tuple!([L0, L1, L2][0, 1, 2]);
338impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
339impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
340impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
341impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
342impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
343impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
344impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::TestBackend;
350
351    #[test]
352    fn dont_override_constant_module_when_loading_record() {
353        let module = Some(42);
354
355        let record = Module::<TestBackend>::into_record(module);
356        let loaded = Module::<TestBackend>::load_record(module, record);
357
358        assert_eq!(loaded, module);
359    }
360    #[test]
361    fn dont_override_constant_module_when_loading_none_record() {
362        let module = Some(42);
363
364        let record = None;
365        let loaded = Module::<TestBackend>::load_record(module, record);
366
367        assert_eq!(loaded, module);
368    }
369}