1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate as burn;
5use crate::{
6 module::{
7 AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8 ModuleMapper, ModuleVisitor,
9 },
10 record::{PrecisionSettings, Record},
11};
12use burn_tensor::{
13 BasicAutodiffOps, BasicOps, Tensor,
14 backend::{AutodiffBackend, Backend},
15 ops::Device,
16};
17
18#[deprecated(
19 since = "0.21.0",
20 note = "ConstantRecord is misleading as it doesn't persist data. Use EmptyRecord instead."
21)]
22pub type ConstantRecord = EmptyRecord;
24
25#[derive(Debug, Clone, Copy, new, Default, PartialEq, Eq)]
33pub struct EmptyRecord;
34
35impl serde::Serialize for EmptyRecord {
36 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37 where
38 S: serde::Serializer,
39 {
40 S::serialize_none(serializer)
42 }
43}
44
45impl<'de> serde::Deserialize<'de> for EmptyRecord {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: serde::Deserializer<'de>,
49 {
50 deserializer.deserialize_option(serde::de::IgnoredAny).ok();
51 Ok(EmptyRecord::new())
52 }
53}
54
55impl<B: Backend> Record<B> for EmptyRecord {
56 type Item<S: PrecisionSettings> = EmptyRecord;
57
58 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
59 self
60 }
61
62 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
63 item
64 }
65}
66#[macro_export]
68macro_rules! empty {
69 (module) => {
70 type Record = burn::module::EmptyRecord;
71
72 fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
73 }
75
76 fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
77 self
78 }
79
80 fn load_record(self, _record: Self::Record) -> Self {
81 self
82 }
83
84 fn into_record(self) -> Self::Record {
85 burn::module::EmptyRecord::new()
86 }
87
88 fn to_device(self, _: &B::Device) -> Self {
89 self
90 }
91
92 fn fork(self, _: &B::Device) -> Self {
93 self
94 }
95
96 fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
97 devices
98 }
99 };
100
101 (ad_module, $type:ty) => {
102 type InnerModule = $type;
103
104 fn valid(&self) -> Self::InnerModule {
105 self.clone()
106 }
107
108 fn from_inner(module: Self::InnerModule) -> Self {
109 module
110 }
111 };
112
113 ($type:ty) => {
114 impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
115 empty!(module);
116 }
117
118 impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
119 empty!(ad_module, $type);
120 }
121
122 impl burn::module::ModuleDisplayDefault for $type {
123 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
124 let string = format!("{}", self);
125 content.add_formatted(&string).optional()
126 }
127 }
128
129 impl burn::module::ModuleDisplay for $type {}
130 };
131}
132
133empty!(alloc::string::String);
137empty!(bool);
138
139empty!(f64);
141empty!(f32);
142empty!(half::bf16);
143empty!(half::f16);
144
145empty!(usize);
147empty!(u64);
148empty!(u32);
149empty!(u16);
150empty!(u8);
151
152empty!(isize);
154empty!(i64);
155empty!(i32);
156empty!(i16);
157empty!(i8);
158
159impl burn::module::ModuleDisplay for str {}
160impl burn::module::ModuleDisplayDefault for str {
161 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
162 content.add_formatted(&self).optional()
163 }
164}
165
166impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
168 type Record = EmptyRecord;
169
170 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
171
172 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
173 self
174 }
175
176 fn into_record(self) -> Self::Record {
177 EmptyRecord
178 }
179
180 fn load_record(self, _record: Self::Record) -> Self {
181 self
182 }
183
184 fn to_device(self, device: &B::Device) -> Self {
185 self.to_device(device)
186 }
187
188 fn fork(self, device: &B::Device) -> Self {
189 self.to_device(device)
190 }
191
192 fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
193 let device = self.device();
194
195 if !devices.contains(&device) {
196 devices.push(device)
197 }
198
199 devices
200 }
201}
202
203impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
204 fn content(&self, content: Content) -> Option<Content> {
205 let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().as_slice());
206 content.add_single(&string).optional()
207 }
208}
209
210impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
211
212impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
213 for Tensor<B, D, K>
214{
215 type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
216
217 fn valid(&self) -> Self::InnerModule {
218 self.clone().inner()
219 }
220
221 fn from_inner(tensor: Self::InnerModule) -> Self {
222 Tensor::from_inner(tensor)
223 }
224}
225
226impl<B: Backend> Module<B> for PhantomData<B> {
227 type Record = EmptyRecord;
228
229 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
230 }
232
233 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
234 self
235 }
236
237 fn load_record(self, _record: Self::Record) -> Self {
238 self
239 }
240
241 fn into_record(self) -> Self::Record {
242 EmptyRecord::new()
243 }
244
245 fn to_device(self, _: &Device<B>) -> Self {
246 self
247 }
248
249 fn fork(self, _: &Device<B>) -> Self {
250 self
251 }
252
253 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
254 devices
255 }
256}
257
258impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
259 fn content(&self, content: Content) -> Option<Content> {
260 content.add_single(&"PhantomData".to_string()).optional()
261 }
262}
263
264impl<B: Backend> ModuleDisplay for PhantomData<B> {}
265
266impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
267 type InnerModule = PhantomData<B::InnerBackend>;
268
269 fn valid(&self) -> Self::InnerModule {
270 PhantomData
271 }
272
273 fn from_inner(_module: Self::InnerModule) -> Self {
274 PhantomData
275 }
276}
277
278#[derive(Clone, Debug)]
280#[deprecated(
281 since = "0.21.0",
282 note = "Ignored<T> is deprecated. Use #[module(skip)] for non-persistent fields (same behavior)."
283)]
284pub struct Ignored<T>(pub T);
285
286#[allow(deprecated)]
287impl<B, T> Module<B> for Ignored<T>
288where
289 B: Backend,
290 T: Sync + Send + core::fmt::Debug + Clone,
291{
292 type Record = EmptyRecord;
293
294 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
295 }
297
298 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
299 self
300 }
301
302 fn load_record(self, _record: Self::Record) -> Self {
303 self
304 }
305
306 fn into_record(self) -> Self::Record {
307 EmptyRecord::new()
308 }
309
310 fn to_device(self, _: &Device<B>) -> Self {
311 self
312 }
313
314 fn fork(self, _: &Device<B>) -> Self {
315 self
316 }
317
318 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
319 devices
320 }
321}
322
323#[allow(deprecated)]
324impl<T> ModuleDisplayDefault for Ignored<T>
325where
326 T: Sync + Send + core::fmt::Debug + Clone,
327{
328 fn content(&self, content: Content) -> Option<Content> {
329 content.add_single(&format!("{:?}", self.0)).optional()
331 }
332}
333
334#[allow(deprecated)]
335impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
336
337#[allow(deprecated)]
338impl<T> Display for Ignored<T>
339where
340 T: Sync + Send + core::fmt::Debug + Clone,
341{
342 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
343 write!(f, "{:?}", self.0)
344 }
345}
346
347#[allow(deprecated)]
348impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
349where
350 B: AutodiffBackend,
351 T: Sync + Send + core::fmt::Debug + Clone,
352{
353 type InnerModule = Ignored<T>;
354
355 fn valid(&self) -> Self::InnerModule {
356 self.clone()
357 }
358
359 fn from_inner(module: Self::InnerModule) -> Self {
360 module
361 }
362}
363
364#[allow(deprecated)]
365impl<T> core::ops::Deref for Ignored<T> {
367 type Target = T;
368
369 fn deref(&self) -> &Self::Target {
370 &self.0
371 }
372}
373
374#[cfg(all(test, feature = "std"))]
375mod tests {
376 use core::marker::PhantomData;
377
378 use burn_tensor::backend::Backend;
379 use burn_tensor::{Device, Tensor};
380
381 use crate::TestBackend;
382 use crate::{
383 TestAutodiffBackend,
384 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
385 };
386 use burn::module::Module;
387
388 use crate as burn;
389
390 #[test]
391 fn tensor_load_record_setting() {
392 let device: &Device<TestAutodiffBackend> = &Default::default();
393 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
394
395 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
396 let bytes = Recorder::<TestAutodiffBackend>::record(
397 &byte_recorder,
398 tensor.clone().into_record(),
399 (),
400 )
401 .unwrap();
402
403 let no_grad_is_require_grad = tensor
404 .clone()
405 .no_grad()
406 .load_record(
407 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
408 .unwrap(),
409 )
410 .is_require_grad();
411
412 let with_default_is_require_grad = tensor
413 .load_record(
414 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
415 .unwrap(),
416 )
417 .is_require_grad();
418
419 assert!(!no_grad_is_require_grad);
420 assert!(!with_default_is_require_grad);
421 }
422
423 #[test]
424 fn empty_module_with_phantom() {
425 #[derive(Module, Debug, new)]
426 struct EmptyModule<B: Backend> {
427 _phantom: PhantomData<B>,
428 }
429
430 let _module = EmptyModule::<TestBackend>::new();
431
432 assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
433 }
434}