Skip to main content

burn_core/module/param/
constant.rs

1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate as burn;
5use crate::{
6    module::{
7        AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8        ModuleMapper, ModuleVisitor,
9    },
10    record::{PrecisionSettings, Record},
11};
12use burn_tensor::{
13    BasicAutodiffOps, BasicOps, Tensor,
14    backend::{AutodiffBackend, Backend},
15    ops::Device,
16};
17
18#[deprecated(
19    since = "0.21.0",
20    note = "ConstantRecord is misleading as it doesn't persist data. Use EmptyRecord instead."
21)]
22/// A record representing the absence of persistent module state.
23pub type ConstantRecord = EmptyRecord;
24
25/// A record representing the absence of persistent module state.
26///
27/// `EmptyRecord` is used for modules that do not store any data to be
28/// serialized or restored (e.g., modules marked with `#[module(skip)]`
29/// or modules without parameters).
30///
31/// This record contains no fields and serializes to `None`.
32#[derive(Debug, Clone, Copy, new, Default, PartialEq, Eq)]
33pub struct EmptyRecord;
34
35impl serde::Serialize for EmptyRecord {
36    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37    where
38        S: serde::Serializer,
39    {
40        // nothing to serialize
41        S::serialize_none(serializer)
42    }
43}
44
45impl<'de> serde::Deserialize<'de> for EmptyRecord {
46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47    where
48        D: serde::Deserializer<'de>,
49    {
50        deserializer.deserialize_option(serde::de::IgnoredAny).ok();
51        Ok(EmptyRecord::new())
52    }
53}
54
55impl<B: Backend> Record<B> for EmptyRecord {
56    type Item<S: PrecisionSettings> = EmptyRecord;
57
58    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
59        self
60    }
61
62    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
63        item
64    }
65}
66/// Constant macro.
67#[macro_export]
68macro_rules! empty {
69    (module) => {
70        type Record = burn::module::EmptyRecord;
71
72        fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
73            // Nothing to do
74        }
75
76        fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
77            self
78        }
79
80        fn load_record(self, _record: Self::Record) -> Self {
81            self
82        }
83
84        fn into_record(self) -> Self::Record {
85            burn::module::EmptyRecord::new()
86        }
87
88        fn to_device(self, _: &B::Device) -> Self {
89            self
90        }
91
92        fn fork(self, _: &B::Device) -> Self {
93            self
94        }
95
96        fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
97            devices
98        }
99    };
100
101    (ad_module, $type:ty) => {
102        type InnerModule = $type;
103
104        fn valid(&self) -> Self::InnerModule {
105            self.clone()
106        }
107
108        fn from_inner(module: Self::InnerModule) -> Self {
109            module
110        }
111    };
112
113    ($type:ty) => {
114        impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
115            empty!(module);
116        }
117
118        impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
119            empty!(ad_module, $type);
120        }
121
122        impl burn::module::ModuleDisplayDefault for $type {
123            fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
124                let string = format!("{}", self);
125                content.add_formatted(&string).optional()
126            }
127        }
128
129        impl burn::module::ModuleDisplay for $type {}
130    };
131}
132
133// TODO: breaking change for these constant types (currently empty record, non-persistent)?
134
135// General Types
136empty!(alloc::string::String);
137empty!(bool);
138
139// Float Types
140empty!(f64);
141empty!(f32);
142empty!(half::bf16);
143empty!(half::f16);
144
145// Unsigned Integer Types
146empty!(usize);
147empty!(u64);
148empty!(u32);
149empty!(u16);
150empty!(u8);
151
152// Signed Integer Types
153empty!(isize);
154empty!(i64);
155empty!(i32);
156empty!(i16);
157empty!(i8);
158
159impl burn::module::ModuleDisplay for str {}
160impl burn::module::ModuleDisplayDefault for str {
161    fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
162        content.add_formatted(&self).optional()
163    }
164}
165
166// TODO: tensor record should persist
167impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
168    type Record = EmptyRecord;
169
170    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
171
172    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
173        self
174    }
175
176    fn into_record(self) -> Self::Record {
177        EmptyRecord
178    }
179
180    fn load_record(self, _record: Self::Record) -> Self {
181        self
182    }
183
184    fn to_device(self, device: &B::Device) -> Self {
185        self.to_device(device)
186    }
187
188    fn fork(self, device: &B::Device) -> Self {
189        self.to_device(device)
190    }
191
192    fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
193        let device = self.device();
194
195        if !devices.contains(&device) {
196            devices.push(device)
197        }
198
199        devices
200    }
201}
202
203impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
204    fn content(&self, content: Content) -> Option<Content> {
205        let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().as_slice());
206        content.add_single(&string).optional()
207    }
208}
209
210impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
211
212impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
213    for Tensor<B, D, K>
214{
215    type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
216
217    fn valid(&self) -> Self::InnerModule {
218        self.clone().inner()
219    }
220
221    fn from_inner(tensor: Self::InnerModule) -> Self {
222        Tensor::from_inner(tensor)
223    }
224}
225
226impl<B: Backend> Module<B> for PhantomData<B> {
227    type Record = EmptyRecord;
228
229    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
230        // Nothing to do
231    }
232
233    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
234        self
235    }
236
237    fn load_record(self, _record: Self::Record) -> Self {
238        self
239    }
240
241    fn into_record(self) -> Self::Record {
242        EmptyRecord::new()
243    }
244
245    fn to_device(self, _: &Device<B>) -> Self {
246        self
247    }
248
249    fn fork(self, _: &Device<B>) -> Self {
250        self
251    }
252
253    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
254        devices
255    }
256}
257
258impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
259    fn content(&self, content: Content) -> Option<Content> {
260        content.add_single(&"PhantomData".to_string()).optional()
261    }
262}
263
264impl<B: Backend> ModuleDisplay for PhantomData<B> {}
265
266impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
267    type InnerModule = PhantomData<B::InnerBackend>;
268
269    fn valid(&self) -> Self::InnerModule {
270        PhantomData
271    }
272
273    fn from_inner(_module: Self::InnerModule) -> Self {
274        PhantomData
275    }
276}
277
278/// Container to satisfy the Module trait for types that are not modules.
279#[derive(Clone, Debug)]
280#[deprecated(
281    since = "0.21.0",
282    note = "Ignored<T> is deprecated. Use #[module(skip)] for non-persistent fields (same behavior)."
283)]
284pub struct Ignored<T>(pub T);
285
286#[allow(deprecated)]
287impl<B, T> Module<B> for Ignored<T>
288where
289    B: Backend,
290    T: Sync + Send + core::fmt::Debug + Clone,
291{
292    type Record = EmptyRecord;
293
294    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
295        // Nothing to do
296    }
297
298    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
299        self
300    }
301
302    fn load_record(self, _record: Self::Record) -> Self {
303        self
304    }
305
306    fn into_record(self) -> Self::Record {
307        EmptyRecord::new()
308    }
309
310    fn to_device(self, _: &Device<B>) -> Self {
311        self
312    }
313
314    fn fork(self, _: &Device<B>) -> Self {
315        self
316    }
317
318    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
319        devices
320    }
321}
322
323#[allow(deprecated)]
324impl<T> ModuleDisplayDefault for Ignored<T>
325where
326    T: Sync + Send + core::fmt::Debug + Clone,
327{
328    fn content(&self, content: Content) -> Option<Content> {
329        // For now, just print the debug representation of the ignored value
330        content.add_single(&format!("{:?}", self.0)).optional()
331    }
332}
333
334#[allow(deprecated)]
335impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
336
337#[allow(deprecated)]
338impl<T> Display for Ignored<T>
339where
340    T: Sync + Send + core::fmt::Debug + Clone,
341{
342    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
343        write!(f, "{:?}", self.0)
344    }
345}
346
347#[allow(deprecated)]
348impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
349where
350    B: AutodiffBackend,
351    T: Sync + Send + core::fmt::Debug + Clone,
352{
353    type InnerModule = Ignored<T>;
354
355    fn valid(&self) -> Self::InnerModule {
356        self.clone()
357    }
358
359    fn from_inner(module: Self::InnerModule) -> Self {
360        module
361    }
362}
363
364#[allow(deprecated)]
365// Implement deref for Ignored
366impl<T> core::ops::Deref for Ignored<T> {
367    type Target = T;
368
369    fn deref(&self) -> &Self::Target {
370        &self.0
371    }
372}
373
374#[cfg(all(test, feature = "std"))]
375mod tests {
376    use core::marker::PhantomData;
377
378    use burn_tensor::backend::Backend;
379    use burn_tensor::{Device, Tensor};
380
381    use crate::TestBackend;
382    use crate::{
383        TestAutodiffBackend,
384        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
385    };
386    use burn::module::Module;
387
388    use crate as burn;
389
390    #[test]
391    fn tensor_load_record_setting() {
392        let device: &Device<TestAutodiffBackend> = &Default::default();
393        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
394
395        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
396        let bytes = Recorder::<TestAutodiffBackend>::record(
397            &byte_recorder,
398            tensor.clone().into_record(),
399            (),
400        )
401        .unwrap();
402
403        let no_grad_is_require_grad = tensor
404            .clone()
405            .no_grad()
406            .load_record(
407                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
408                    .unwrap(),
409            )
410            .is_require_grad();
411
412        let with_default_is_require_grad = tensor
413            .load_record(
414                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
415                    .unwrap(),
416            )
417            .is_require_grad();
418
419        assert!(!no_grad_is_require_grad);
420        assert!(!with_default_is_require_grad);
421    }
422
423    #[test]
424    fn empty_module_with_phantom() {
425        #[derive(Module, Debug, new)]
426        struct EmptyModule<B: Backend> {
427            _phantom: PhantomData<B>,
428        }
429
430        let _module = EmptyModule::<TestBackend>::new();
431
432        assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
433    }
434}