1use super::{Param, ParamId, Parameter};
2use crate::module::{
3 AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,
4 ModuleMapper, ModuleVisitor,
5};
6use crate::tensor::{
7 Tensor,
8 backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14 type Device = B::Device;
15
16 fn device(&self) -> Self::Device {
17 Tensor::device(self)
18 }
19
20 fn is_require_grad(&self) -> bool {
21 Tensor::is_require_grad(self)
22 }
23
24 fn set_require_grad(self, require_grad: bool) -> Self {
25 Tensor::set_require_grad(self, require_grad)
26 }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30 type Device = B::Device;
31
32 fn device(&self) -> Self::Device {
33 Tensor::device(self)
34 }
35
36 fn is_require_grad(&self) -> bool {
37 false
38 }
39
40 fn set_require_grad(self, _require_grad: bool) -> Self {
41 self
42 }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46 type Device = B::Device;
47
48 fn device(&self) -> Self::Device {
49 Tensor::device(self)
50 }
51
52 fn is_require_grad(&self) -> bool {
53 false
54 }
55
56 fn set_require_grad(self, _require_grad: bool) -> Self {
57 self
58 }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62 pub fn from_tensor(value: Tensor<B, D>) -> Self {
70 Param::initialized(ParamId::new(), value.require_grad())
73 }
74
75 pub fn lazy_shape(&self) -> burn_tensor::Shape {
84 let initialization = match &self.initialization {
85 Some(init) => init,
86 None => return self.shape(),
87 };
88
89 let init = initialization.read().unwrap();
90
91 match init.as_ref() {
92 Some(value) => value.shape.clone(),
93 None => self.shape(),
94 }
95 }
96
97 pub fn from_data<T>(data: T, device: &B::Device) -> Self
99 where
100 T: Into<TensorData>,
101 {
102 let data: TensorData = data.into();
103 B::memory_persistent_allocations(device, data, |data| {
106 let value = Tensor::from_data(data, device);
107 Param::initialized(ParamId::new(), value.require_grad())
108 })
109 }
110
111 pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
117 let mut new_tensor = tensor;
118
119 let mapper = self.param_mapper.clone();
120
121 let expected_device = self.lazy_device();
122 let expected_require_grad = self.lazy_is_require_grad();
123
124 if new_tensor.device() != expected_device {
126 new_tensor = new_tensor.to_device(&expected_device).detach();
127 }
128
129 new_tensor = mapper.on_load(new_tensor);
130
131 new_tensor = new_tensor.set_require_grad(expected_require_grad);
133
134 let mut loaded = Self::initialized(param_id, new_tensor);
135 loaded.param_mapper = mapper;
136 loaded
137 }
138
139 pub fn transform_for_save(&self) -> Self {
145 let mut tensor = self.val();
146 let mapper = self.param_mapper.clone();
147
148 tensor = mapper.on_save(tensor);
149
150 Self::initialized(self.id, tensor)
151 }
152}
153
154impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
155 pub fn lazy_shape(&self) -> burn_tensor::Shape {
164 let initialization = match &self.initialization {
165 Some(init) => init,
166 None => return self.shape(),
167 };
168
169 let init = initialization.read().unwrap();
170
171 match init.as_ref() {
172 Some(value) => value.shape.clone(),
173 None => self.shape(),
174 }
175 }
176
177 pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
183 let mut new_tensor = tensor;
184
185 let mapper = self.param_mapper.clone();
186
187 let expected_device = self.lazy_device();
188
189 if new_tensor.device() != expected_device {
191 new_tensor = new_tensor.to_device(&expected_device);
192 }
193
194 new_tensor = mapper.on_load(new_tensor);
195
196 let mut loaded = Self::initialized(param_id, new_tensor);
197 loaded.param_mapper = mapper;
198 loaded
199 }
200
201 pub fn transform_for_save(&self) -> Self {
207 let mut tensor = self.val();
208 let mapper = self.param_mapper.clone();
209
210 tensor = mapper.on_save(tensor);
211
212 Self::initialized(self.id, tensor)
213 }
214}
215
216impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
217 pub fn lazy_shape(&self) -> burn_tensor::Shape {
230 let initialization = match &self.initialization {
231 Some(init) => init,
232 None => return self.shape(),
233 };
234
235 let init = initialization.read().unwrap();
236
237 match init.as_ref() {
238 Some(value) => value.shape.clone(),
239 None => self.shape(),
240 }
241 }
242
243 pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
249 let mut new_tensor = tensor;
250
251 let mapper = self.param_mapper.clone();
252
253 let expected_device = self.lazy_device();
254
255 if new_tensor.device() != expected_device {
257 new_tensor = new_tensor.to_device(&expected_device);
258 }
259
260 new_tensor = mapper.on_load(new_tensor);
261
262 let mut loaded = Self::initialized(param_id, new_tensor);
263 loaded.param_mapper = mapper;
264 loaded
265 }
266
267 pub fn transform_for_save(&self) -> Self {
273 let mut tensor = self.val();
274 let mapper = self.param_mapper.clone();
275
276 tensor = mapper.on_save(tensor);
277
278 Self::initialized(self.id, tensor)
279 }
280}
281
282impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
283 type Record = Param<Tensor<B, D>>;
284
285 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
286 visitor.visit_float(self)
287 }
288
289 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
290 mapper.map_float(self)
291 }
292
293 fn into_record(self) -> Self::Record {
294 self.transform_for_save()
295 }
296
297 fn load_record(self, record: Self::Record) -> Self {
298 let (record_param_id, record_tensor, _) = record.consume();
299 self.transform_for_load(record_tensor, record_param_id)
300 }
301
302 fn to_device(self, device: &Device<B>) -> Self {
303 self.map(|tensor| tensor.to_device(device))
304 }
305
306 fn fork(self, device: &Device<B>) -> Self {
307 self.map(|tensor| {
308 let is_require_grad = tensor.is_require_grad();
309 let mut tensor = tensor.to_device(device).detach();
310
311 if is_require_grad {
312 tensor = tensor.require_grad();
313 }
314
315 tensor
316 })
317 }
318
319 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
320 let device = self.val().device();
321
322 if !devices.contains(&device) {
323 devices.push(device)
324 }
325
326 devices
327 }
328}
329
330impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
331 fn content(&self, content: Content) -> Option<Content> {
332 let id = if content.display_settings.show_param_id() {
333 format!(", id: {}", self.id)
334 } else {
335 "".to_string()
336 };
337 let string = format!(
338 "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
339 self.shape().as_slice()
340 );
341 content.add_formatted(&string).optional()
342 }
343}
344impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
345
346impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
347 type Record = Param<Tensor<B, D, Int>>;
348
349 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
350 visitor.visit_int(self)
351 }
352
353 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
354 mapper.map_int(self)
355 }
356
357 fn into_record(self) -> Self::Record {
358 self.transform_for_save()
359 }
360
361 fn load_record(self, record: Self::Record) -> Self {
362 let (record_param_id, record_tensor, _) = record.consume();
363 self.transform_for_load(record_tensor, record_param_id)
364 }
365
366 fn to_device(self, device: &Device<B>) -> Self {
367 self.map(|tensor| tensor.to_device(device))
368 }
369
370 fn fork(self, device: &Device<B>) -> Self {
371 self.to_device(device) }
373
374 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
375 let device = self.val().device();
376
377 if !devices.contains(&device) {
378 devices.push(device)
379 }
380
381 devices
382 }
383}
384
385impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
386 fn content(&self, content: Content) -> Option<Content> {
387 let id = if content.display_settings.show_param_id() {
388 format!(", id: {}", self.id)
389 } else {
390 "".to_string()
391 };
392 let string = format!(
393 "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
394 self.shape().as_slice()
395 );
396 content.add_formatted(&string).optional()
397 }
398}
399impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
400
401impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
402 type Record = Param<Tensor<B, D, Bool>>;
403
404 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
405 visitor.visit_bool(self)
406 }
407
408 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
409 mapper.map_bool(self)
410 }
411
412 fn into_record(self) -> Self::Record {
413 self.transform_for_save()
414 }
415
416 fn load_record(self, record: Self::Record) -> Self {
417 let (record_param_id, record_tensor, _) = record.consume();
418 self.transform_for_load(record_tensor, record_param_id)
419 }
420
421 fn to_device(self, device: &Device<B>) -> Self {
422 self.map(|tensor| tensor.to_device(device))
423 }
424
425 fn fork(self, device: &Device<B>) -> Self {
426 self.to_device(device) }
428
429 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
430 let device = self.val().device();
431
432 if !devices.contains(&device) {
433 devices.push(device)
434 }
435
436 devices
437 }
438}
439
440impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
441 fn content(&self, content: Content) -> Option<Content> {
442 let id = if content.display_settings.show_param_id() {
443 format!(", id: {}", self.id)
444 } else {
445 "".to_string()
446 };
447
448 let string = format!(
449 "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
450 self.shape().as_slice()
451 );
452 content.add_formatted(&string).optional()
453 }
454}
455
456impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
457
458impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
459 type InnerModule = Param<Tensor<B::InnerBackend, D>>;
460
461 fn valid(&self) -> Self::InnerModule {
462 let require_grad = self.require_grad;
464 let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));
465 param.require_grad = require_grad;
466 param
467 }
468
469 fn from_inner(module: Self::InnerModule) -> Self {
470 let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);
472 Param::initialized(module.id, tensor)
473 }
474}
475
476impl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>
477 for Param<Tensor<B::InnerBackend, D>>
478{
479 type TrainModule = Param<Tensor<B, D>>;
480}
481
482impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
483 type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
484
485 fn valid(&self) -> Self::InnerModule {
486 Param::initialized(self.id, self.val().inner())
487 }
488
489 fn from_inner(module: Self::InnerModule) -> Self {
490 Param::initialized(module.id, Tensor::from_inner(module.val()))
491 }
492}
493
494impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
495 type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
496
497 fn valid(&self) -> Self::InnerModule {
498 Param::initialized(self.id, self.val().inner())
499 }
500
501 fn from_inner(module: Self::InnerModule) -> Self {
502 Param::initialized(module.id, Tensor::from_inner(module.val()))
503 }
504}
505
506#[cfg(all(test, feature = "std"))]
507mod tests {
508 use super::*;
509 use crate::{
510 TestAutodiffBackend,
511 module::Module,
512 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
513 };
514
515 #[test]
516 fn test_load_record_setting() {
517 let device = Default::default();
518 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
519
520 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
521 let bytes = byte_recorder
522 .record(
523 Param::initialized(ParamId::new(), tensor.clone()).into_record(),
524 (),
525 )
526 .unwrap();
527
528 let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
529 .no_grad()
530 .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
531 .is_require_grad();
532
533 let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
534 .load_record(byte_recorder.load(bytes, &device).unwrap())
535 .is_require_grad();
536
537 assert!(!no_grad_is_require_grad);
538 assert!(with_default_is_require_grad);
539 }
540
541 #[test]
542 fn test_param_require_grad_stateful() {
543 let device = Default::default();
544 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
545
546 let param = Param::initialized(ParamId::new(), tensor);
547 assert!(param.is_require_grad());
548 assert!(param.require_grad);
549
550 let param = param.valid();
551 assert!(!param.is_require_grad());
552 assert!(param.require_grad); let param = param.train::<TestAutodiffBackend>();
557 assert!(param.is_require_grad());
558 assert!(param.require_grad); let param = param.no_grad();
561 assert!(!param.is_require_grad());
562 assert!(!param.require_grad); let param = param.valid();
565 assert!(!param.is_require_grad()); assert!(!param.require_grad); let param = param.train::<TestAutodiffBackend>();
569 assert!(!param.is_require_grad());
570 assert!(!param.require_grad); }
572}