burn_core/module/param/
tensor.rs

1use super::{Param, ParamId, Parameter};
2use crate::module::{
3    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4    ModuleVisitor,
5};
6use crate::tensor::{
7    Tensor,
8    backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14    type Device = B::Device;
15
16    fn device(&self) -> Self::Device {
17        Tensor::device(self)
18    }
19
20    fn is_require_grad(&self) -> bool {
21        Tensor::is_require_grad(self)
22    }
23
24    fn set_require_grad(self, require_grad: bool) -> Self {
25        Tensor::set_require_grad(self, require_grad)
26    }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30    type Device = B::Device;
31
32    fn device(&self) -> Self::Device {
33        Tensor::device(self)
34    }
35
36    fn is_require_grad(&self) -> bool {
37        false
38    }
39
40    fn set_require_grad(self, _require_grad: bool) -> Self {
41        self
42    }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46    type Device = B::Device;
47
48    fn device(&self) -> Self::Device {
49        Tensor::device(self)
50    }
51
52    fn is_require_grad(&self) -> bool {
53        false
54    }
55
56    fn set_require_grad(self, _require_grad: bool) -> Self {
57        self
58    }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62    /// Create a new parameter from a float tensor.
63    ///
64    /// # Warnings
65    ///
66    /// We strongly recommend using [Param::uninitialized] if you are using this method to
67    /// initialize parameters inside a module, since the tensor initialization will be lazy,
68    /// making the loading of weights more performant.
69    pub fn from_tensor(value: Tensor<B, D>) -> Self {
70        // When creating a parameter from a float tensor, we automatically mark it as requiring
71        // gradients, so that it can be updated by an optimizer.
72        Param::initialized(ParamId::new(), value.require_grad())
73    }
74
75    /// Create a new parameter from data.
76    pub fn from_data<T>(data: T, device: &B::Device) -> Self
77    where
78        T: Into<TensorData>,
79    {
80        // When creating a parameter from a float tensor, we automatically mark it as requiring
81        // gradients, so that it can be updated by an optimizer.
82        let value = Tensor::from_data(data, device);
83        Param::initialized(ParamId::new(), value.require_grad())
84    }
85}
86
87impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
88    type Record = Param<Tensor<B, D>>;
89
90    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
91        visitor.visit_float(self.id, &self.val())
92    }
93
94    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
95        let (id, tensor) = self.consume();
96        let value = mapper.map_float(id, tensor);
97        Self::initialized(id, value)
98    }
99
100    fn into_record(self) -> Self::Record {
101        self
102    }
103
104    fn load_record(self, record: Self::Record) -> Self {
105        let (new_id, mut new_value) = record.consume();
106
107        let expected_device = self.lazy_device();
108        let expected_require_grad = self.lazy_is_require_grad();
109
110        // Make sure we load the record into the same module device.
111        if new_value.device() != expected_device {
112            new_value = new_value.to_device(&expected_device).detach();
113        }
114
115        // Make sure we load the record with the same autodiff setting.
116        new_value = new_value.set_require_grad(expected_require_grad);
117
118        Self::initialized(new_id, new_value)
119    }
120
121    fn to_device(self, device: &Device<B>) -> Self {
122        self.map(|tensor| tensor.to_device(device))
123    }
124
125    fn fork(self, device: &Device<B>) -> Self {
126        self.map(|tensor| {
127            let is_require_grad = tensor.is_require_grad();
128            let mut tensor = tensor.to_device(device).detach();
129
130            if is_require_grad {
131                tensor = tensor.require_grad();
132            }
133
134            tensor
135        })
136    }
137
138    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
139        let device = self.val().device();
140
141        if !devices.contains(&device) {
142            devices.push(device)
143        }
144
145        devices
146    }
147}
148
149impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
150    fn content(&self, content: Content) -> Option<Content> {
151        let id = if content.display_settings.show_param_id() {
152            format!(", id: {}", self.id)
153        } else {
154            "".to_string()
155        };
156        let string = format!(
157            "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
158            self.shape().dims
159        );
160        content.add_formatted(&string).optional()
161    }
162}
163impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
164
165impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
166    type Record = Param<Tensor<B, D, Int>>;
167
168    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
169        visitor.visit_int(self.id, &self.val())
170    }
171
172    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
173        let value = mapper.map_int(self.id, self.val());
174        Self::initialized(self.id, value)
175    }
176
177    fn into_record(self) -> Self::Record {
178        self
179    }
180
181    fn load_record(self, record: Self::Record) -> Self {
182        let (new_id, mut new_value) = record.consume();
183
184        let expected_device = self.lazy_device();
185
186        // Make sure we load the record into the same module device.
187        if new_value.device() != expected_device {
188            new_value = new_value.to_device(&expected_device);
189        }
190
191        Self::initialized(new_id, new_value)
192    }
193
194    fn to_device(self, device: &Device<B>) -> Self {
195        self.map(|tensor| tensor.to_device(device))
196    }
197
198    fn fork(self, device: &Device<B>) -> Self {
199        self.to_device(device) // Don't support autodiff.
200    }
201
202    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
203        let device = self.val().device();
204
205        if !devices.contains(&device) {
206            devices.push(device)
207        }
208
209        devices
210    }
211}
212
213impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
214    fn content(&self, content: Content) -> Option<Content> {
215        let id = if content.display_settings.show_param_id() {
216            format!(", id: {}", self.id)
217        } else {
218            "".to_string()
219        };
220        let string = format!(
221            "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
222            self.shape().dims
223        );
224        content.add_formatted(&string).optional()
225    }
226}
227impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
228
229impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
230    type Record = Param<Tensor<B, D, Bool>>;
231
232    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
233        visitor.visit_bool(self.id, &self.val())
234    }
235
236    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
237        let value = mapper.map_bool(self.id, self.val());
238        Self::initialized(self.id, value)
239    }
240
241    fn into_record(self) -> Self::Record {
242        self
243    }
244
245    fn load_record(self, record: Self::Record) -> Self {
246        let (new_id, mut new_value) = record.consume();
247
248        let expected_device = self.lazy_device();
249
250        // Make sure we load the record into the same module device.
251        if new_value.device() != expected_device {
252            new_value = new_value.to_device(&expected_device);
253        }
254
255        Self::initialized(new_id, new_value)
256    }
257
258    fn to_device(self, device: &Device<B>) -> Self {
259        self.map(|tensor| tensor.to_device(device))
260    }
261
262    fn fork(self, device: &Device<B>) -> Self {
263        self.to_device(device) // Don't support autodiff.
264    }
265
266    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
267        let device = self.val().device();
268
269        if !devices.contains(&device) {
270            devices.push(device)
271        }
272
273        devices
274    }
275}
276
277impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
278    fn content(&self, content: Content) -> Option<Content> {
279        let id = if content.display_settings.show_param_id() {
280            format!(", id: {}", self.id)
281        } else {
282            "".to_string()
283        };
284
285        let string = format!(
286            "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
287            self.shape().dims
288        );
289        content.add_formatted(&string).optional()
290    }
291}
292
293impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
294
295impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
296    type InnerModule = Param<Tensor<B::InnerBackend, D>>;
297
298    fn valid(&self) -> Self::InnerModule {
299        Param::initialized(self.id, self.val().inner().set_require_grad(false))
300    }
301}
302
303impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
304    type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
305
306    fn valid(&self) -> Self::InnerModule {
307        Param::initialized(self.id, self.val().inner())
308    }
309}
310
311impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
312    type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
313
314    fn valid(&self) -> Self::InnerModule {
315        Param::initialized(self.id, self.val().inner())
316    }
317}
318
319#[cfg(all(test, feature = "std"))]
320mod tests {
321    use super::*;
322    use crate::{
323        TestAutodiffBackend,
324        module::Module,
325        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
326    };
327
328    #[test]
329    fn test_load_record_setting() {
330        let device = Default::default();
331        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
332
333        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
334        let bytes = byte_recorder
335            .record(
336                Param::initialized(ParamId::new(), tensor.clone()).into_record(),
337                (),
338            )
339            .unwrap();
340
341        let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
342            .no_grad()
343            .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
344            .is_require_grad();
345
346        let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
347            .load_record(byte_recorder.load(bytes, &device).unwrap())
348            .is_require_grad();
349
350        assert!(!no_grad_is_require_grad);
351        assert!(with_default_is_require_grad);
352    }
353}