1use burn::{
2 prelude::Backend,
3 tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8#[cfg(target_arch = "wasm32")]
12pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
13 tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20 tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29 return data;
30 }
31
32 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34 data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41 if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43 return data;
44 }
45
46 let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48 data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56 fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57 fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63 let vec: Vec<f32>;
64 let bytes = if self.is_standard_layout() {
65 self.as_slice().unwrap()
66 } else {
67 vec = self.iter().copied().collect();
68 vec.as_slice()
69 };
70 let shape = [self.nrows(), self.ncols()];
71 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72 }
73
74 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75 let vec: Vec<f32>;
76 let bytes = if self.is_standard_layout() {
77 self.as_slice().expect("Array should have a slice if it's in standard layout")
78 } else {
79 vec = self.iter().copied().collect();
80 vec.as_slice()
81 };
82 let shape = [self.nrows(), self.ncols()];
83 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84 }
85}
86
87impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90 let vec: Vec<f32> = self.iter().copied().collect();
91 Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92 }
93
94 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95 let vec: Vec<f32>;
96 let bytes = if self.is_standard_layout() {
97 self.as_slice().expect("Array should have a slice if it's in standard layout")
98 } else {
99 vec = self.iter().copied().collect();
100 vec.as_slice()
101 };
102 Tensor::<B, 1, Float>::from_floats(bytes, device)
103 }
104}
105
106impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108 #[allow(clippy::cast_possible_wrap)]
109 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110 let array_i32 = self.mapv(|x| x as i32);
111 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
112 let shape = [self.nrows(), self.ncols()];
114 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115 }
116
117 #[allow(clippy::cast_possible_wrap)]
118 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
119 let array_i32 = self.mapv(|x| x as i32);
120 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
121 let shape = [self.nrows(), self.ncols()];
123 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
124 }
125}
126
127impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
129 #[allow(clippy::cast_possible_wrap)]
130 fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
131 let array_i32 = self.mapv(|x| x as i32);
132 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
133 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
135 }
136
137 #[allow(clippy::cast_possible_wrap)]
138 fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
139 let array_i32 = self.mapv(|x| x as i32);
140 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
141 Tensor::<B, 1, Int>::from_ints(&vec[..], device)
143 }
144}
145impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
146 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
147 let vec: Vec<f32>;
148 let bytes = if self.is_standard_layout() {
149 self.as_slice().unwrap()
150 } else {
151 vec = self.iter().copied().collect();
152 vec.as_slice()
153 };
154 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
155 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
156 }
157
158 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
159 let vec: Vec<f32>;
160 let bytes = if self.is_standard_layout() {
161 self.as_slice().expect("Array should have a slice if it's in standard layout")
162 } else {
163 vec = self.iter().copied().collect();
164 vec.as_slice()
165 };
166 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
167 Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
168 }
169}
170
171impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
173 #[allow(clippy::cast_possible_wrap)]
174 fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
175 let array_i32 = self.mapv(|x| x as i32);
176 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
177 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
179 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
180 }
181
182 #[allow(clippy::cast_possible_wrap)]
183 fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
184 let array_i32 = self.mapv(|x| x as i32);
185 let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
186 let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
188 Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
189 }
190}
191impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
194 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
195 let num_rows = self.nrows();
196 let num_cols = self.ncols();
197 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
198 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
199 }
200
201 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
202 let num_rows = self.nrows();
203 let num_cols = self.ncols();
204 let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
205 Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
206 }
207}
208
209impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
212 fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
213 let num_rows = self.nrows();
214 let num_cols = self.ncols();
215 let flattened: Vec<i32> = self
216 .transpose()
217 .as_slice()
218 .iter()
219 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
220 .collect();
221 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
222 }
223
224 fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
225 let num_rows = self.nrows();
226 let num_cols = self.ncols();
227 let flattened: Vec<i32> = self
228 .transpose()
229 .as_slice()
230 .iter()
231 .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
232 .collect();
233 Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
234 }
235}
236
237pub trait ToNdArray<B: Backend, const D: usize, T> {
242 fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
243 fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244}
245
246impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
248 fn to_ndarray(&self) -> nd::Array2<f32> {
249 let tensor_data = tensor_to_data_float(self);
250 let shape = self.dims();
251 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
252 }
253
254 fn into_ndarray(self) -> nd::Array2<f32> {
255 let tensor_data = tensor_to_data_float(&self);
256 let shape = self.dims();
257 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
258 }
259}
260
261impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
263 fn to_ndarray(&self) -> nd::Array1<f32> {
264 let tensor_data = tensor_to_data_float(self);
265 nd::Array1::from_vec(tensor_data)
266 }
267
268 fn into_ndarray(self) -> nd::Array1<f32> {
269 let tensor_data = tensor_to_data_float(&self);
270 nd::Array1::from_vec(tensor_data)
271 }
272}
273
274#[allow(clippy::cast_sign_loss)]
276impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
277 fn to_ndarray(&self) -> nd::Array2<u32> {
278 let tensor_data = tensor_to_data_int(self);
279 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
280 let shape = self.dims();
281 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
282 }
283
284 fn into_ndarray(self) -> nd::Array2<u32> {
285 let tensor_data = tensor_to_data_int(&self);
286 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
287 let shape = self.dims();
288 nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
289 }
290}
291
292#[allow(clippy::cast_sign_loss)]
294impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
295 fn to_ndarray(&self) -> nd::Array1<u32> {
296 let tensor_data = tensor_to_data_int(self);
297 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
298 nd::Array1::from_vec(tensor_data_u32)
299 }
300
301 fn into_ndarray(self) -> nd::Array1<u32> {
302 let tensor_data = tensor_to_data_int(&self);
303 let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
304 nd::Array1::from_vec(tensor_data_u32)
305 }
306}
307
308pub trait ToNalgebraFloat<B: Backend, const D: usize> {
313 fn to_nalgebra(&self) -> na::DMatrix<f32>;
314 fn into_nalgebra(self) -> na::DMatrix<f32>;
315}
316
317pub trait ToNalgebraInt<B: Backend, const D: usize> {
320 fn to_nalgebra(&self) -> na::DMatrix<u32>;
321 fn into_nalgebra(self) -> na::DMatrix<u32>;
322}
323
324impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
327 fn to_nalgebra(&self) -> na::DMatrix<f32> {
328 let data = tensor_to_data_float(self);
329 let shape = self.shape().dims;
330 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
331 }
332
333 fn into_nalgebra(self) -> na::DMatrix<f32> {
334 let data = tensor_to_data_float(&self);
335 let shape = self.shape().dims;
336 na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
337 }
338}
339
340impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
343 #[allow(clippy::cast_sign_loss)]
344 fn to_nalgebra(&self) -> na::DMatrix<u32> {
345 let data = tensor_to_data_int(self);
346 let shape = self.shape().dims;
347 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
348 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
349 }
350 #[allow(clippy::cast_sign_loss)]
351 fn into_nalgebra(self) -> na::DMatrix<u32> {
352 let data = tensor_to_data_int(&self);
353 let shape = self.shape().dims;
354 let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
355 na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
356 }
357}