1use super::{Param, ParamId, Parameter};
2use crate::module::{
3 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4 ModuleVisitor,
5};
6use crate::tensor::{
7 Tensor,
8 backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14 type Device = B::Device;
15
16 fn device(&self) -> Self::Device {
17 Tensor::device(self)
18 }
19
20 fn is_require_grad(&self) -> bool {
21 Tensor::is_require_grad(self)
22 }
23
24 fn set_require_grad(self, require_grad: bool) -> Self {
25 Tensor::set_require_grad(self, require_grad)
26 }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30 type Device = B::Device;
31
32 fn device(&self) -> Self::Device {
33 Tensor::device(self)
34 }
35
36 fn is_require_grad(&self) -> bool {
37 false
38 }
39
40 fn set_require_grad(self, _require_grad: bool) -> Self {
41 self
42 }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46 type Device = B::Device;
47
48 fn device(&self) -> Self::Device {
49 Tensor::device(self)
50 }
51
52 fn is_require_grad(&self) -> bool {
53 false
54 }
55
56 fn set_require_grad(self, _require_grad: bool) -> Self {
57 self
58 }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62 pub fn from_tensor(value: Tensor<B, D>) -> Self {
70 Param::initialized(ParamId::new(), value.require_grad())
73 }
74
75 pub fn lazy_shape(&self) -> burn_tensor::Shape {
84 let initialization = match &self.initialization {
85 Some(init) => init,
86 None => return self.shape(),
87 };
88
89 let init = initialization.read().unwrap();
90
91 match init.as_ref() {
92 Some(value) => value.shape.clone(),
93 None => self.shape(),
94 }
95 }
96
97 pub fn from_data<T>(data: T, device: &B::Device) -> Self
99 where
100 T: Into<TensorData>,
101 {
102 B::memory_persistent_allocations(device, data, |data| {
105 let value = Tensor::from_data(data, device);
106 Param::initialized(ParamId::new(), value.require_grad())
107 })
108 }
109
110 pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
116 let mut new_tensor = tensor;
117
118 let mapper = self.param_mapper.clone();
119
120 let expected_device = self.lazy_device();
121 let expected_require_grad = self.lazy_is_require_grad();
122
123 if new_tensor.device() != expected_device {
125 new_tensor = new_tensor.to_device(&expected_device).detach();
126 }
127
128 new_tensor = mapper.on_load(new_tensor);
129
130 new_tensor = new_tensor.set_require_grad(expected_require_grad);
132
133 let mut loaded = Self::initialized(param_id, new_tensor);
134 loaded.param_mapper = mapper;
135 loaded
136 }
137
138 pub fn transform_for_save(&self) -> Self {
144 let mut tensor = self.val();
145 let mapper = self.param_mapper.clone();
146
147 tensor = mapper.on_save(tensor);
148
149 Self::initialized(self.id, tensor)
150 }
151}
152
153impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
154 pub fn lazy_shape(&self) -> burn_tensor::Shape {
163 let initialization = match &self.initialization {
164 Some(init) => init,
165 None => return self.shape(),
166 };
167
168 let init = initialization.read().unwrap();
169
170 match init.as_ref() {
171 Some(value) => value.shape.clone(),
172 None => self.shape(),
173 }
174 }
175
176 pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
182 let mut new_tensor = tensor;
183
184 let mapper = self.param_mapper.clone();
185
186 let expected_device = self.lazy_device();
187
188 if new_tensor.device() != expected_device {
190 new_tensor = new_tensor.to_device(&expected_device);
191 }
192
193 new_tensor = mapper.on_load(new_tensor);
194
195 let mut loaded = Self::initialized(param_id, new_tensor);
196 loaded.param_mapper = mapper;
197 loaded
198 }
199
200 pub fn transform_for_save(&self) -> Self {
206 let mut tensor = self.val();
207 let mapper = self.param_mapper.clone();
208
209 tensor = mapper.on_save(tensor);
210
211 Self::initialized(self.id, tensor)
212 }
213}
214
215impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
216 pub fn lazy_shape(&self) -> burn_tensor::Shape {
229 let initialization = match &self.initialization {
230 Some(init) => init,
231 None => return self.shape(),
232 };
233
234 let init = initialization.read().unwrap();
235
236 match init.as_ref() {
237 Some(value) => value.shape.clone(),
238 None => self.shape(),
239 }
240 }
241
242 pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
248 let mut new_tensor = tensor;
249
250 let mapper = self.param_mapper.clone();
251
252 let expected_device = self.lazy_device();
253
254 if new_tensor.device() != expected_device {
256 new_tensor = new_tensor.to_device(&expected_device);
257 }
258
259 new_tensor = mapper.on_load(new_tensor);
260
261 let mut loaded = Self::initialized(param_id, new_tensor);
262 loaded.param_mapper = mapper;
263 loaded
264 }
265
266 pub fn transform_for_save(&self) -> Self {
272 let mut tensor = self.val();
273 let mapper = self.param_mapper.clone();
274
275 tensor = mapper.on_save(tensor);
276
277 Self::initialized(self.id, tensor)
278 }
279}
280
281impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
282 type Record = Param<Tensor<B, D>>;
283
284 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
285 visitor.visit_float(self)
286 }
287
288 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
289 mapper.map_float(self)
290 }
291
292 fn into_record(self) -> Self::Record {
293 self.transform_for_save()
294 }
295
296 fn load_record(self, record: Self::Record) -> Self {
297 let (record_param_id, record_tensor, _) = record.consume();
298 self.transform_for_load(record_tensor, record_param_id)
299 }
300
301 fn to_device(self, device: &Device<B>) -> Self {
302 self.map(|tensor| tensor.to_device(device))
303 }
304
305 fn fork(self, device: &Device<B>) -> Self {
306 self.map(|tensor| {
307 let is_require_grad = tensor.is_require_grad();
308 let mut tensor = tensor.to_device(device).detach();
309
310 if is_require_grad {
311 tensor = tensor.require_grad();
312 }
313
314 tensor
315 })
316 }
317
318 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
319 let device = self.val().device();
320
321 if !devices.contains(&device) {
322 devices.push(device)
323 }
324
325 devices
326 }
327}
328
329impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
330 fn content(&self, content: Content) -> Option<Content> {
331 let id = if content.display_settings.show_param_id() {
332 format!(", id: {}", self.id)
333 } else {
334 "".to_string()
335 };
336 let string = format!(
337 "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
338 self.shape().dims
339 );
340 content.add_formatted(&string).optional()
341 }
342}
343impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
344
345impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
346 type Record = Param<Tensor<B, D, Int>>;
347
348 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
349 visitor.visit_int(self)
350 }
351
352 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
353 mapper.map_int(self)
354 }
355
356 fn into_record(self) -> Self::Record {
357 self.transform_for_save()
358 }
359
360 fn load_record(self, record: Self::Record) -> Self {
361 let (record_param_id, record_tensor, _) = record.consume();
362 self.transform_for_load(record_tensor, record_param_id)
363 }
364
365 fn to_device(self, device: &Device<B>) -> Self {
366 self.map(|tensor| tensor.to_device(device))
367 }
368
369 fn fork(self, device: &Device<B>) -> Self {
370 self.to_device(device) }
372
373 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
374 let device = self.val().device();
375
376 if !devices.contains(&device) {
377 devices.push(device)
378 }
379
380 devices
381 }
382}
383
384impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
385 fn content(&self, content: Content) -> Option<Content> {
386 let id = if content.display_settings.show_param_id() {
387 format!(", id: {}", self.id)
388 } else {
389 "".to_string()
390 };
391 let string = format!(
392 "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
393 self.shape().dims
394 );
395 content.add_formatted(&string).optional()
396 }
397}
398impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
399
400impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
401 type Record = Param<Tensor<B, D, Bool>>;
402
403 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
404 visitor.visit_bool(self)
405 }
406
407 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
408 mapper.map_bool(self)
409 }
410
411 fn into_record(self) -> Self::Record {
412 self.transform_for_save()
413 }
414
415 fn load_record(self, record: Self::Record) -> Self {
416 let (record_param_id, record_tensor, _) = record.consume();
417 self.transform_for_load(record_tensor, record_param_id)
418 }
419
420 fn to_device(self, device: &Device<B>) -> Self {
421 self.map(|tensor| tensor.to_device(device))
422 }
423
424 fn fork(self, device: &Device<B>) -> Self {
425 self.to_device(device) }
427
428 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
429 let device = self.val().device();
430
431 if !devices.contains(&device) {
432 devices.push(device)
433 }
434
435 devices
436 }
437}
438
439impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
440 fn content(&self, content: Content) -> Option<Content> {
441 let id = if content.display_settings.show_param_id() {
442 format!(", id: {}", self.id)
443 } else {
444 "".to_string()
445 };
446
447 let string = format!(
448 "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
449 self.shape().dims
450 );
451 content.add_formatted(&string).optional()
452 }
453}
454
455impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
456
457impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
458 type InnerModule = Param<Tensor<B::InnerBackend, D>>;
459
460 fn valid(&self) -> Self::InnerModule {
461 Param::initialized(self.id, self.val().inner().set_require_grad(false))
462 }
463}
464
465impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
466 type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
467
468 fn valid(&self) -> Self::InnerModule {
469 Param::initialized(self.id, self.val().inner())
470 }
471}
472
473impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
474 type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
475
476 fn valid(&self) -> Self::InnerModule {
477 Param::initialized(self.id, self.val().inner())
478 }
479}
480
481#[cfg(all(test, feature = "std"))]
482mod tests {
483 use super::*;
484 use crate::{
485 TestAutodiffBackend,
486 module::Module,
487 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
488 };
489
490 #[test]
491 fn test_load_record_setting() {
492 let device = Default::default();
493 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
494
495 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
496 let bytes = byte_recorder
497 .record(
498 Param::initialized(ParamId::new(), tensor.clone()).into_record(),
499 (),
500 )
501 .unwrap();
502
503 let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
504 .no_grad()
505 .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
506 .is_require_grad();
507
508 let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
509 .load_record(byte_recorder.load(bytes, &device).unwrap())
510 .is_require_grad();
511
512 assert!(!no_grad_is_require_grad);
513 assert!(with_default_is_require_grad);
514 }
515}