1use super::{Param, ParamId, Parameter};
2use crate::module::{
3 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4 ModuleVisitor,
5};
6use crate::tensor::{
7 Tensor,
8 backend::{AutodiffBackend, Backend},
9};
10use alloc::{format, string::ToString, vec::Vec};
11use burn_tensor::{Bool, Float, Int, TensorData, ops::Device};
12
13impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
14 type Device = B::Device;
15
16 fn device(&self) -> Self::Device {
17 Tensor::device(self)
18 }
19
20 fn is_require_grad(&self) -> bool {
21 Tensor::is_require_grad(self)
22 }
23
24 fn set_require_grad(self, require_grad: bool) -> Self {
25 Tensor::set_require_grad(self, require_grad)
26 }
27}
28
29impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {
30 type Device = B::Device;
31
32 fn device(&self) -> Self::Device {
33 Tensor::device(self)
34 }
35
36 fn is_require_grad(&self) -> bool {
37 false
38 }
39
40 fn set_require_grad(self, _require_grad: bool) -> Self {
41 self
42 }
43}
44
45impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {
46 type Device = B::Device;
47
48 fn device(&self) -> Self::Device {
49 Tensor::device(self)
50 }
51
52 fn is_require_grad(&self) -> bool {
53 false
54 }
55
56 fn set_require_grad(self, _require_grad: bool) -> Self {
57 self
58 }
59}
60
61impl<B: Backend, const D: usize> Param<Tensor<B, D>> {
62 pub fn from_tensor(value: Tensor<B, D>) -> Self {
70 Param::initialized(ParamId::new(), value.require_grad())
73 }
74
75 pub fn from_data<T>(data: T, device: &B::Device) -> Self
77 where
78 T: Into<TensorData>,
79 {
80 let value = Tensor::from_data(data, device);
83 Param::initialized(ParamId::new(), value.require_grad())
84 }
85}
86
87impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
88 type Record = Param<Tensor<B, D>>;
89
90 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
91 visitor.visit_float(self.id, &self.val())
92 }
93
94 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
95 let (id, tensor) = self.consume();
96 let value = mapper.map_float(id, tensor);
97 Self::initialized(id, value)
98 }
99
100 fn into_record(self) -> Self::Record {
101 self
102 }
103
104 fn load_record(self, record: Self::Record) -> Self {
105 let (new_id, mut new_value) = record.consume();
106
107 let expected_device = self.lazy_device();
108 let expected_require_grad = self.lazy_is_require_grad();
109
110 if new_value.device() != expected_device {
112 new_value = new_value.to_device(&expected_device).detach();
113 }
114
115 new_value = new_value.set_require_grad(expected_require_grad);
117
118 Self::initialized(new_id, new_value)
119 }
120
121 fn to_device(self, device: &Device<B>) -> Self {
122 self.map(|tensor| tensor.to_device(device))
123 }
124
125 fn fork(self, device: &Device<B>) -> Self {
126 self.map(|tensor| {
127 let is_require_grad = tensor.is_require_grad();
128 let mut tensor = tensor.to_device(device).detach();
129
130 if is_require_grad {
131 tensor = tensor.require_grad();
132 }
133
134 tensor
135 })
136 }
137
138 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
139 let device = self.val().device();
140
141 if !devices.contains(&device) {
142 devices.push(device)
143 }
144
145 devices
146 }
147}
148
149impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
150 fn content(&self, content: Content) -> Option<Content> {
151 let id = if content.display_settings.show_param_id() {
152 format!(", id: {}", self.id)
153 } else {
154 "".to_string()
155 };
156 let string = format!(
157 "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
158 self.shape().dims
159 );
160 content.add_formatted(&string).optional()
161 }
162}
163impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
164
165impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
166 type Record = Param<Tensor<B, D, Int>>;
167
168 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
169 visitor.visit_int(self.id, &self.val())
170 }
171
172 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
173 let value = mapper.map_int(self.id, self.val());
174 Self::initialized(self.id, value)
175 }
176
177 fn into_record(self) -> Self::Record {
178 self
179 }
180
181 fn load_record(self, record: Self::Record) -> Self {
182 let (new_id, mut new_value) = record.consume();
183
184 let expected_device = self.lazy_device();
185
186 if new_value.device() != expected_device {
188 new_value = new_value.to_device(&expected_device);
189 }
190
191 Self::initialized(new_id, new_value)
192 }
193
194 fn to_device(self, device: &Device<B>) -> Self {
195 self.map(|tensor| tensor.to_device(device))
196 }
197
198 fn fork(self, device: &Device<B>) -> Self {
199 self.to_device(device) }
201
202 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
203 let device = self.val().device();
204
205 if !devices.contains(&device) {
206 devices.push(device)
207 }
208
209 devices
210 }
211}
212
213impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
214 fn content(&self, content: Content) -> Option<Content> {
215 let id = if content.display_settings.show_param_id() {
216 format!(", id: {}", self.id)
217 } else {
218 "".to_string()
219 };
220 let string = format!(
221 "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
222 self.shape().dims
223 );
224 content.add_formatted(&string).optional()
225 }
226}
227impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
228
229impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
230 type Record = Param<Tensor<B, D, Bool>>;
231
232 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
233 visitor.visit_bool(self.id, &self.val())
234 }
235
236 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
237 let value = mapper.map_bool(self.id, self.val());
238 Self::initialized(self.id, value)
239 }
240
241 fn into_record(self) -> Self::Record {
242 self
243 }
244
245 fn load_record(self, record: Self::Record) -> Self {
246 let (new_id, mut new_value) = record.consume();
247
248 let expected_device = self.lazy_device();
249
250 if new_value.device() != expected_device {
252 new_value = new_value.to_device(&expected_device);
253 }
254
255 Self::initialized(new_id, new_value)
256 }
257
258 fn to_device(self, device: &Device<B>) -> Self {
259 self.map(|tensor| tensor.to_device(device))
260 }
261
262 fn fork(self, device: &Device<B>) -> Self {
263 self.to_device(device) }
265
266 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
267 let device = self.val().device();
268
269 if !devices.contains(&device) {
270 devices.push(device)
271 }
272
273 devices
274 }
275}
276
277impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
278 fn content(&self, content: Content) -> Option<Content> {
279 let id = if content.display_settings.show_param_id() {
280 format!(", id: {}", self.id)
281 } else {
282 "".to_string()
283 };
284
285 let string = format!(
286 "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
287 self.shape().dims
288 );
289 content.add_formatted(&string).optional()
290 }
291}
292
293impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
294
295impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
296 type InnerModule = Param<Tensor<B::InnerBackend, D>>;
297
298 fn valid(&self) -> Self::InnerModule {
299 Param::initialized(self.id, self.val().inner().set_require_grad(false))
300 }
301}
302
303impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
304 type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
305
306 fn valid(&self) -> Self::InnerModule {
307 Param::initialized(self.id, self.val().inner())
308 }
309}
310
311impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
312 type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
313
314 fn valid(&self) -> Self::InnerModule {
315 Param::initialized(self.id, self.val().inner())
316 }
317}
318
319#[cfg(all(test, feature = "std"))]
320mod tests {
321 use super::*;
322 use crate::{
323 TestAutodiffBackend,
324 module::Module,
325 record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
326 };
327
328 #[test]
329 fn test_load_record_setting() {
330 let device = Default::default();
331 let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();
332
333 let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
334 let bytes = byte_recorder
335 .record(
336 Param::initialized(ParamId::new(), tensor.clone()).into_record(),
337 (),
338 )
339 .unwrap();
340
341 let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())
342 .no_grad()
343 .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
344 .is_require_grad();
345
346 let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)
347 .load_record(byte_recorder.load(bytes, &device).unwrap())
348 .is_require_grad();
349
350 assert!(!no_grad_is_require_grad);
351 assert!(with_default_is_require_grad);
352 }
353}