1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate as burn;
5use crate::{
6 module::{
7 AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8 ModuleMapper, ModuleVisitor,
9 },
10 record::{PrecisionSettings, Record},
11};
12use burn_tensor::{
13 BasicAutodiffOps, BasicOps, Tensor,
14 backend::{AutodiffBackend, Backend},
15 ops::Device,
16};
17
18#[derive(Debug, Clone, Copy, new, Default)]
20pub struct ConstantRecord;
21
22impl serde::Serialize for ConstantRecord {
23 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24 where
25 S: serde::Serializer,
26 {
27 S::serialize_none(serializer)
29 }
30}
31
32impl<'de> serde::Deserialize<'de> for ConstantRecord {
33 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34 where
35 D: serde::Deserializer<'de>,
36 {
37 deserializer.deserialize_option(serde::de::IgnoredAny).ok();
38 Ok(ConstantRecord::new())
39 }
40}
41
42impl<B: Backend> Record<B> for ConstantRecord {
43 type Item<S: PrecisionSettings> = ConstantRecord;
44
45 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
46 self
47 }
48
49 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
50 item
51 }
52}
53#[macro_export]
55macro_rules! constant {
56 (module) => {
57 type Record = burn::module::ConstantRecord;
58
59 fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
60 }
62
63 fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
64 self
65 }
66
67 fn load_record(self, _record: Self::Record) -> Self {
68 self
69 }
70
71 fn into_record(self) -> Self::Record {
72 burn::module::ConstantRecord::new()
73 }
74
75 fn to_device(self, _: &B::Device) -> Self {
76 self
77 }
78
79 fn fork(self, _: &B::Device) -> Self {
80 self
81 }
82
83 fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
84 devices
85 }
86 };
87
88 (ad_module, $type:ty) => {
89 type InnerModule = $type;
90
91 fn valid(&self) -> Self::InnerModule {
92 self.clone()
93 }
94 };
95
96 ($type:ty) => {
97 impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
98 constant!(module);
99 }
100
101 impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
102 constant!(ad_module, $type);
103 }
104
105 impl burn::module::ModuleDisplayDefault for $type {
106 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
107 let string = format!("{}", self);
108 content.add_formatted(&string).optional()
109 }
110 }
111
112 impl burn::module::ModuleDisplay for $type {}
113 };
114}
115
116constant!(alloc::string::String);
118constant!(bool);
119
120constant!(f64);
122constant!(f32);
123constant!(half::bf16);
124constant!(half::f16);
125
126constant!(usize);
128constant!(u64);
129constant!(u32);
130constant!(u16);
131constant!(u8);
132
133constant!(isize);
135constant!(i64);
136constant!(i32);
137constant!(i16);
138constant!(i8);
139
140impl burn::module::ModuleDisplay for str {}
141impl burn::module::ModuleDisplayDefault for str {
142 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
143 content.add_formatted(&self).optional()
144 }
145}
146
147impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
148 type Record = ConstantRecord;
149
150 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
151
152 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
153 self
154 }
155
156 fn into_record(self) -> Self::Record {
157 ConstantRecord
158 }
159
160 fn load_record(self, _record: Self::Record) -> Self {
161 self
162 }
163
164 fn to_device(self, device: &B::Device) -> Self {
165 self.to_device(device)
166 }
167
168 fn fork(self, device: &B::Device) -> Self {
169 self.to_device(device)
170 }
171
172 fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
173 let device = self.device();
174
175 if !devices.contains(&device) {
176 devices.push(device)
177 }
178
179 devices
180 }
181}
182
183impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
184 fn content(&self, content: Content) -> Option<Content> {
185 let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
186 content.add_single(&string).optional()
187 }
188}
189
190impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
191
192impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
193 for Tensor<B, D, K>
194{
195 type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
196
197 fn valid(&self) -> Self::InnerModule {
198 self.clone().inner()
199 }
200}
201
202impl<B: Backend> Module<B> for PhantomData<B> {
203 type Record = ConstantRecord;
204
205 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
206 }
208
209 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
210 self
211 }
212
213 fn load_record(self, _record: Self::Record) -> Self {
214 self
215 }
216
217 fn into_record(self) -> Self::Record {
218 ConstantRecord::new()
219 }
220
221 fn to_device(self, _: &Device<B>) -> Self {
222 self
223 }
224
225 fn fork(self, _: &Device<B>) -> Self {
226 self
227 }
228
229 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
230 devices
231 }
232}
233
234impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
235 fn content(&self, content: Content) -> Option<Content> {
236 content.add_single(&"PhantomData".to_string()).optional()
237 }
238}
239
240impl<B: Backend> ModuleDisplay for PhantomData<B> {}
241
242impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
243 type InnerModule = PhantomData<B::InnerBackend>;
244
245 fn valid(&self) -> Self::InnerModule {
246 PhantomData
247 }
248}
249
250#[derive(Clone, Debug)]
252pub struct Ignored<T>(pub T);
253
254impl<B, T> Module<B> for Ignored<T>
255where
256 B: Backend,
257 T: Sync + Send + core::fmt::Debug + Clone,
258{
259 type Record = ConstantRecord;
260
261 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
262 }
264
265 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
266 self
267 }
268
269 fn load_record(self, _record: Self::Record) -> Self {
270 self
271 }
272
273 fn into_record(self) -> Self::Record {
274 ConstantRecord::new()
275 }
276
277 fn to_device(self, _: &Device<B>) -> Self {
278 self
279 }
280
281 fn fork(self, _: &Device<B>) -> Self {
282 self
283 }
284
285 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
286 devices
287 }
288}
289
290impl<T> ModuleDisplayDefault for Ignored<T>
291where
292 T: Sync + Send + core::fmt::Debug + Clone,
293{
294 fn content(&self, content: Content) -> Option<Content> {
295 content.add_single(&format!("{:?}", self.0)).optional()
297 }
298}
299
300impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
301
302impl<T> Display for Ignored<T>
303where
304 T: Sync + Send + core::fmt::Debug + Clone,
305{
306 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
307 write!(f, "{:?}", self.0)
308 }
309}
310
311impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
312where
313 B: AutodiffBackend,
314 T: Sync + Send + core::fmt::Debug + Clone,
315{
316 type InnerModule = Ignored<T>;
317
318 fn valid(&self) -> Self::InnerModule {
319 self.clone()
320 }
321}
322
323impl<T> core::ops::Deref for Ignored<T> {
325 type Target = T;
326
327 fn deref(&self) -> &Self::Target {
328 &self.0
329 }
330}
331
332#[cfg(all(test, feature = "std"))]
333mod tests {
334 use core::marker::PhantomData;
335
336 use burn_tensor::backend::Backend;
337 use burn_tensor::{Device, Tensor};
338
339 use crate::TestBackend;
340 use crate::{
341 TestAutodiffBackend,
342 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
343 };
344 use burn::module::Module;
345
346 use crate as burn;
347
348 #[test]
349 fn tensor_load_record_setting() {
350 let device: &Device<TestAutodiffBackend> = &Default::default();
351 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
352
353 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
354 let bytes = Recorder::<TestAutodiffBackend>::record(
355 &byte_recorder,
356 tensor.clone().into_record(),
357 (),
358 )
359 .unwrap();
360
361 let no_grad_is_require_grad = tensor
362 .clone()
363 .no_grad()
364 .load_record(
365 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
366 .unwrap(),
367 )
368 .is_require_grad();
369
370 let with_default_is_require_grad = tensor
371 .load_record(
372 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
373 .unwrap(),
374 )
375 .is_require_grad();
376
377 assert!(!no_grad_is_require_grad);
378 assert!(!with_default_is_require_grad);
379 }
380
381 #[test]
382 fn empty_module_with_phantom() {
383 #[derive(Module, Debug, new)]
384 struct EmptyModule<B: Backend> {
385 _phantom: PhantomData<B>,
386 }
387
388 let _module = EmptyModule::<TestBackend>::new();
389
390 assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
391 }
392}