Skip to main content

burn_core/module/param/
tensor.rs

1use super::{Param, ParamId, Parameter};
2use crate::module::{
3    AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,
4    ModuleMapper, ModuleVisitor,
5};
6use crate::tensor::{
7    Tensor,
8    backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14    type Device = B::Device;
15
16    fn device(&self) -> Self::Device {
17        Tensor::device(self)
18    }
19
20    fn is_require_grad(&self) -> bool {
21        Tensor::is_require_grad(self)
22    }
23
24    fn set_require_grad(self, require_grad: bool) -> Self {
25        Tensor::set_require_grad(self, require_grad)
26    }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30    type Device = B::Device;
31
32    fn device(&self) -> Self::Device {
33        Tensor::device(self)
34    }
35
36    fn is_require_grad(&self) -> bool {
37        false
38    }
39
40    fn set_require_grad(self, _require_grad: bool) -> Self {
41        self
42    }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46    type Device = B::Device;
47
48    fn device(&self) -> Self::Device {
49        Tensor::device(self)
50    }
51
52    fn is_require_grad(&self) -> bool {
53        false
54    }
55
56    fn set_require_grad(self, _require_grad: bool) -> Self {
57        self
58    }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62    /// Create a new parameter from a float tensor.
63    ///
64    /// # Warnings
65    ///
66    /// We strongly recommend using [Param::uninitialized] if you are using this method to
67    /// initialize parameters inside a module, since the tensor initialization will be lazy,
68    /// making the loading of weights more performant.
69    pub fn from_tensor(value: Tensor<B, D>) -> Self {
70        // When creating a parameter from a float tensor, we automatically mark it as requiring
71        // gradients, so that it can be updated by an optimizer.
72        Param::initialized(ParamId::new(), value.require_grad())
73    }
74
75    /// The shape of the parameter, **without triggering initialization**.
76    ///
77    /// This is critical for shape validation during loading: when applying tensors to an
78    /// uninitialized parameter, we need to validate the shape without triggering the
79    /// initialization function (which would allocate an unnecessary tensor).
80    ///
81    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
82    /// preserve lazy initialization.
83    pub fn lazy_shape(&self) -> burn_tensor::Shape {
84        let initialization = match &self.initialization {
85            Some(init) => init,
86            None => return self.shape(),
87        };
88
89        let init = initialization.read().unwrap();
90
91        match init.as_ref() {
92            Some(value) => value.shape.clone(),
93            None => self.shape(),
94        }
95    }
96
97    /// Create a new parameter from data.
98    pub fn from_data<T>(data: T, device: &B::Device) -> Self
99    where
100        T: Into<TensorData>,
101    {
102        let data: TensorData = data.into();
103        // When creating a parameter from a float tensor, we automatically mark it as requiring
104        // gradients, so that it can be updated by an optimizer.
105        B::memory_persistent_allocations(device, data, |data| {
106            let value = Tensor::from_data(data, device);
107            Param::initialized(ParamId::new(), value.require_grad())
108        })
109    }
110
111    /// Transform a parameter for loading by applying load transformations.
112    ///
113    /// This method is used to restore a parameter from a tensor (typically during deserialization).
114    /// It ensures the tensor is moved to the expected device, applies the param mapper's
115    /// `on_load` transformation, and preserves the autodiff settings (require_grad).
116    pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
117        let mut new_tensor = tensor;
118
119        let mapper = self.param_mapper.clone();
120
121        let expected_device = self.lazy_device();
122        let expected_require_grad = self.lazy_is_require_grad();
123
124        // Make sure we load the tensor into the same module device.
125        if new_tensor.device() != expected_device {
126            new_tensor = new_tensor.to_device(&expected_device).detach();
127        }
128
129        new_tensor = mapper.on_load(new_tensor);
130
131        // Make sure we load the tensor with the same autodiff setting.
132        new_tensor = new_tensor.set_require_grad(expected_require_grad);
133
134        let mut loaded = Self::initialized(param_id, new_tensor);
135        loaded.param_mapper = mapper;
136        loaded
137    }
138
139    /// Transform a parameter for saving by applying save transformations.
140    ///
141    /// This method is used to prepare a parameter for saving (typically during serialization).
142    /// It applies the param mapper's `on_save` transformation, which can be used
143    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
144    pub fn transform_for_save(&self) -> Self {
145        let mut tensor = self.val();
146        let mapper = self.param_mapper.clone();
147
148        tensor = mapper.on_save(tensor);
149
150        Self::initialized(self.id, tensor)
151    }
152}
153
154impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
155    /// The shape of the parameter, **without triggering initialization**.
156    ///
157    /// This is critical for shape validation during loading: when applying tensors to an
158    /// uninitialized parameter, we need to validate the shape without triggering the
159    /// initialization function (which would allocate an unnecessary tensor).
160    ///
161    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
162    /// preserve lazy initialization.
163    pub fn lazy_shape(&self) -> burn_tensor::Shape {
164        let initialization = match &self.initialization {
165            Some(init) => init,
166            None => return self.shape(),
167        };
168
169        let init = initialization.read().unwrap();
170
171        match init.as_ref() {
172            Some(value) => value.shape.clone(),
173            None => self.shape(),
174        }
175    }
176
177    /// Transform a parameter for loading by applying load transformations.
178    ///
179    /// This method is used to restore a parameter from a tensor (typically during deserialization).
180    /// It ensures the tensor is moved to the expected device and applies the param mapper's
181    /// `on_load` transformation.
182    pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
183        let mut new_tensor = tensor;
184
185        let mapper = self.param_mapper.clone();
186
187        let expected_device = self.lazy_device();
188
189        // Make sure we load the tensor into the same module device.
190        if new_tensor.device() != expected_device {
191            new_tensor = new_tensor.to_device(&expected_device);
192        }
193
194        new_tensor = mapper.on_load(new_tensor);
195
196        let mut loaded = Self::initialized(param_id, new_tensor);
197        loaded.param_mapper = mapper;
198        loaded
199    }
200
201    /// Transform a parameter for saving by applying save transformations.
202    ///
203    /// This method is used to prepare a parameter for saving (typically during serialization).
204    /// It applies the param mapper's `on_save` transformation, which can be used
205    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
206    pub fn transform_for_save(&self) -> Self {
207        let mut tensor = self.val();
208        let mapper = self.param_mapper.clone();
209
210        tensor = mapper.on_save(tensor);
211
212        Self::initialized(self.id, tensor)
213    }
214}
215
216impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
217    /// The shape of the parameter, **without triggering initialization**.
218    ///
219    /// This is critical for shape validation during loading: when applying tensors to an
220    /// uninitialized parameter, we need to validate the shape without triggering the
221    /// initialization function (which would allocate an unnecessary tensor).
222    ///
223    /// **Returns:**
224    /// - For uninitialized params: the shape from the `Uninitialized` struct
225    /// - For initialized params: the actual shape from the tensor
226    ///
227    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
228    /// preserve lazy initialization.
229    pub fn lazy_shape(&self) -> burn_tensor::Shape {
230        let initialization = match &self.initialization {
231            Some(init) => init,
232            None => return self.shape(),
233        };
234
235        let init = initialization.read().unwrap();
236
237        match init.as_ref() {
238            Some(value) => value.shape.clone(),
239            None => self.shape(),
240        }
241    }
242
243    /// Transform a parameter for loading by applying load transformations.
244    ///
245    /// This method is used to restore a parameter from a tensor (typically during deserialization).
246    /// It ensures the tensor is moved to the expected device and applies the param mapper's
247    /// `on_load` transformation.
248    pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
249        let mut new_tensor = tensor;
250
251        let mapper = self.param_mapper.clone();
252
253        let expected_device = self.lazy_device();
254
255        // Make sure we load the tensor into the same module device.
256        if new_tensor.device() != expected_device {
257            new_tensor = new_tensor.to_device(&expected_device);
258        }
259
260        new_tensor = mapper.on_load(new_tensor);
261
262        let mut loaded = Self::initialized(param_id, new_tensor);
263        loaded.param_mapper = mapper;
264        loaded
265    }
266
267    /// Transform a parameter for saving by applying save transformations.
268    ///
269    /// This method is used to prepare a parameter for saving (typically during serialization).
270    /// It applies the param mapper's `on_save` transformation, which can be used
271    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
272    pub fn transform_for_save(&self) -> Self {
273        let mut tensor = self.val();
274        let mapper = self.param_mapper.clone();
275
276        tensor = mapper.on_save(tensor);
277
278        Self::initialized(self.id, tensor)
279    }
280}
281
282impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
283    type Record = Param<Tensor<B, D>>;
284
285    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
286        visitor.visit_float(self)
287    }
288
289    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
290        mapper.map_float(self)
291    }
292
293    fn into_record(self) -> Self::Record {
294        self.transform_for_save()
295    }
296
297    fn load_record(self, record: Self::Record) -> Self {
298        let (record_param_id, record_tensor, _) = record.consume();
299        self.transform_for_load(record_tensor, record_param_id)
300    }
301
302    fn to_device(self, device: &Device<B>) -> Self {
303        self.map(|tensor| tensor.to_device(device))
304    }
305
306    fn fork(self, device: &Device<B>) -> Self {
307        self.map(|tensor| {
308            let is_require_grad = tensor.is_require_grad();
309            let mut tensor = tensor.to_device(device).detach();
310
311            if is_require_grad {
312                tensor = tensor.require_grad();
313            }
314
315            tensor
316        })
317    }
318
319    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
320        let device = self.val().device();
321
322        if !devices.contains(&device) {
323            devices.push(device)
324        }
325
326        devices
327    }
328}
329
330impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
331    fn content(&self, content: Content) -> Option<Content> {
332        let id = if content.display_settings.show_param_id() {
333            format!(", id: {}", self.id)
334        } else {
335            "".to_string()
336        };
337        let string = format!(
338            "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
339            self.shape().as_slice()
340        );
341        content.add_formatted(&string).optional()
342    }
343}
344impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
345
346impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
347    type Record = Param<Tensor<B, D, Int>>;
348
349    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
350        visitor.visit_int(self)
351    }
352
353    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
354        mapper.map_int(self)
355    }
356
357    fn into_record(self) -> Self::Record {
358        self.transform_for_save()
359    }
360
361    fn load_record(self, record: Self::Record) -> Self {
362        let (record_param_id, record_tensor, _) = record.consume();
363        self.transform_for_load(record_tensor, record_param_id)
364    }
365
366    fn to_device(self, device: &Device<B>) -> Self {
367        self.map(|tensor| tensor.to_device(device))
368    }
369
370    fn fork(self, device: &Device<B>) -> Self {
371        self.to_device(device) // Don't support autodiff.
372    }
373
374    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
375        let device = self.val().device();
376
377        if !devices.contains(&device) {
378            devices.push(device)
379        }
380
381        devices
382    }
383}
384
385impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
386    fn content(&self, content: Content) -> Option<Content> {
387        let id = if content.display_settings.show_param_id() {
388            format!(", id: {}", self.id)
389        } else {
390            "".to_string()
391        };
392        let string = format!(
393            "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
394            self.shape().as_slice()
395        );
396        content.add_formatted(&string).optional()
397    }
398}
399impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
400
401impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
402    type Record = Param<Tensor<B, D, Bool>>;
403
404    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
405        visitor.visit_bool(self)
406    }
407
408    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
409        mapper.map_bool(self)
410    }
411
412    fn into_record(self) -> Self::Record {
413        self.transform_for_save()
414    }
415
416    fn load_record(self, record: Self::Record) -> Self {
417        let (record_param_id, record_tensor, _) = record.consume();
418        self.transform_for_load(record_tensor, record_param_id)
419    }
420
421    fn to_device(self, device: &Device<B>) -> Self {
422        self.map(|tensor| tensor.to_device(device))
423    }
424
425    fn fork(self, device: &Device<B>) -> Self {
426        self.to_device(device) // Don't support autodiff.
427    }
428
429    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
430        let device = self.val().device();
431
432        if !devices.contains(&device) {
433            devices.push(device)
434        }
435
436        devices
437    }
438}
439
440impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
441    fn content(&self, content: Content) -> Option<Content> {
442        let id = if content.display_settings.show_param_id() {
443            format!(", id: {}", self.id)
444        } else {
445            "".to_string()
446        };
447
448        let string = format!(
449            "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
450            self.shape().as_slice()
451        );
452        content.add_formatted(&string).optional()
453    }
454}
455
456impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
457
458impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
459    type InnerModule = Param<Tensor<B::InnerBackend, D>>;
460
461    fn valid(&self) -> Self::InnerModule {
462        // Preserve initialized param `require_grad` state, but reset the inner value's
463        let require_grad = self.require_grad;
464        let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));
465        param.require_grad = require_grad;
466        param
467    }
468
469    fn from_inner(module: Self::InnerModule) -> Self {
470        // Reinstate the param's `require_grad` state
471        let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);
472        Param::initialized(module.id, tensor)
473    }
474}
475
476impl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>
477    for Param<Tensor<B::InnerBackend, D>>
478{
479    type TrainModule = Param<Tensor<B, D>>;
480}
481
482impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
483    type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
484
485    fn valid(&self) -> Self::InnerModule {
486        Param::initialized(self.id, self.val().inner())
487    }
488
489    fn from_inner(module: Self::InnerModule) -> Self {
490        Param::initialized(module.id, Tensor::from_inner(module.val()))
491    }
492}
493
494impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
495    type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
496
497    fn valid(&self) -> Self::InnerModule {
498        Param::initialized(self.id, self.val().inner())
499    }
500
501    fn from_inner(module: Self::InnerModule) -> Self {
502        Param::initialized(module.id, Tensor::from_inner(module.val()))
503    }
504}
505
506#[cfg(all(test, feature = "std"))]
507mod tests {
508    use super::*;
509    use crate::{
510        TestAutodiffBackend,
511        module::Module,
512        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
513    };
514
515    #[test]
516    fn test_load_record_setting() {
517        let device = Default::default();
518        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
519
520        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
521        let bytes = byte_recorder
522            .record(
523                Param::initialized(ParamId::new(), tensor.clone()).into_record(),
524                (),
525            )
526            .unwrap();
527
528        let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
529            .no_grad()
530            .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
531            .is_require_grad();
532
533        let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
534            .load_record(byte_recorder.load(bytes, &device).unwrap())
535            .is_require_grad();
536
537        assert!(!no_grad_is_require_grad);
538        assert!(with_default_is_require_grad);
539    }
540
541    #[test]
542    fn test_param_require_grad_stateful() {
543        let device = Default::default();
544        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
545
546        let param = Param::initialized(ParamId::new(), tensor);
547        assert!(param.is_require_grad());
548        assert!(param.require_grad);
549
550        let param = param.valid();
551        assert!(!param.is_require_grad());
552        assert!(param.require_grad); // stateful
553
554        // Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying:
555        // let param: Param<Tensor<TestAutodiffBackend, _>> = param.train();
556        let param = param.train::<TestAutodiffBackend>();
557        assert!(param.is_require_grad());
558        assert!(param.require_grad); // stateful
559
560        let param = param.no_grad();
561        assert!(!param.is_require_grad());
562        assert!(!param.require_grad); // stateful
563
564        let param = param.valid();
565        assert!(!param.is_require_grad()); // always
566        assert!(!param.require_grad); // stateful
567
568        let param = param.train::<TestAutodiffBackend>();
569        assert!(!param.is_require_grad());
570        assert!(!param.require_grad); // stateful
571    }
572}