1use burn::{
2 prelude::Backend,
3 tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14 tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20 tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29 return data;
30 }
31
32 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34 data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43 return data;
44 }
45
46 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48 data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56 fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57 fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63 let vec: Vec<f32>;
64 let bytes = if self.is_standard_layout() {
65 self.as_slice().unwrap()
66 } else {
67 vec = self.iter().copied().collect();
68 vec.as_slice()
69 };
70 let shape = [self.nrows(), self.ncols()];
71 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72 }
73
74 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75 let vec: Vec<f32>;
76 let bytes = if self.is_standard_layout() {
77 self.as_slice().expect("Array should have a slice if it's in standard layout")
78 } else {
79 vec = self.iter().copied().collect();
80 vec.as_slice()
81 };
82 let shape = [self.nrows(), self.ncols()];
83 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84 }
85}
86
87impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90 let vec: Vec<f32> = self.iter().copied().collect();
91 Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92 }
93
94 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95 let vec: Vec<f32>;
96 let bytes = if self.is_standard_layout() {
97 self.as_slice().expect("Array should have a slice if it's in standard layout")
98 } else {
99 vec = self.iter().copied().collect();
100 vec.as_slice()
101 };
102 Tensor::<B, 1, Float>::from_floats(bytes, device)
103 }
104}
105
106impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108 #[allow(clippy::cast_possible_wrap)]
109 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110 let array_i32 = self.mapv(|x| x as i32);
111 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
112 let shape = [self.nrows(), self.ncols()];
114 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115 }
116
117 #[allow(clippy::cast_possible_wrap)]
118 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
119 let array_i32 = self.mapv(|x| x as i32);
120 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
121 let shape = [self.nrows(), self.ncols()];
123 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
124 }
125}
126
127impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
129 #[allow(clippy::cast_possible_wrap)]
130 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
131 let array_i32 = self.mapv(|x| x as i32);
132 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
133 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
135 }
136
137 #[allow(clippy::cast_possible_wrap)]
138 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
139 let array_i32 = self.mapv(|x| x as i32);
140 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
141 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
143 }
144}
145impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
146 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
147 let vec: Vec<f32>;
148 let bytes = if self.is_standard_layout() {
149 self.as_slice().unwrap()
150 } else {
151 vec = self.iter().copied().collect();
152 vec.as_slice()
153 };
154 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
155 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
156 }
157
158 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
159 let vec: Vec<f32>;
160 let bytes = if self.is_standard_layout() {
161 self.as_slice().expect("Array should have a slice if it's in standard layout")
162 } else {
163 vec = self.iter().copied().collect();
164 vec.as_slice()
165 };
166 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
167 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
168 }
169}
170
171impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
173 #[allow(clippy::cast_possible_wrap)]
174 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
175 let array_i32 = self.mapv(|x| x as i32);
176 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
177 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
179 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
180 }
181
182 #[allow(clippy::cast_possible_wrap)]
183 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
184 let array_i32 = self.mapv(|x| x as i32);
185 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
186 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
188 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
189 }
190}
191impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
194 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
195 let num_rows = self.nrows();
196 let num_cols = self.ncols();
197 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
198 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
199 }
200
201 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
202 let num_rows = self.nrows();
203 let num_cols = self.ncols();
204 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
205 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
206 }
207}
208
209impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
212 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
213 let num_rows = self.nrows();
214 let num_cols = self.ncols();
215 let flattened: Vec<i32> = self
216 .transpose()
217 .as_slice()
218 .iter()
219 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
220 .collect();
221 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
222 }
223
224 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
225 let num_rows = self.nrows();
226 let num_cols = self.ncols();
227 let flattened: Vec<i32> = self
228 .transpose()
229 .as_slice()
230 .iter()
231 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
232 .collect();
233 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
234 }
235}
236
237pub trait ToNdArray<B: Backend, const D: usize, T> {
242 fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
243 fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244}
245
246impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
248 fn to_ndarray(&self) -> nd::Array3<f32> {
249 let tensor_data = tensor_to_data_float(self);
250 let shape = self.dims();
251 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
252 }
253
254 fn into_ndarray(self) -> nd::Array3<f32> {
255 let tensor_data = tensor_to_data_float(&self);
256 let shape = self.dims();
257 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
258 }
259}
260
261impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
263 fn to_ndarray(&self) -> nd::Array2<f32> {
264 let tensor_data = tensor_to_data_float(self);
265 let shape = self.dims();
266 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
267 }
268
269 fn into_ndarray(self) -> nd::Array2<f32> {
270 let tensor_data = tensor_to_data_float(&self);
271 let shape = self.dims();
272 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
273 }
274}
275
276impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
278 fn to_ndarray(&self) -> nd::Array1<f32> {
279 let tensor_data = tensor_to_data_float(self);
280 nd::Array1::from_vec(tensor_data)
281 }
282
283 fn into_ndarray(self) -> nd::Array1<f32> {
284 let tensor_data = tensor_to_data_float(&self);
285 nd::Array1::from_vec(tensor_data)
286 }
287}
288
289#[allow(clippy::cast_sign_loss)]
291impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
292 fn to_ndarray(&self) -> nd::Array3<u32> {
293 let tensor_data = tensor_to_data_int(self);
294 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
295 let shape = self.dims();
296 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
297 }
298
299 fn into_ndarray(self) -> nd::Array3<u32> {
300 let tensor_data = tensor_to_data_int(&self);
301 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
302 let shape = self.dims();
303 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
304 }
305}
306
307#[allow(clippy::cast_sign_loss)]
309impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
310 fn to_ndarray(&self) -> nd::Array2<u32> {
311 let tensor_data = tensor_to_data_int(self);
312 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
313 let shape = self.dims();
314 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
315 }
316
317 fn into_ndarray(self) -> nd::Array2<u32> {
318 let tensor_data = tensor_to_data_int(&self);
319 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
320 let shape = self.dims();
321 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
322 }
323}
324
325#[allow(clippy::cast_sign_loss)]
327impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
328 fn to_ndarray(&self) -> nd::Array1<u32> {
329 let tensor_data = tensor_to_data_int(self);
330 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
331 nd::Array1::from_vec(tensor_data_u32)
332 }
333
334 fn into_ndarray(self) -> nd::Array1<u32> {
335 let tensor_data = tensor_to_data_int(&self);
336 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
337 nd::Array1::from_vec(tensor_data_u32)
338 }
339}
340
341pub trait ToNalgebraFloat<B: Backend, const D: usize> {
346 fn to_nalgebra(&self) -> na::DMatrix<f32>;
347 fn into_nalgebra(self) -> na::DMatrix<f32>;
348}
349
350pub trait ToNalgebraInt<B: Backend, const D: usize> {
353 fn to_nalgebra(&self) -> na::DMatrix<u32>;
354 fn into_nalgebra(self) -> na::DMatrix<u32>;
355}
356
357impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
360 fn to_nalgebra(&self) -> na::DMatrix<f32> {
361 let data = tensor_to_data_float(self);
362 let shape = self.shape().dims;
363 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
364 }
365
366 fn into_nalgebra(self) -> na::DMatrix<f32> {
367 let data = tensor_to_data_float(&self);
368 let shape = self.shape().dims;
369 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
370 }
371}
372
373impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
376 #[allow(clippy::cast_sign_loss)]
377 fn to_nalgebra(&self) -> na::DMatrix<u32> {
378 let data = tensor_to_data_int(self);
379 let shape = self.shape().dims;
380 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
381 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
382 }
383 #[allow(clippy::cast_sign_loss)]
384 fn into_nalgebra(self) -> na::DMatrix<u32> {
385 let data = tensor_to_data_int(&self);
386 let shape = self.shape().dims;
387 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
388 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
389 }
390}