burn_core/module/param/
primitive.rs

1use crate::module::{
2    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
3    ModuleVisitor,
4};
5
6use alloc::{format, vec::Vec};
7
8use burn_tensor::{
9    backend::{AutodiffBackend, Backend},
10    ops::Device,
11};
12use core::fmt::Debug;
13
14impl<T, B> Module<B> for Option<T>
15where
16    T: Module<B> + Debug + Send + Clone,
17    B: Backend,
18{
19    type Record = Option<T::Record>;
20
21    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
22        if let Some(module) = self {
23            module.visit(visitor)
24        }
25    }
26
27    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
28        self.map(|module| module.map(mapper))
29    }
30
31    fn load_record(self, record: Self::Record) -> Self {
32        let is_constant = self.num_params() == 0;
33
34        if is_constant {
35            return self;
36        }
37
38        self.zip(record)
39            .map(|(module, record)| module.load_record(record))
40    }
41
42    fn into_record(self) -> Self::Record {
43        self.map(Module::into_record)
44    }
45
46    fn to_device(self, device: &Device<B>) -> Self {
47        self.map(|module| module.to_device(device))
48    }
49
50    fn fork(self, device: &Device<B>) -> Self {
51        self.map(|module| module.fork(device))
52    }
53
54    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
55        if let Some(module) = self.as_ref() {
56            devices = module.collect_devices(devices);
57        }
58
59        devices
60    }
61}
62
63impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
64    fn content(&self, content: Content) -> Option<Content> {
65        match self {
66            Some(module) => content.add_single(module).optional(),
67            None => content.add_single("None").optional(),
68        }
69    }
70}
71
72impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
73
74impl<T, B> AutodiffModule<B> for Option<T>
75where
76    T: AutodiffModule<B> + Debug + Send + Clone,
77    B: AutodiffBackend,
78{
79    type InnerModule = Option<T::InnerModule>;
80
81    fn valid(&self) -> Self::InnerModule {
82        self.as_ref().map(|module| module.valid())
83    }
84}
85
86impl<T, B> Module<B> for Vec<T>
87where
88    T: Module<B> + Debug + Send + Clone,
89    B: Backend,
90{
91    type Record = Vec<T::Record>;
92
93    fn num_params(&self) -> usize {
94        let mut num_params = 0;
95        for module in self.iter() {
96            num_params += module.num_params();
97        }
98
99        num_params
100    }
101
102    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
103        for (i, module) in self.iter().enumerate() {
104            let index_str = alloc::format!("{}", i);
105            visitor.enter_module(&index_str, "Vec");
106            module.visit(visitor);
107            visitor.exit_module(&index_str, "Vec");
108        }
109    }
110
111    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
112        self.into_iter()
113            .enumerate()
114            .map(|(i, module)| {
115                let index_str = alloc::format!("{}", i);
116                mapper.enter_module(&index_str, "Vec");
117                let mapped = module.map(mapper);
118                mapper.exit_module(&index_str, "Vec");
119                mapped
120            })
121            .collect()
122    }
123
124    fn into_record(self) -> Self::Record {
125        self.into_iter().map(Module::into_record).collect()
126    }
127
128    fn load_record(self, record: Self::Record) -> Self {
129        assert_eq!(
130            self.len(),
131            record.len(),
132            r#"[Load Record Error] The vec record does not the same length as the module.
133            Make sure you module initialization is compatible with the record being loaded.
134            "#,
135        );
136
137        self.into_iter()
138            .zip(record)
139            .map(|(module, record)| module.load_record(record))
140            .collect()
141    }
142
143    fn to_device(self, device: &Device<B>) -> Self {
144        self.into_iter()
145            .map(|module| module.to_device(device))
146            .collect()
147    }
148
149    fn fork(self, device: &Device<B>) -> Self {
150        self.into_iter().map(|module| module.fork(device)).collect()
151    }
152
153    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
154        for module in self.iter() {
155            devices = module.collect_devices(devices);
156        }
157
158        devices
159    }
160}
161
162impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
163    fn content(&self, content: Content) -> Option<Content> {
164        self.iter()
165            .enumerate()
166            .fold(content, |acc, (i, module)| {
167                let index = format!("{i}");
168                acc.add(&index, module)
169            })
170            .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
171            .optional()
172    }
173}
174
175impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
176
177impl<T, B> AutodiffModule<B> for Vec<T>
178where
179    T: AutodiffModule<B> + Debug + Send + Clone,
180    B: AutodiffBackend,
181{
182    type InnerModule = Vec<T::InnerModule>;
183
184    fn valid(&self) -> Self::InnerModule {
185        self.iter().map(|module| module.valid()).collect()
186    }
187}
188
189impl<const N: usize, T, B> Module<B> for [T; N]
190where
191    T: Module<B> + Debug + Send + Clone,
192    B: Backend,
193{
194    type Record = [T::Record; N];
195
196    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
197        for module in self.iter() {
198            devices = module.collect_devices(devices);
199        }
200
201        devices
202    }
203
204    fn num_params(&self) -> usize {
205        let mut num_params = 0;
206        for module in self.iter() {
207            num_params += module.num_params();
208        }
209
210        num_params
211    }
212
213    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
214        for (i, module) in self.iter().enumerate() {
215            let index_str = alloc::format!("{}", i);
216            visitor.enter_module(&index_str, "Vec");
217            module.visit(visitor);
218            visitor.exit_module(&index_str, "Vec");
219        }
220    }
221
222    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
223        let mut result = Vec::with_capacity(N);
224        for (i, module) in IntoIterator::into_iter(self).enumerate() {
225            let index_str = alloc::format!("{}", i);
226            mapper.enter_module(&index_str, "Array");
227            let mapped = module.map(mapper);
228            mapper.exit_module(&index_str, "Array");
229            result.push(mapped);
230        }
231        result
232            .try_into()
233            .unwrap_or_else(|v: Vec<T>| panic!("Expected array of length {}, got {}", N, v.len()))
234    }
235
236    fn load_record(self, record: Self::Record) -> Self {
237        self.into_iter()
238            .zip(record)
239            .map(|(module, record)| module.load_record(record))
240            .collect::<Vec<_>>()
241            .try_into()
242            .unwrap()
243    }
244
245    fn into_record(self) -> Self::Record {
246        self.map(Module::into_record)
247    }
248
249    fn to_device(self, device: &Device<B>) -> Self {
250        self.map(|module| module.to_device(device))
251    }
252
253    fn fork(self, device: &Device<B>) -> Self {
254        self.map(|module| module.fork(device))
255    }
256}
257
258impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
259    fn content(&self, content: Content) -> Option<Content> {
260        self.iter()
261            .enumerate()
262            .fold(content, |acc, (i, module)| {
263                let index = format!("{i}");
264                acc.add(&index, module)
265            })
266            .set_top_level_type(format!("[0..{}]", self.len()).as_str())
267            .optional()
268    }
269}
270
271impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
272
273impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
274where
275    T: AutodiffModule<B> + Debug + Send + Clone,
276    T::InnerModule: Debug,
277    B: AutodiffBackend,
278{
279    type InnerModule = [T::InnerModule; N];
280
281    fn valid(&self) -> Self::InnerModule {
282        self.clone().map(|module| module.valid())
283    }
284}
285
286/// A macro for generating implementations for tuple modules of different sizes.
287/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
288/// Would generate an implementation for a tuple of size 2.
289/// For this macro to work properly, please adhere to the convention:
290/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
291macro_rules! impl_module_tuple {
292    // `$l` represents the generic modules.
293    // `$i` represents the indices of the modules in the tuple.
294    ([$($l:ident),*][$($i:tt),*]) => {
295        impl<B, $($l,)*> Module<B> for ($($l,)*)
296        where
297            B: Backend,
298            $($l: Module<B> + Debug + Send + Clone,)*
299        {
300            type Record = ($($l::Record),*);
301
302            fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
303                $(devices = self.$i.collect_devices(devices);)*
304                devices
305            }
306
307            fn fork(self, device: &Device<B>) -> Self {
308                ($(self.$i.fork(device),)*)
309            }
310
311            fn to_device(self, device: &Device<B>) -> Self {
312                ($(self.$i.to_device(device),)*)
313            }
314
315            fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
316                $(self.$i.visit(visitor);)*
317            }
318
319            fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
320                ($(self.$i.map(mapper),)*)
321            }
322
323            fn load_record(self, record: Self::Record) -> Self {
324                ($(self.$i.load_record(record.$i),)*)
325            }
326
327            fn into_record(self) -> Self::Record {
328                ($(self.$i.into_record(),)*)
329            }
330        }
331
332        impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
333        where
334            B: AutodiffBackend,
335            $($l: AutodiffModule<B> + Debug + Send + Clone,)*
336        {
337            type InnerModule = ($($l::InnerModule,)*);
338
339            fn valid(&self) -> Self::InnerModule {
340                ($(self.$i.valid(),)*)
341            }
342        }
343
344        impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
345        where
346            $($l: ModuleDisplay,)*
347        {
348            fn content(&self, content: Content) -> Option<Content> {
349                let content = content
350                    $(.add(&format!("{}", $i), &self.$i))*
351                    .set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
352                content.optional()
353            }
354        }
355
356        impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
357
358    };
359}
360
361impl_module_tuple!([L0, L1][0, 1]);
362impl_module_tuple!([L0, L1, L2][0, 1, 2]);
363impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
364impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
365impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
366impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
367impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
368impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
369impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::TestBackend;
375
376    #[test]
377    fn dont_override_constant_module_when_loading_record() {
378        let module = Some(42);
379
380        let record = Module::<TestBackend>::into_record(module);
381        let loaded = Module::<TestBackend>::load_record(module, record);
382
383        assert_eq!(loaded, module);
384    }
385    #[test]
386    fn dont_override_constant_module_when_loading_none_record() {
387        let module = Some(42);
388
389        let record = None;
390        let loaded = Module::<TestBackend>::load_record(module, record);
391
392        assert_eq!(loaded, module);
393    }
394}