burn_core/module/param/
constant.rs

1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate::{
5    self as burn,
6    module::{
7        AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8        ModuleMapper, ModuleVisitor,
9    },
10    record::Record,
11};
12use burn::record::PrecisionSettings;
13use burn_tensor::{
14    backend::{AutodiffBackend, Backend},
15    ops::Device,
16    BasicAutodiffOps, BasicOps, Tensor,
17};
18
19/// Record used for constant type implementing the [module](crate::module::Module) trait.
20#[derive(Debug, Clone, Copy, new, Default)]
21pub struct ConstantRecord;
22
23impl serde::Serialize for ConstantRecord {
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: serde::Serializer,
27    {
28        // nothing to serialize
29        S::serialize_none(serializer)
30    }
31}
32
33impl<'de> serde::Deserialize<'de> for ConstantRecord {
34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35    where
36        D: serde::Deserializer<'de>,
37    {
38        deserializer.deserialize_option(serde::de::IgnoredAny).ok();
39        Ok(ConstantRecord::new())
40    }
41}
42
43impl<B: Backend> Record<B> for ConstantRecord {
44    type Item<S: PrecisionSettings> = ConstantRecord;
45
46    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
47        self
48    }
49
50    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
51        item
52    }
53}
54/// Constant macro.
55#[macro_export]
56macro_rules! constant {
57    (module) => {
58        type Record = burn::module::ConstantRecord;
59
60        fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
61            // Nothing to do
62        }
63
64        fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
65            self
66        }
67
68        fn load_record(self, _record: Self::Record) -> Self {
69            self
70        }
71
72        fn into_record(self) -> Self::Record {
73            burn::module::ConstantRecord::new()
74        }
75
76        fn to_device(self, _: &B::Device) -> Self {
77            self
78        }
79
80        fn fork(self, _: &B::Device) -> Self {
81            self
82        }
83
84        fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
85            devices
86        }
87    };
88
89    (ad_module, $type:ty) => {
90        type InnerModule = $type;
91
92        fn valid(&self) -> Self::InnerModule {
93            self.clone()
94        }
95    };
96
97    ($type:ty) => {
98        impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
99            constant!(module);
100        }
101
102        impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
103            constant!(ad_module, $type);
104        }
105
106        impl burn::module::ModuleDisplayDefault for $type {
107            fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
108                let string = format!("{}", self);
109                content.add_formatted(&string).optional()
110            }
111        }
112
113        impl burn::module::ModuleDisplay for $type {}
114    };
115}
116
117// General Types
118constant!(alloc::string::String);
119constant!(bool);
120
121// Float Types
122constant!(f64);
123constant!(f32);
124constant!(half::bf16);
125constant!(half::f16);
126
127// Unsigned Integer Types
128constant!(usize);
129constant!(u64);
130constant!(u32);
131constant!(u16);
132constant!(u8);
133
134// Signed Integer Types
135constant!(i64);
136constant!(i32);
137constant!(i16);
138constant!(i8);
139
140impl burn::module::ModuleDisplay for str {}
141impl burn::module::ModuleDisplayDefault for str {
142    fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
143        content.add_formatted(&self).optional()
144    }
145}
146
147impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
148    type Record = ConstantRecord;
149
150    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
151
152    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
153        self
154    }
155
156    fn into_record(self) -> Self::Record {
157        ConstantRecord
158    }
159
160    fn load_record(self, _record: Self::Record) -> Self {
161        self
162    }
163
164    fn to_device(self, device: &B::Device) -> Self {
165        self.to_device(device)
166    }
167
168    fn fork(self, device: &B::Device) -> Self {
169        self.to_device(device)
170    }
171
172    fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
173        let device = self.device();
174
175        if !devices.contains(&device) {
176            devices.push(device)
177        }
178
179        devices
180    }
181}
182
183impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
184    fn content(&self, content: Content) -> Option<Content> {
185        let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
186        content.add_single(&string).optional()
187    }
188}
189
190impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
191
192impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
193    for Tensor<B, D, K>
194{
195    type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
196
197    fn valid(&self) -> Self::InnerModule {
198        self.clone().inner()
199    }
200}
201
202impl<B: Backend> Module<B> for PhantomData<B> {
203    type Record = ConstantRecord;
204
205    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
206        // Nothing to do
207    }
208
209    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
210        self
211    }
212
213    fn load_record(self, _record: Self::Record) -> Self {
214        self
215    }
216
217    fn into_record(self) -> Self::Record {
218        ConstantRecord::new()
219    }
220
221    fn to_device(self, _: &Device<B>) -> Self {
222        self
223    }
224
225    fn fork(self, _: &Device<B>) -> Self {
226        self
227    }
228
229    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
230        devices
231    }
232}
233
234impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
235    fn content(&self, content: Content) -> Option<Content> {
236        content.add_single(&"PhantomData".to_string()).optional()
237    }
238}
239
240impl<B: Backend> ModuleDisplay for PhantomData<B> {}
241
242impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
243    type InnerModule = PhantomData<B::InnerBackend>;
244
245    fn valid(&self) -> Self::InnerModule {
246        PhantomData
247    }
248}
249
250/// Container to satisfy the Module trait for types that are not modules.
251#[derive(Clone, Debug)]
252pub struct Ignored<T>(pub T);
253
254impl<B, T> Module<B> for Ignored<T>
255where
256    B: Backend,
257    T: Sync + Send + core::fmt::Debug + Clone,
258{
259    type Record = ConstantRecord;
260
261    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
262        // Nothing to do
263    }
264
265    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
266        self
267    }
268
269    fn load_record(self, _record: Self::Record) -> Self {
270        self
271    }
272
273    fn into_record(self) -> Self::Record {
274        ConstantRecord::new()
275    }
276
277    fn to_device(self, _: &Device<B>) -> Self {
278        self
279    }
280
281    fn fork(self, _: &Device<B>) -> Self {
282        self
283    }
284
285    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
286        devices
287    }
288}
289
290impl<T> ModuleDisplayDefault for Ignored<T>
291where
292    T: Sync + Send + core::fmt::Debug + Clone,
293{
294    fn content(&self, content: Content) -> Option<Content> {
295        // For now, just print the debug representation of the ignored value
296        content.add_single(&format!("{:?}", self.0)).optional()
297    }
298}
299
300impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
301
302impl<T> Display for Ignored<T>
303where
304    T: Sync + Send + core::fmt::Debug + Clone,
305{
306    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
307        write!(f, "{:?}", self.0)
308    }
309}
310
311impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
312where
313    B: AutodiffBackend,
314    T: Sync + Send + core::fmt::Debug + Clone,
315{
316    type InnerModule = Ignored<T>;
317
318    fn valid(&self) -> Self::InnerModule {
319        self.clone()
320    }
321}
322
323// Implement deref for Ignored
324impl<T> core::ops::Deref for Ignored<T> {
325    type Target = T;
326
327    fn deref(&self) -> &Self::Target {
328        &self.0
329    }
330}
331
332#[cfg(all(test, feature = "std"))]
333mod tests {
334    use core::marker::PhantomData;
335
336    use burn_tensor::backend::Backend;
337    use burn_tensor::{Device, Tensor};
338
339    use crate::TestBackend;
340    use crate::{
341        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
342        TestAutodiffBackend,
343    };
344    use burn::module::Module;
345
346    use crate as burn;
347
348    #[test]
349    fn tensor_load_record_setting() {
350        let device: &Device<TestAutodiffBackend> = &Default::default();
351        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
352
353        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
354        let bytes = Recorder::<TestAutodiffBackend>::record(
355            &byte_recorder,
356            tensor.clone().into_record(),
357            (),
358        )
359        .unwrap();
360
361        let no_grad_is_require_grad = tensor
362            .clone()
363            .no_grad()
364            .load_record(
365                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
366                    .unwrap(),
367            )
368            .is_require_grad();
369
370        let with_default_is_require_grad = tensor
371            .load_record(
372                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
373                    .unwrap(),
374            )
375            .is_require_grad();
376
377        assert!(!no_grad_is_require_grad);
378        assert!(!with_default_is_require_grad);
379    }
380
381    #[test]
382    fn empty_module_with_phantom() {
383        #[derive(Module, Debug, new)]
384        struct EmptyModule<B: Backend> {
385            _phantom: PhantomData<B>,
386        }
387
388        let _module = EmptyModule::<TestBackend>::new();
389
390        assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
391    }
392}