1use alloc::{format, string::ToString};
2use core::{fmt::Display, marker::PhantomData};
3
4use crate::{
5 self as burn,
6 module::{
7 AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
8 ModuleMapper, ModuleVisitor,
9 },
10 record::Record,
11};
12use burn::record::PrecisionSettings;
13use burn_tensor::{
14 BasicAutodiffOps, BasicOps, Tensor,
15 backend::{AutodiffBackend, Backend},
16 ops::Device,
17};
18
19#[derive(Debug, Clone, Copy, new, Default)]
21pub struct ConstantRecord;
22
23impl serde::Serialize for ConstantRecord {
24 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25 where
26 S: serde::Serializer,
27 {
28 S::serialize_none(serializer)
30 }
31}
32
33impl<'de> serde::Deserialize<'de> for ConstantRecord {
34 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35 where
36 D: serde::Deserializer<'de>,
37 {
38 deserializer.deserialize_option(serde::de::IgnoredAny).ok();
39 Ok(ConstantRecord::new())
40 }
41}
42
43impl<B: Backend> Record<B> for ConstantRecord {
44 type Item<S: PrecisionSettings> = ConstantRecord;
45
46 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
47 self
48 }
49
50 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
51 item
52 }
53}
54#[macro_export]
56macro_rules! constant {
57 (module) => {
58 type Record = burn::module::ConstantRecord;
59
60 fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
61 }
63
64 fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
65 self
66 }
67
68 fn load_record(self, _record: Self::Record) -> Self {
69 self
70 }
71
72 fn into_record(self) -> Self::Record {
73 burn::module::ConstantRecord::new()
74 }
75
76 fn to_device(self, _: &B::Device) -> Self {
77 self
78 }
79
80 fn fork(self, _: &B::Device) -> Self {
81 self
82 }
83
84 fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
85 devices
86 }
87 };
88
89 (ad_module, $type:ty) => {
90 type InnerModule = $type;
91
92 fn valid(&self) -> Self::InnerModule {
93 self.clone()
94 }
95 };
96
97 ($type:ty) => {
98 impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
99 constant!(module);
100 }
101
102 impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
103 constant!(ad_module, $type);
104 }
105
106 impl burn::module::ModuleDisplayDefault for $type {
107 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
108 let string = format!("{}", self);
109 content.add_formatted(&string).optional()
110 }
111 }
112
113 impl burn::module::ModuleDisplay for $type {}
114 };
115}
116
117constant!(alloc::string::String);
119constant!(bool);
120
121constant!(f64);
123constant!(f32);
124constant!(half::bf16);
125constant!(half::f16);
126
127constant!(usize);
129constant!(u64);
130constant!(u32);
131constant!(u16);
132constant!(u8);
133
134constant!(isize);
136constant!(i64);
137constant!(i32);
138constant!(i16);
139constant!(i8);
140
141impl burn::module::ModuleDisplay for str {}
142impl burn::module::ModuleDisplayDefault for str {
143 fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
144 content.add_formatted(&self).optional()
145 }
146}
147
148impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
149 type Record = ConstantRecord;
150
151 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
152
153 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
154 self
155 }
156
157 fn into_record(self) -> Self::Record {
158 ConstantRecord
159 }
160
161 fn load_record(self, _record: Self::Record) -> Self {
162 self
163 }
164
165 fn to_device(self, device: &B::Device) -> Self {
166 self.to_device(device)
167 }
168
169 fn fork(self, device: &B::Device) -> Self {
170 self.to_device(device)
171 }
172
173 fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
174 let device = self.device();
175
176 if !devices.contains(&device) {
177 devices.push(device)
178 }
179
180 devices
181 }
182}
183
184impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
185 fn content(&self, content: Content) -> Option<Content> {
186 let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
187 content.add_single(&string).optional()
188 }
189}
190
191impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
192
193impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
194 for Tensor<B, D, K>
195{
196 type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
197
198 fn valid(&self) -> Self::InnerModule {
199 self.clone().inner()
200 }
201}
202
203impl<B: Backend> Module<B> for PhantomData<B> {
204 type Record = ConstantRecord;
205
206 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
207 }
209
210 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
211 self
212 }
213
214 fn load_record(self, _record: Self::Record) -> Self {
215 self
216 }
217
218 fn into_record(self) -> Self::Record {
219 ConstantRecord::new()
220 }
221
222 fn to_device(self, _: &Device<B>) -> Self {
223 self
224 }
225
226 fn fork(self, _: &Device<B>) -> Self {
227 self
228 }
229
230 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
231 devices
232 }
233}
234
235impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
236 fn content(&self, content: Content) -> Option<Content> {
237 content.add_single(&"PhantomData".to_string()).optional()
238 }
239}
240
241impl<B: Backend> ModuleDisplay for PhantomData<B> {}
242
243impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
244 type InnerModule = PhantomData<B::InnerBackend>;
245
246 fn valid(&self) -> Self::InnerModule {
247 PhantomData
248 }
249}
250
251#[derive(Clone, Debug)]
253pub struct Ignored<T>(pub T);
254
255impl<B, T> Module<B> for Ignored<T>
256where
257 B: Backend,
258 T: Sync + Send + core::fmt::Debug + Clone,
259{
260 type Record = ConstantRecord;
261
262 fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
263 }
265
266 fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
267 self
268 }
269
270 fn load_record(self, _record: Self::Record) -> Self {
271 self
272 }
273
274 fn into_record(self) -> Self::Record {
275 ConstantRecord::new()
276 }
277
278 fn to_device(self, _: &Device<B>) -> Self {
279 self
280 }
281
282 fn fork(self, _: &Device<B>) -> Self {
283 self
284 }
285
286 fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
287 devices
288 }
289}
290
291impl<T> ModuleDisplayDefault for Ignored<T>
292where
293 T: Sync + Send + core::fmt::Debug + Clone,
294{
295 fn content(&self, content: Content) -> Option<Content> {
296 content.add_single(&format!("{:?}", self.0)).optional()
298 }
299}
300
301impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
302
303impl<T> Display for Ignored<T>
304where
305 T: Sync + Send + core::fmt::Debug + Clone,
306{
307 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
308 write!(f, "{:?}", self.0)
309 }
310}
311
312impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
313where
314 B: AutodiffBackend,
315 T: Sync + Send + core::fmt::Debug + Clone,
316{
317 type InnerModule = Ignored<T>;
318
319 fn valid(&self) -> Self::InnerModule {
320 self.clone()
321 }
322}
323
324impl<T> core::ops::Deref for Ignored<T> {
326 type Target = T;
327
328 fn deref(&self) -> &Self::Target {
329 &self.0
330 }
331}
332
333#[cfg(all(test, feature = "std"))]
334mod tests {
335 use core::marker::PhantomData;
336
337 use burn_tensor::backend::Backend;
338 use burn_tensor::{Device, Tensor};
339
340 use crate::TestBackend;
341 use crate::{
342 TestAutodiffBackend,
343 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
344 };
345 use burn::module::Module;
346
347 use crate as burn;
348
349 #[test]
350 fn tensor_load_record_setting() {
351 let device: &Device<TestAutodiffBackend> = &Default::default();
352 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
353
354 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
355 let bytes = Recorder::<TestAutodiffBackend>::record(
356 &byte_recorder,
357 tensor.clone().into_record(),
358 (),
359 )
360 .unwrap();
361
362 let no_grad_is_require_grad = tensor
363 .clone()
364 .no_grad()
365 .load_record(
366 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
367 .unwrap(),
368 )
369 .is_require_grad();
370
371 let with_default_is_require_grad = tensor
372 .load_record(
373 Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
374 .unwrap(),
375 )
376 .is_require_grad();
377
378 assert!(!no_grad_is_require_grad);
379 assert!(!with_default_is_require_grad);
380 }
381
382 #[test]
383 fn empty_module_with_phantom() {
384 #[derive(Module, Debug, new)]
385 struct EmptyModule<B: Backend> {
386 _phantom: PhantomData<B>,
387 }
388
389 let _module = EmptyModule::<TestBackend>::new();
390
391 assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
392 }
393}