burn_core/module/param/
tensor.rs

1use super::{Param, ParamId, Parameter};
2use crate::module::{
3    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4    ModuleVisitor,
5};
6use crate::tensor::{
7    Tensor,
8    backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14    type Device = B::Device;
15
16    fn device(&self) -> Self::Device {
17        Tensor::device(self)
18    }
19
20    fn is_require_grad(&self) -> bool {
21        Tensor::is_require_grad(self)
22    }
23
24    fn set_require_grad(self, require_grad: bool) -> Self {
25        Tensor::set_require_grad(self, require_grad)
26    }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30    type Device = B::Device;
31
32    fn device(&self) -> Self::Device {
33        Tensor::device(self)
34    }
35
36    fn is_require_grad(&self) -> bool {
37        false
38    }
39
40    fn set_require_grad(self, _require_grad: bool) -> Self {
41        self
42    }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46    type Device = B::Device;
47
48    fn device(&self) -> Self::Device {
49        Tensor::device(self)
50    }
51
52    fn is_require_grad(&self) -> bool {
53        false
54    }
55
56    fn set_require_grad(self, _require_grad: bool) -> Self {
57        self
58    }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62    /// Create a new parameter from a float tensor.
63    ///
64    /// # Warnings
65    ///
66    /// We strongly recommend using [Param::uninitialized] if you are using this method to
67    /// initialize parameters inside a module, since the tensor initialization will be lazy,
68    /// making the loading of weights more performant.
69    pub fn from_tensor(value: Tensor<B, D>) -> Self {
70        // When creating a parameter from a float tensor, we automatically mark it as requiring
71        // gradients, so that it can be updated by an optimizer.
72        Param::initialized(ParamId::new(), value.require_grad())
73    }
74
75    /// The shape of the parameter, **without triggering initialization**.
76    ///
77    /// This is critical for shape validation during loading: when applying tensors to an
78    /// uninitialized parameter, we need to validate the shape without triggering the
79    /// initialization function (which would allocate an unnecessary tensor).
80    ///
81    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
82    /// preserve lazy initialization.
83    pub fn lazy_shape(&self) -> burn_tensor::Shape {
84        let initialization = match &self.initialization {
85            Some(init) => init,
86            None => return self.shape(),
87        };
88
89        let init = initialization.read().unwrap();
90
91        match init.as_ref() {
92            Some(value) => value.shape.clone(),
93            None => self.shape(),
94        }
95    }
96
97    /// Create a new parameter from data.
98    pub fn from_data<T>(data: T, device: &B::Device) -> Self
99    where
100        T: Into<TensorData>,
101    {
102        // When creating a parameter from a float tensor, we automatically mark it as requiring
103        // gradients, so that it can be updated by an optimizer.
104        B::memory_persistent_allocations(device, data, |data| {
105            let value = Tensor::from_data(data, device);
106            Param::initialized(ParamId::new(), value.require_grad())
107        })
108    }
109
110    /// Transform a parameter for loading by applying load transformations.
111    ///
112    /// This method is used to restore a parameter from a tensor (typically during deserialization).
113    /// It ensures the tensor is moved to the expected device, applies the param mapper's
114    /// `on_load` transformation, and preserves the autodiff settings (require_grad).
115    pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {
116        let mut new_tensor = tensor;
117
118        let mapper = self.param_mapper.clone();
119
120        let expected_device = self.lazy_device();
121        let expected_require_grad = self.lazy_is_require_grad();
122
123        // Make sure we load the tensor into the same module device.
124        if new_tensor.device() != expected_device {
125            new_tensor = new_tensor.to_device(&expected_device).detach();
126        }
127
128        new_tensor = mapper.on_load(new_tensor);
129
130        // Make sure we load the tensor with the same autodiff setting.
131        new_tensor = new_tensor.set_require_grad(expected_require_grad);
132
133        let mut loaded = Self::initialized(param_id, new_tensor);
134        loaded.param_mapper = mapper;
135        loaded
136    }
137
138    /// Transform a parameter for saving by applying save transformations.
139    ///
140    /// This method is used to prepare a parameter for saving (typically during serialization).
141    /// It applies the param mapper's `on_save` transformation, which can be used
142    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
143    pub fn transform_for_save(&self) -> Self {
144        let mut tensor = self.val();
145        let mapper = self.param_mapper.clone();
146
147        tensor = mapper.on_save(tensor);
148
149        Self::initialized(self.id, tensor)
150    }
151}
152
153impl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {
154    /// The shape of the parameter, **without triggering initialization**.
155    ///
156    /// This is critical for shape validation during loading: when applying tensors to an
157    /// uninitialized parameter, we need to validate the shape without triggering the
158    /// initialization function (which would allocate an unnecessary tensor).
159    ///
160    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
161    /// preserve lazy initialization.
162    pub fn lazy_shape(&self) -> burn_tensor::Shape {
163        let initialization = match &self.initialization {
164            Some(init) => init,
165            None => return self.shape(),
166        };
167
168        let init = initialization.read().unwrap();
169
170        match init.as_ref() {
171            Some(value) => value.shape.clone(),
172            None => self.shape(),
173        }
174    }
175
176    /// Transform a parameter for loading by applying load transformations.
177    ///
178    /// This method is used to restore a parameter from a tensor (typically during deserialization).
179    /// It ensures the tensor is moved to the expected device and applies the param mapper's
180    /// `on_load` transformation.
181    pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {
182        let mut new_tensor = tensor;
183
184        let mapper = self.param_mapper.clone();
185
186        let expected_device = self.lazy_device();
187
188        // Make sure we load the tensor into the same module device.
189        if new_tensor.device() != expected_device {
190            new_tensor = new_tensor.to_device(&expected_device);
191        }
192
193        new_tensor = mapper.on_load(new_tensor);
194
195        let mut loaded = Self::initialized(param_id, new_tensor);
196        loaded.param_mapper = mapper;
197        loaded
198    }
199
200    /// Transform a parameter for saving by applying save transformations.
201    ///
202    /// This method is used to prepare a parameter for saving (typically during serialization).
203    /// It applies the param mapper's `on_save` transformation, which can be used
204    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
205    pub fn transform_for_save(&self) -> Self {
206        let mut tensor = self.val();
207        let mapper = self.param_mapper.clone();
208
209        tensor = mapper.on_save(tensor);
210
211        Self::initialized(self.id, tensor)
212    }
213}
214
215impl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {
216    /// The shape of the parameter, **without triggering initialization**.
217    ///
218    /// This is critical for shape validation during loading: when applying tensors to an
219    /// uninitialized parameter, we need to validate the shape without triggering the
220    /// initialization function (which would allocate an unnecessary tensor).
221    ///
222    /// **Returns:**
223    /// - For uninitialized params: the shape from the `Uninitialized` struct
224    /// - For initialized params: the actual shape from the tensor
225    ///
226    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to
227    /// preserve lazy initialization.
228    pub fn lazy_shape(&self) -> burn_tensor::Shape {
229        let initialization = match &self.initialization {
230            Some(init) => init,
231            None => return self.shape(),
232        };
233
234        let init = initialization.read().unwrap();
235
236        match init.as_ref() {
237            Some(value) => value.shape.clone(),
238            None => self.shape(),
239        }
240    }
241
242    /// Transform a parameter for loading by applying load transformations.
243    ///
244    /// This method is used to restore a parameter from a tensor (typically during deserialization).
245    /// It ensures the tensor is moved to the expected device and applies the param mapper's
246    /// `on_load` transformation.
247    pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {
248        let mut new_tensor = tensor;
249
250        let mapper = self.param_mapper.clone();
251
252        let expected_device = self.lazy_device();
253
254        // Make sure we load the tensor into the same module device.
255        if new_tensor.device() != expected_device {
256            new_tensor = new_tensor.to_device(&expected_device);
257        }
258
259        new_tensor = mapper.on_load(new_tensor);
260
261        let mut loaded = Self::initialized(param_id, new_tensor);
262        loaded.param_mapper = mapper;
263        loaded
264    }
265
266    /// Transform a parameter for saving by applying save transformations.
267    ///
268    /// This method is used to prepare a parameter for saving (typically during serialization).
269    /// It applies the param mapper's `on_save` transformation, which can be used
270    /// to modify the tensor before serialization (e.g., quantization, precision conversion).
271    pub fn transform_for_save(&self) -> Self {
272        let mut tensor = self.val();
273        let mapper = self.param_mapper.clone();
274
275        tensor = mapper.on_save(tensor);
276
277        Self::initialized(self.id, tensor)
278    }
279}
280
281impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
282    type Record = Param<Tensor<B, D>>;
283
284    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
285        visitor.visit_float(self)
286    }
287
288    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
289        mapper.map_float(self)
290    }
291
292    fn into_record(self) -> Self::Record {
293        self.transform_for_save()
294    }
295
296    fn load_record(self, record: Self::Record) -> Self {
297        let (record_param_id, record_tensor, _) = record.consume();
298        self.transform_for_load(record_tensor, record_param_id)
299    }
300
301    fn to_device(self, device: &Device<B>) -> Self {
302        self.map(|tensor| tensor.to_device(device))
303    }
304
305    fn fork(self, device: &Device<B>) -> Self {
306        self.map(|tensor| {
307            let is_require_grad = tensor.is_require_grad();
308            let mut tensor = tensor.to_device(device).detach();
309
310            if is_require_grad {
311                tensor = tensor.require_grad();
312            }
313
314            tensor
315        })
316    }
317
318    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
319        let device = self.val().device();
320
321        if !devices.contains(&device) {
322            devices.push(device)
323        }
324
325        devices
326    }
327}
328
329impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
330    fn content(&self, content: Content) -> Option<Content> {
331        let id = if content.display_settings.show_param_id() {
332            format!(", id: {}", self.id)
333        } else {
334            "".to_string()
335        };
336        let string = format!(
337            "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
338            self.shape().dims
339        );
340        content.add_formatted(&string).optional()
341    }
342}
343impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
344
345impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
346    type Record = Param<Tensor<B, D, Int>>;
347
348    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
349        visitor.visit_int(self)
350    }
351
352    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
353        mapper.map_int(self)
354    }
355
356    fn into_record(self) -> Self::Record {
357        self.transform_for_save()
358    }
359
360    fn load_record(self, record: Self::Record) -> Self {
361        let (record_param_id, record_tensor, _) = record.consume();
362        self.transform_for_load(record_tensor, record_param_id)
363    }
364
365    fn to_device(self, device: &Device<B>) -> Self {
366        self.map(|tensor| tensor.to_device(device))
367    }
368
369    fn fork(self, device: &Device<B>) -> Self {
370        self.to_device(device) // Don't support autodiff.
371    }
372
373    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
374        let device = self.val().device();
375
376        if !devices.contains(&device) {
377            devices.push(device)
378        }
379
380        devices
381    }
382}
383
384impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
385    fn content(&self, content: Content) -> Option<Content> {
386        let id = if content.display_settings.show_param_id() {
387            format!(", id: {}", self.id)
388        } else {
389            "".to_string()
390        };
391        let string = format!(
392            "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
393            self.shape().dims
394        );
395        content.add_formatted(&string).optional()
396    }
397}
398impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
399
400impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
401    type Record = Param<Tensor<B, D, Bool>>;
402
403    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
404        visitor.visit_bool(self)
405    }
406
407    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
408        mapper.map_bool(self)
409    }
410
411    fn into_record(self) -> Self::Record {
412        self.transform_for_save()
413    }
414
415    fn load_record(self, record: Self::Record) -> Self {
416        let (record_param_id, record_tensor, _) = record.consume();
417        self.transform_for_load(record_tensor, record_param_id)
418    }
419
420    fn to_device(self, device: &Device<B>) -> Self {
421        self.map(|tensor| tensor.to_device(device))
422    }
423
424    fn fork(self, device: &Device<B>) -> Self {
425        self.to_device(device) // Don't support autodiff.
426    }
427
428    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
429        let device = self.val().device();
430
431        if !devices.contains(&device) {
432            devices.push(device)
433        }
434
435        devices
436    }
437}
438
439impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
440    fn content(&self, content: Content) -> Option<Content> {
441        let id = if content.display_settings.show_param_id() {
442            format!(", id: {}", self.id)
443        } else {
444            "".to_string()
445        };
446
447        let string = format!(
448            "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
449            self.shape().dims
450        );
451        content.add_formatted(&string).optional()
452    }
453}
454
455impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
456
457impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
458    type InnerModule = Param<Tensor<B::InnerBackend, D>>;
459
460    fn valid(&self) -> Self::InnerModule {
461        Param::initialized(self.id, self.val().inner().set_require_grad(false))
462    }
463}
464
465impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
466    type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
467
468    fn valid(&self) -> Self::InnerModule {
469        Param::initialized(self.id, self.val().inner())
470    }
471}
472
473impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
474    type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
475
476    fn valid(&self) -> Self::InnerModule {
477        Param::initialized(self.id, self.val().inner())
478    }
479}
480
481#[cfg(all(test, feature = "std"))]
482mod tests {
483    use super::*;
484    use crate::{
485        TestAutodiffBackend,
486        module::Module,
487        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
488    };
489
490    #[test]
491    fn test_load_record_setting() {
492        let device = Default::default();
493        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
494
495        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
496        let bytes = byte_recorder
497            .record(
498                Param::initialized(ParamId::new(), tensor.clone()).into_record(),
499                (),
500            )
501            .unwrap();
502
503        let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
504            .no_grad()
505            .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
506            .is_require_grad();
507
508        let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
509            .load_record(byte_recorder.load(bytes, &device).unwrap())
510            .is_require_grad();
511
512        assert!(!no_grad_is_require_grad);
513        assert!(with_default_is_require_grad);
514    }
515}