1use burn::{
2 prelude::Backend,
3 tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14 tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20 tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29 return data;
30 }
31
32 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34 data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43 return data;
44 }
45
46 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48 data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56 fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57 fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63 let vec: Vec<f32>;
64 let bytes = if self.is_standard_layout() {
65 self.as_slice().unwrap()
66 } else {
67 vec = self.iter().copied().collect();
68 vec.as_slice()
69 };
70 let shape = [self.nrows(), self.ncols()];
71 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72 }
73
74 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75 let vec: Vec<f32>;
76 let bytes = if self.is_standard_layout() {
77 self.as_slice().expect("Array should have a slice if it's in standard layout")
78 } else {
79 vec = self.iter().copied().collect();
80 vec.as_slice()
81 };
82 let shape = [self.nrows(), self.ncols()];
83 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84 }
85}
86
87impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90 let vec: Vec<f32> = self.iter().copied().collect();
91 Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92 }
93
94 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95 let vec: Vec<f32>;
96 let bytes = if self.is_standard_layout() {
97 self.as_slice().expect("Array should have a slice if it's in standard layout")
98 } else {
99 vec = self.iter().copied().collect();
100 vec.as_slice()
101 };
102 Tensor::<B, 1, Float>::from_floats(bytes, device)
103 }
104}
105
106impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108 #[allow(clippy::cast_possible_wrap)]
109 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110 let array_i32 = self.mapv(|x| x as i32);
111 let vec: Vec<i32> = array_i32.into_raw_vec();
112 let shape = [self.nrows(), self.ncols()];
114 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115
116 }
134
135 #[allow(clippy::cast_possible_wrap)]
136 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
137 let array_i32 = self.mapv(|x| x as i32);
138 let vec: Vec<i32> = array_i32.into_raw_vec();
139 let shape = [self.nrows(), self.ncols()];
141 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
142 }
143}
144
145impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
147 #[allow(clippy::cast_possible_wrap)]
148 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
149 let array_i32 = self.mapv(|x| x as i32);
150 let vec: Vec<i32> = array_i32.into_raw_vec();
151 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
153 }
154
155 #[allow(clippy::cast_possible_wrap)]
156 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
157 let array_i32 = self.mapv(|x| x as i32);
158 let vec: Vec<i32> = array_i32.into_raw_vec();
159 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
161 }
162}
163impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
164 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
165 let vec: Vec<f32>;
166 let bytes = if self.is_standard_layout() {
167 self.as_slice().unwrap()
168 } else {
169 vec = self.iter().copied().collect();
170 vec.as_slice()
171 };
172 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
173 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
174 }
175
176 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
177 let vec: Vec<f32>;
178 let bytes = if self.is_standard_layout() {
179 self.as_slice().expect("Array should have a slice if it's in standard layout")
180 } else {
181 vec = self.iter().copied().collect();
182 vec.as_slice()
183 };
184 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
185 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
186 }
187}
188
189impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
191 #[allow(clippy::cast_possible_wrap)]
192 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
193 let array_i32 = self.mapv(|x| x as i32);
194 let vec: Vec<i32> = array_i32.into_raw_vec();
195 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
197 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
198 }
199
200 #[allow(clippy::cast_possible_wrap)]
201 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
202 let array_i32 = self.mapv(|x| x as i32);
203 let vec: Vec<i32> = array_i32.into_raw_vec();
204 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
206 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
207 }
208}
209impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
212 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
213 let num_rows = self.nrows();
214 let num_cols = self.ncols();
215 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
216 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
217 }
218
219 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
220 let num_rows = self.nrows();
221 let num_cols = self.ncols();
222 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
223 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
224 }
225}
226
227impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
230 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
231 let num_rows = self.nrows();
232 let num_cols = self.ncols();
233 let flattened: Vec<i32> = self
234 .transpose()
235 .as_slice()
236 .iter()
237 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
238 .collect();
239 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
240 }
241
242 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
243 let num_rows = self.nrows();
244 let num_cols = self.ncols();
245 let flattened: Vec<i32> = self
246 .transpose()
247 .as_slice()
248 .iter()
249 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
250 .collect();
251 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
252 }
253}
254
255pub trait ToNdArray<B: Backend, const D: usize, T> {
260 fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
261 fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
262}
263
264impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
266 fn to_ndarray(&self) -> nd::Array3<f32> {
267 let tensor_data = tensor_to_data_float(self);
268 let shape = self.dims();
269 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
270 }
271
272 fn into_ndarray(self) -> nd::Array3<f32> {
273 let tensor_data = tensor_to_data_float(&self);
274 let shape = self.dims();
275 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
276 }
277}
278
279impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
281 fn to_ndarray(&self) -> nd::Array2<f32> {
282 let tensor_data = tensor_to_data_float(self);
283 let shape = self.dims();
284 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
285 }
286
287 fn into_ndarray(self) -> nd::Array2<f32> {
288 let tensor_data = tensor_to_data_float(&self);
289 let shape = self.dims();
290 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
291 }
292}
293
294impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
296 fn to_ndarray(&self) -> nd::Array1<f32> {
297 let tensor_data = tensor_to_data_float(self);
298 nd::Array1::from_vec(tensor_data)
299 }
300
301 fn into_ndarray(self) -> nd::Array1<f32> {
302 let tensor_data = tensor_to_data_float(&self);
303 nd::Array1::from_vec(tensor_data)
304 }
305}
306
307#[allow(clippy::cast_sign_loss)]
309impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
310 fn to_ndarray(&self) -> nd::Array3<u32> {
311 let tensor_data = tensor_to_data_int(self);
312 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
313 let shape = self.dims();
314 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
315 }
316
317 fn into_ndarray(self) -> nd::Array3<u32> {
318 let tensor_data = tensor_to_data_int(&self);
319 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
320 let shape = self.dims();
321 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
322 }
323}
324
325#[allow(clippy::cast_sign_loss)]
327impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
328 fn to_ndarray(&self) -> nd::Array2<u32> {
329 let tensor_data = tensor_to_data_int(self);
330 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
331 let shape = self.dims();
332 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
333 }
334
335 fn into_ndarray(self) -> nd::Array2<u32> {
336 let tensor_data = tensor_to_data_int(&self);
337 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
338 let shape = self.dims();
339 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
340 }
341}
342
343#[allow(clippy::cast_sign_loss)]
345impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
346 fn to_ndarray(&self) -> nd::Array1<u32> {
347 let tensor_data = tensor_to_data_int(self);
348 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
349 nd::Array1::from_vec(tensor_data_u32)
350 }
351
352 fn into_ndarray(self) -> nd::Array1<u32> {
353 let tensor_data = tensor_to_data_int(&self);
354 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
355 nd::Array1::from_vec(tensor_data_u32)
356 }
357}
358
359pub trait ToNalgebraFloat<B: Backend, const D: usize> {
364 fn to_nalgebra(&self) -> na::DMatrix<f32>;
365 fn into_nalgebra(self) -> na::DMatrix<f32>;
366}
367
368pub trait ToNalgebraInt<B: Backend, const D: usize> {
371 fn to_nalgebra(&self) -> na::DMatrix<u32>;
372 fn into_nalgebra(self) -> na::DMatrix<u32>;
373}
374
375impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
378 fn to_nalgebra(&self) -> na::DMatrix<f32> {
379 let data = tensor_to_data_float(self);
380 let shape = self.shape().dims;
381 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
382 }
383
384 fn into_nalgebra(self) -> na::DMatrix<f32> {
385 let data = tensor_to_data_float(&self);
386 let shape = self.shape().dims;
387 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
388 }
389}
390
391impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
394 #[allow(clippy::cast_sign_loss)]
395 fn to_nalgebra(&self) -> na::DMatrix<u32> {
396 let data = tensor_to_data_int(self);
397 let shape = self.shape().dims;
398 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
399 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
400 }
401 #[allow(clippy::cast_sign_loss)]
402 fn into_nalgebra(self) -> na::DMatrix<u32> {
403 let data = tensor_to_data_int(&self);
404 let shape = self.shape().dims;
405 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
406 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
407 }
408}