1use burn::{
2 prelude::Backend,
3 tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14 tensor.to_data().to_vec::<f32>().unwrap()
16}
17
18#[cfg(not(target_arch = "wasm32"))]
20pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
21 tensor.to_data().to_vec::<f32>().unwrap()
22}
23
24#[cfg(target_arch = "wasm32")]
27#[allow(clippy::cast_possible_truncation)]
28pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
29 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
30 return data;
31 }
32
33 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
35 data_i64.into_iter().map(|x| x as i32).collect()
36}
37
38#[cfg(not(target_arch = "wasm32"))]
40#[allow(clippy::cast_possible_truncation)]
41pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
42 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
44 return data;
45 }
46
47 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
49 data_i64.into_iter().map(|x| x as i32).collect()
50}
51
52pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
57 fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
58 fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
59}
60
61impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
63 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
64 let vec: Vec<f32>;
65 let bytes = if self.is_standard_layout() {
66 self.as_slice().unwrap()
67 } else {
68 vec = self.iter().copied().collect();
69 vec.as_slice()
70 };
71 let shape = [self.nrows(), self.ncols()];
72 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
73 }
74
75 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
76 let vec: Vec<f32>;
77 let bytes = if self.is_standard_layout() {
78 self.as_slice().expect("Array should have a slice if it's in standard layout")
79 } else {
80 vec = self.iter().copied().collect();
81 vec.as_slice()
82 };
83 let shape = [self.nrows(), self.ncols()];
84 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
85 }
86}
87
88impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
90 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
91 let vec: Vec<f32> = self.iter().copied().collect();
92 Tensor::<B, 1, Float>::from_floats(&vec[..], device)
93 }
94
95 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
96 let vec: Vec<f32>;
97 let bytes = if self.is_standard_layout() {
98 self.as_slice().expect("Array should have a slice if it's in standard layout")
99 } else {
100 vec = self.iter().copied().collect();
101 vec.as_slice()
102 };
103 Tensor::<B, 1, Float>::from_floats(bytes, device)
104 }
105}
106
107impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
109 #[allow(clippy::cast_possible_wrap)]
110 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
111 let array_i32 = self.mapv(|x| x as i32);
112 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
113 let shape = [self.nrows(), self.ncols()];
115 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
116 }
117
118 #[allow(clippy::cast_possible_wrap)]
119 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
120 let array_i32 = self.mapv(|x| x as i32);
121 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
122 let shape = [self.nrows(), self.ncols()];
124 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
125 }
126}
127
128impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
130 #[allow(clippy::cast_possible_wrap)]
131 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
132 let array_i32 = self.mapv(|x| x as i32);
133 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
134 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
136 }
137
138 #[allow(clippy::cast_possible_wrap)]
139 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
140 let array_i32 = self.mapv(|x| x as i32);
141 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
142 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
144 }
145}
146impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
147 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
148 let vec: Vec<f32>;
149 let bytes = if self.is_standard_layout() {
150 self.as_slice().unwrap()
151 } else {
152 vec = self.iter().copied().collect();
153 vec.as_slice()
154 };
155 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
156 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
157 }
158
159 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
160 let vec: Vec<f32>;
161 let bytes = if self.is_standard_layout() {
162 self.as_slice().expect("Array should have a slice if it's in standard layout")
163 } else {
164 vec = self.iter().copied().collect();
165 vec.as_slice()
166 };
167 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
168 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
169 }
170}
171
172impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
174 #[allow(clippy::cast_possible_wrap)]
175 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
176 let array_i32 = self.mapv(|x| x as i32);
177 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
178 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
180 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
181 }
182
183 #[allow(clippy::cast_possible_wrap)]
184 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
185 let array_i32 = self.mapv(|x| x as i32);
186 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
187 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
189 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
190 }
191}
192impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
195 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
196 let num_rows = self.nrows();
197 let num_cols = self.ncols();
198 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
199 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
200 }
201
202 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
203 let num_rows = self.nrows();
204 let num_cols = self.ncols();
205 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
206 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
207 }
208}
209
210impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
213 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
214 let num_rows = self.nrows();
215 let num_cols = self.ncols();
216 let flattened: Vec<i32> = self
217 .transpose()
218 .as_slice()
219 .iter()
220 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
221 .collect();
222 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
223 }
224
225 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
226 let num_rows = self.nrows();
227 let num_cols = self.ncols();
228 let flattened: Vec<i32> = self
229 .transpose()
230 .as_slice()
231 .iter()
232 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
233 .collect();
234 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
235 }
236}
237
238pub trait ToNdArray<B: Backend, const D: usize, T> {
243 fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244 fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
245}
246
247impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
249 fn to_ndarray(&self) -> nd::Array3<f32> {
250 let tensor_data = tensor_to_data_float(self);
251 let shape = self.dims();
252 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
253 }
254
255 fn into_ndarray(self) -> nd::Array3<f32> {
256 let tensor_data = tensor_to_data_float(&self);
257 let shape = self.dims();
258 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
259 }
260}
261
262impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
264 fn to_ndarray(&self) -> nd::Array2<f32> {
265 let tensor_data = tensor_to_data_float(self);
266 let shape = self.dims();
267 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
268 }
269
270 fn into_ndarray(self) -> nd::Array2<f32> {
271 let tensor_data = tensor_to_data_float(&self);
272 let shape = self.dims();
273 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
274 }
275}
276
277impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
279 fn to_ndarray(&self) -> nd::Array1<f32> {
280 let tensor_data = tensor_to_data_float(self);
281 nd::Array1::from_vec(tensor_data)
282 }
283
284 fn into_ndarray(self) -> nd::Array1<f32> {
285 let tensor_data = tensor_to_data_float(&self);
286 nd::Array1::from_vec(tensor_data)
287 }
288}
289
290#[allow(clippy::cast_sign_loss)]
292impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
293 fn to_ndarray(&self) -> nd::Array3<u32> {
294 let tensor_data = tensor_to_data_int(self);
295 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
296 let shape = self.dims();
297 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
298 }
299
300 fn into_ndarray(self) -> nd::Array3<u32> {
301 let tensor_data = tensor_to_data_int(&self);
302 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
303 let shape = self.dims();
304 nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
305 }
306}
307
308#[allow(clippy::cast_sign_loss)]
310impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
311 fn to_ndarray(&self) -> nd::Array2<u32> {
312 let tensor_data = tensor_to_data_int(self);
313 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
314 let shape = self.dims();
315 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
316 }
317
318 fn into_ndarray(self) -> nd::Array2<u32> {
319 let tensor_data = tensor_to_data_int(&self);
320 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
321 let shape = self.dims();
322 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
323 }
324}
325
326#[allow(clippy::cast_sign_loss)]
328impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
329 fn to_ndarray(&self) -> nd::Array1<u32> {
330 let tensor_data = tensor_to_data_int(self);
331 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
332 nd::Array1::from_vec(tensor_data_u32)
333 }
334
335 fn into_ndarray(self) -> nd::Array1<u32> {
336 let tensor_data = tensor_to_data_int(&self);
337 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
338 nd::Array1::from_vec(tensor_data_u32)
339 }
340}
341
342pub trait ToNalgebraFloat<B: Backend, const D: usize> {
347 fn to_nalgebra(&self) -> na::DMatrix<f32>;
348 fn into_nalgebra(self) -> na::DMatrix<f32>;
349}
350
351pub trait ToNalgebraInt<B: Backend, const D: usize> {
354 fn to_nalgebra(&self) -> na::DMatrix<u32>;
355 fn into_nalgebra(self) -> na::DMatrix<u32>;
356}
357
358impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
361 fn to_nalgebra(&self) -> na::DMatrix<f32> {
362 let data = tensor_to_data_float(self);
363 let shape = self.shape().dims;
364 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
365 }
366
367 fn into_nalgebra(self) -> na::DMatrix<f32> {
368 let data = tensor_to_data_float(&self);
369 let shape = self.shape().dims;
370 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
371 }
372}
373
374impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
377 #[allow(clippy::cast_sign_loss)]
378 fn to_nalgebra(&self) -> na::DMatrix<u32> {
379 let data = tensor_to_data_int(self);
380 let shape = self.shape().dims;
381 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
382 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
383 }
384 #[allow(clippy::cast_sign_loss)]
385 fn into_nalgebra(self) -> na::DMatrix<u32> {
386 let data = tensor_to_data_int(&self);
387 let shape = self.shape().dims;
388 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
389 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
390 }
391}