burn_core/module/param/
constant.rs

1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate::{
5    self as burn,
6    module::{
7        AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8        ModuleMapper, ModuleVisitor,
9    },
10    record::Record,
11};
12use burn::record::PrecisionSettings;
13use burn_tensor::{
14    BasicAutodiffOps, BasicOps, Tensor,
15    backend::{AutodiffBackend, Backend},
16    ops::Device,
17};
18
19/// Record used for constant type implementing the [module](crate::module::Module) trait.
20#[derive(Debug, Clone, Copy, new, Default)]
21pub struct ConstantRecord;
22
23impl serde::Serialize for ConstantRecord {
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: serde::Serializer,
27    {
28        // nothing to serialize
29        S::serialize_none(serializer)
30    }
31}
32
33impl<'de> serde::Deserialize<'de> for ConstantRecord {
34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35    where
36        D: serde::Deserializer<'de>,
37    {
38        deserializer.deserialize_option(serde::de::IgnoredAny).ok();
39        Ok(ConstantRecord::new())
40    }
41}
42
43impl<B: Backend> Record<B> for ConstantRecord {
44    type Item<S: PrecisionSettings> = ConstantRecord;
45
46    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
47        self
48    }
49
50    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
51        item
52    }
53}
54/// Constant macro.
55#[macro_export]
56macro_rules! constant {
57    (module) => {
58        type Record = burn::module::ConstantRecord;
59
60        fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
61            // Nothing to do
62        }
63
64        fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
65            self
66        }
67
68        fn load_record(self, _record: Self::Record) -> Self {
69            self
70        }
71
72        fn into_record(self) -> Self::Record {
73            burn::module::ConstantRecord::new()
74        }
75
76        fn to_device(self, _: &B::Device) -> Self {
77            self
78        }
79
80        fn fork(self, _: &B::Device) -> Self {
81            self
82        }
83
84        fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
85            devices
86        }
87    };
88
89    (ad_module, $type:ty) => {
90        type InnerModule = $type;
91
92        fn valid(&self) -> Self::InnerModule {
93            self.clone()
94        }
95    };
96
97    ($type:ty) => {
98        impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
99            constant!(module);
100        }
101
102        impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
103            constant!(ad_module, $type);
104        }
105
106        impl burn::module::ModuleDisplayDefault for $type {
107            fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
108                let string = format!("{}", self);
109                content.add_formatted(&string).optional()
110            }
111        }
112
113        impl burn::module::ModuleDisplay for $type {}
114    };
115}
116
117// General Types
118constant!(alloc::string::String);
119constant!(bool);
120
121// Float Types
122constant!(f64);
123constant!(f32);
124constant!(half::bf16);
125constant!(half::f16);
126
127// Unsigned Integer Types
128constant!(usize);
129constant!(u64);
130constant!(u32);
131constant!(u16);
132constant!(u8);
133
134// Signed Integer Types
135constant!(isize);
136constant!(i64);
137constant!(i32);
138constant!(i16);
139constant!(i8);
140
141impl burn::module::ModuleDisplay for str {}
142impl burn::module::ModuleDisplayDefault for str {
143    fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
144        content.add_formatted(&self).optional()
145    }
146}
147
148impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
149    type Record = ConstantRecord;
150
151    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
152
153    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
154        self
155    }
156
157    fn into_record(self) -> Self::Record {
158        ConstantRecord
159    }
160
161    fn load_record(self, _record: Self::Record) -> Self {
162        self
163    }
164
165    fn to_device(self, device: &B::Device) -> Self {
166        self.to_device(device)
167    }
168
169    fn fork(self, device: &B::Device) -> Self {
170        self.to_device(device)
171    }
172
173    fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
174        let device = self.device();
175
176        if !devices.contains(&device) {
177            devices.push(device)
178        }
179
180        devices
181    }
182}
183
184impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
185    fn content(&self, content: Content) -> Option<Content> {
186        let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
187        content.add_single(&string).optional()
188    }
189}
190
191impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
192
193impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
194    for Tensor<B, D, K>
195{
196    type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
197
198    fn valid(&self) -> Self::InnerModule {
199        self.clone().inner()
200    }
201}
202
203impl<B: Backend> Module<B> for PhantomData<B> {
204    type Record = ConstantRecord;
205
206    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
207        // Nothing to do
208    }
209
210    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
211        self
212    }
213
214    fn load_record(self, _record: Self::Record) -> Self {
215        self
216    }
217
218    fn into_record(self) -> Self::Record {
219        ConstantRecord::new()
220    }
221
222    fn to_device(self, _: &Device<B>) -> Self {
223        self
224    }
225
226    fn fork(self, _: &Device<B>) -> Self {
227        self
228    }
229
230    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
231        devices
232    }
233}
234
235impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
236    fn content(&self, content: Content) -> Option<Content> {
237        content.add_single(&"PhantomData".to_string()).optional()
238    }
239}
240
241impl<B: Backend> ModuleDisplay for PhantomData<B> {}
242
243impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
244    type InnerModule = PhantomData<B::InnerBackend>;
245
246    fn valid(&self) -> Self::InnerModule {
247        PhantomData
248    }
249}
250
251/// Container to satisfy the Module trait for types that are not modules.
252#[derive(Clone, Debug)]
253pub struct Ignored<T>(pub T);
254
255impl<B, T> Module<B> for Ignored<T>
256where
257    B: Backend,
258    T: Sync + Send + core::fmt::Debug + Clone,
259{
260    type Record = ConstantRecord;
261
262    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
263        // Nothing to do
264    }
265
266    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
267        self
268    }
269
270    fn load_record(self, _record: Self::Record) -> Self {
271        self
272    }
273
274    fn into_record(self) -> Self::Record {
275        ConstantRecord::new()
276    }
277
278    fn to_device(self, _: &Device<B>) -> Self {
279        self
280    }
281
282    fn fork(self, _: &Device<B>) -> Self {
283        self
284    }
285
286    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
287        devices
288    }
289}
290
291impl<T> ModuleDisplayDefault for Ignored<T>
292where
293    T: Sync + Send + core::fmt::Debug + Clone,
294{
295    fn content(&self, content: Content) -> Option<Content> {
296        // For now, just print the debug representation of the ignored value
297        content.add_single(&format!("{:?}", self.0)).optional()
298    }
299}
300
301impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
302
303impl<T> Display for Ignored<T>
304where
305    T: Sync + Send + core::fmt::Debug + Clone,
306{
307    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
308        write!(f, "{:?}", self.0)
309    }
310}
311
312impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
313where
314    B: AutodiffBackend,
315    T: Sync + Send + core::fmt::Debug + Clone,
316{
317    type InnerModule = Ignored<T>;
318
319    fn valid(&self) -> Self::InnerModule {
320        self.clone()
321    }
322}
323
324// Implement deref for Ignored
325impl<T> core::ops::Deref for Ignored<T> {
326    type Target = T;
327
328    fn deref(&self) -> &Self::Target {
329        &self.0
330    }
331}
332
333#[cfg(all(test, feature = "std"))]
334mod tests {
335    use core::marker::PhantomData;
336
337    use burn_tensor::backend::Backend;
338    use burn_tensor::{Device, Tensor};
339
340    use crate::TestBackend;
341    use crate::{
342        TestAutodiffBackend,
343        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
344    };
345    use burn::module::Module;
346
347    use crate as burn;
348
349    #[test]
350    fn tensor_load_record_setting() {
351        let device: &Device<TestAutodiffBackend> = &Default::default();
352        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
353
354        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
355        let bytes = Recorder::<TestAutodiffBackend>::record(
356            &byte_recorder,
357            tensor.clone().into_record(),
358            (),
359        )
360        .unwrap();
361
362        let no_grad_is_require_grad = tensor
363            .clone()
364            .no_grad()
365            .load_record(
366                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
367                    .unwrap(),
368            )
369            .is_require_grad();
370
371        let with_default_is_require_grad = tensor
372            .load_record(
373                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
374                    .unwrap(),
375            )
376            .is_require_grad();
377
378        assert!(!no_grad_is_require_grad);
379        assert!(!with_default_is_require_grad);
380    }
381
382    #[test]
383    fn empty_module_with_phantom() {
384        #[derive(Module, Debug, new)]
385        struct EmptyModule<B: Backend> {
386            _phantom: PhantomData<B>,
387        }
388
389        let _module = EmptyModule::<TestBackend>::new();
390
391        assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
392    }
393}