burn_core/module/param/
constant.rs

1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate as burn;
5use crate::{
6    module::{
7        AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8        ModuleMapper, ModuleVisitor,
9    },
10    record::{PrecisionSettings, Record},
11};
12use burn_tensor::{
13    BasicAutodiffOps, BasicOps, Tensor,
14    backend::{AutodiffBackend, Backend},
15    ops::Device,
16};
17
18/// Record used for constant type implementing the [module](crate::module::Module) trait.
19#[derive(Debug, Clone, Copy, new, Default)]
20pub struct ConstantRecord;
21
22impl serde::Serialize for ConstantRecord {
23    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24    where
25        S: serde::Serializer,
26    {
27        // nothing to serialize
28        S::serialize_none(serializer)
29    }
30}
31
32impl<'de> serde::Deserialize<'de> for ConstantRecord {
33    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34    where
35        D: serde::Deserializer<'de>,
36    {
37        deserializer.deserialize_option(serde::de::IgnoredAny).ok();
38        Ok(ConstantRecord::new())
39    }
40}
41
42impl<B: Backend> Record<B> for ConstantRecord {
43    type Item<S: PrecisionSettings> = ConstantRecord;
44
45    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
46        self
47    }
48
49    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
50        item
51    }
52}
53/// Constant macro.
54#[macro_export]
55macro_rules! constant {
56    (module) => {
57        type Record = burn::module::ConstantRecord;
58
59        fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
60            // Nothing to do
61        }
62
63        fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
64            self
65        }
66
67        fn load_record(self, _record: Self::Record) -> Self {
68            self
69        }
70
71        fn into_record(self) -> Self::Record {
72            burn::module::ConstantRecord::new()
73        }
74
75        fn to_device(self, _: &B::Device) -> Self {
76            self
77        }
78
79        fn fork(self, _: &B::Device) -> Self {
80            self
81        }
82
83        fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
84            devices
85        }
86    };
87
88    (ad_module, $type:ty) => {
89        type InnerModule = $type;
90
91        fn valid(&self) -> Self::InnerModule {
92            self.clone()
93        }
94    };
95
96    ($type:ty) => {
97        impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
98            constant!(module);
99        }
100
101        impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
102            constant!(ad_module, $type);
103        }
104
105        impl burn::module::ModuleDisplayDefault for $type {
106            fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
107                let string = format!("{}", self);
108                content.add_formatted(&string).optional()
109            }
110        }
111
112        impl burn::module::ModuleDisplay for $type {}
113    };
114}
115
116// General Types
117constant!(alloc::string::String);
118constant!(bool);
119
120// Float Types
121constant!(f64);
122constant!(f32);
123constant!(half::bf16);
124constant!(half::f16);
125
126// Unsigned Integer Types
127constant!(usize);
128constant!(u64);
129constant!(u32);
130constant!(u16);
131constant!(u8);
132
133// Signed Integer Types
134constant!(isize);
135constant!(i64);
136constant!(i32);
137constant!(i16);
138constant!(i8);
139
140impl burn::module::ModuleDisplay for str {}
141impl burn::module::ModuleDisplayDefault for str {
142    fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
143        content.add_formatted(&self).optional()
144    }
145}
146
147impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
148    type Record = ConstantRecord;
149
150    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
151
152    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
153        self
154    }
155
156    fn into_record(self) -> Self::Record {
157        ConstantRecord
158    }
159
160    fn load_record(self, _record: Self::Record) -> Self {
161        self
162    }
163
164    fn to_device(self, device: &B::Device) -> Self {
165        self.to_device(device)
166    }
167
168    fn fork(self, device: &B::Device) -> Self {
169        self.to_device(device)
170    }
171
172    fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
173        let device = self.device();
174
175        if !devices.contains(&device) {
176            devices.push(device)
177        }
178
179        devices
180    }
181}
182
183impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
184    fn content(&self, content: Content) -> Option<Content> {
185        let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
186        content.add_single(&string).optional()
187    }
188}
189
190impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
191
192impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
193    for Tensor<B, D, K>
194{
195    type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
196
197    fn valid(&self) -> Self::InnerModule {
198        self.clone().inner()
199    }
200}
201
202impl<B: Backend> Module<B> for PhantomData<B> {
203    type Record = ConstantRecord;
204
205    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
206        // Nothing to do
207    }
208
209    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
210        self
211    }
212
213    fn load_record(self, _record: Self::Record) -> Self {
214        self
215    }
216
217    fn into_record(self) -> Self::Record {
218        ConstantRecord::new()
219    }
220
221    fn to_device(self, _: &Device<B>) -> Self {
222        self
223    }
224
225    fn fork(self, _: &Device<B>) -> Self {
226        self
227    }
228
229    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
230        devices
231    }
232}
233
234impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
235    fn content(&self, content: Content) -> Option<Content> {
236        content.add_single(&"PhantomData".to_string()).optional()
237    }
238}
239
240impl<B: Backend> ModuleDisplay for PhantomData<B> {}
241
242impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
243    type InnerModule = PhantomData<B::InnerBackend>;
244
245    fn valid(&self) -> Self::InnerModule {
246        PhantomData
247    }
248}
249
250/// Container to satisfy the Module trait for types that are not modules.
251#[derive(Clone, Debug)]
252pub struct Ignored<T>(pub T);
253
254impl<B, T> Module<B> for Ignored<T>
255where
256    B: Backend,
257    T: Sync + Send + core::fmt::Debug + Clone,
258{
259    type Record = ConstantRecord;
260
261    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
262        // Nothing to do
263    }
264
265    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
266        self
267    }
268
269    fn load_record(self, _record: Self::Record) -> Self {
270        self
271    }
272
273    fn into_record(self) -> Self::Record {
274        ConstantRecord::new()
275    }
276
277    fn to_device(self, _: &Device<B>) -> Self {
278        self
279    }
280
281    fn fork(self, _: &Device<B>) -> Self {
282        self
283    }
284
285    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
286        devices
287    }
288}
289
290impl<T> ModuleDisplayDefault for Ignored<T>
291where
292    T: Sync + Send + core::fmt::Debug + Clone,
293{
294    fn content(&self, content: Content) -> Option<Content> {
295        // For now, just print the debug representation of the ignored value
296        content.add_single(&format!("{:?}", self.0)).optional()
297    }
298}
299
300impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
301
302impl<T> Display for Ignored<T>
303where
304    T: Sync + Send + core::fmt::Debug + Clone,
305{
306    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
307        write!(f, "{:?}", self.0)
308    }
309}
310
311impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
312where
313    B: AutodiffBackend,
314    T: Sync + Send + core::fmt::Debug + Clone,
315{
316    type InnerModule = Ignored<T>;
317
318    fn valid(&self) -> Self::InnerModule {
319        self.clone()
320    }
321}
322
323// Implement deref for Ignored
324impl<T> core::ops::Deref for Ignored<T> {
325    type Target = T;
326
327    fn deref(&self) -> &Self::Target {
328        &self.0
329    }
330}
331
332#[cfg(all(test, feature = "std"))]
333mod tests {
334    use core::marker::PhantomData;
335
336    use burn_tensor::backend::Backend;
337    use burn_tensor::{Device, Tensor};
338
339    use crate::TestBackend;
340    use crate::{
341        TestAutodiffBackend,
342        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
343    };
344    use burn::module::Module;
345
346    use crate as burn;
347
348    #[test]
349    fn tensor_load_record_setting() {
350        let device: &Device<TestAutodiffBackend> = &Default::default();
351        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
352
353        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
354        let bytes = Recorder::<TestAutodiffBackend>::record(
355            &byte_recorder,
356            tensor.clone().into_record(),
357            (),
358        )
359        .unwrap();
360
361        let no_grad_is_require_grad = tensor
362            .clone()
363            .no_grad()
364            .load_record(
365                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
366                    .unwrap(),
367            )
368            .is_require_grad();
369
370        let with_default_is_require_grad = tensor
371            .load_record(
372                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
373                    .unwrap(),
374            )
375            .is_require_grad();
376
377        assert!(!no_grad_is_require_grad);
378        assert!(!with_default_is_require_grad);
379    }
380
381    #[test]
382    fn empty_module_with_phantom() {
383        #[derive(Module, Debug, new)]
384        struct EmptyModule<B: Backend> {
385            _phantom: PhantomData<B>,
386        }
387
388        let _module = EmptyModule::<TestBackend>::new();
389
390        assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
391    }
392}