Skip to main content

gloss_utils/
bshare.rs

1use burn::{
2    prelude::Backend,
3    tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8// TODO: Take another look at these conversions with the burn update
9// ================ Tensor to Data Functions ================
10// Handle Float tensors for both 1D and 2D
11/// Convert a burn float tensor to a Vec on wasm
12#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14    tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17/// Convert a burn float tensor to a Vec
18#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20    tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23// Handle Int tensors for both 1D and 2D
24/// Convert a burn int tensor to a Vec on wasm
25#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29        return data;
30    }
31
32    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
33    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34    data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37/// Convert a burn int tensor to a Vec
38#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41    // tensor.to_data().to_vec::<i32>().unwrap()
42    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43        return data;
44    }
45
46    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
47    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48    data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51// ================ To and Into Burn Conversions ================
52
53/// Trait for converting ndarray to burn tensor (generic over Float/Int and
54/// dimensionality)
55pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56    fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57    fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60/// Implementation of the trait for 2D Float ndarray
61impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63        let vec: Vec<f32>;
64        let bytes = if self.is_standard_layout() {
65            self.as_slice().unwrap()
66        } else {
67            vec = self.iter().copied().collect();
68            vec.as_slice()
69        };
70        let shape = [self.nrows(), self.ncols()];
71        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72    }
73
74    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75        let vec: Vec<f32>;
76        let bytes = if self.is_standard_layout() {
77            self.as_slice().expect("Array should have a slice if it's in standard layout")
78        } else {
79            vec = self.iter().copied().collect();
80            vec.as_slice()
81        };
82        let shape = [self.nrows(), self.ncols()];
83        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84    }
85}
86
87/// Trait implementation for 1D Float ndarray
88impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90        let vec: Vec<f32> = self.iter().copied().collect();
91        Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92    }
93
94    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95        let vec: Vec<f32>;
96        let bytes = if self.is_standard_layout() {
97            self.as_slice().expect("Array should have a slice if it's in standard layout")
98        } else {
99            vec = self.iter().copied().collect();
100            vec.as_slice()
101        };
102        Tensor::<B, 1, Float>::from_floats(bytes, device)
103    }
104}
105
106/// Trait implementation for 2D Int ndarray
107impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108    #[allow(clippy::cast_possible_wrap)]
109    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110        let array_i32 = self.mapv(|x| x as i32);
111        let vec: Vec<i32> = array_i32.into_raw_vec();
112        // let vec: Vec<i32> = array_i32.into_raw_vec();
113        let shape = [self.nrows(), self.ncols()];
114        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115
116        //for torch the dtype is i64 only so we need to handle that case
117        // since the multibackend has a fixed dtype of i32 for int tensors we cannot use  B::IntElem::dtype() since it would just return i32
118        // we need to check if the device is torch or libtorch and cast to i64 in this case
119        // let backend_name = B::name(device);
120        // if backend_name.contains("torch") {
121        //     let array = self.mapv(i64::from);
122        //     let vec: Vec<i64> = array.into_raw_vec();
123        //     let shape = [self.nrows(), self.ncols()];
124        //     // Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape) //internally it does a conversion to i32, sad :(
125        //     let data: TensorData = vec.as_slice().into();
126        //     Tensor::<B, 1, Int>::from_data(data, device).reshape(shape)
127        // } else {
128        //     let array = self.mapv(|x| x as i32);
129        //     let vec: Vec<i32> = array.into_raw_vec();
130        //     let shape = [self.nrows(), self.ncols()];
131        //     Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
132        // }
133    }
134
135    #[allow(clippy::cast_possible_wrap)]
136    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
137        let array_i32 = self.mapv(|x| x as i32);
138        let vec: Vec<i32> = array_i32.into_raw_vec();
139        // let vec: Vec<i32> = array_i32.into_raw_vec();
140        let shape = [self.nrows(), self.ncols()];
141        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
142    }
143}
144
145/// Trait implementation for 1D Int ndarray
146impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
147    #[allow(clippy::cast_possible_wrap)]
148    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
149        let array_i32 = self.mapv(|x| x as i32);
150        let vec: Vec<i32> = array_i32.into_raw_vec();
151        // let vec: Vec<i32> = array_i32.into_raw_vec();
152        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
153    }
154
155    #[allow(clippy::cast_possible_wrap)]
156    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
157        let array_i32 = self.mapv(|x| x as i32);
158        let vec: Vec<i32> = array_i32.into_raw_vec();
159        // let vec: Vec<i32> = array_i32.into_raw_vec();
160        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
161    }
162}
163impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
164    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
165        let vec: Vec<f32>;
166        let bytes = if self.is_standard_layout() {
167            self.as_slice().unwrap()
168        } else {
169            vec = self.iter().copied().collect();
170            vec.as_slice()
171        };
172        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
173        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
174    }
175
176    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
177        let vec: Vec<f32>;
178        let bytes = if self.is_standard_layout() {
179            self.as_slice().expect("Array should have a slice if it's in standard layout")
180        } else {
181            vec = self.iter().copied().collect();
182            vec.as_slice()
183        };
184        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
185        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
186    }
187}
188
189/// Trait implementation for 3D Int ndarray
190impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
191    #[allow(clippy::cast_possible_wrap)]
192    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
193        let array_i32 = self.mapv(|x| x as i32);
194        let vec: Vec<i32> = array_i32.into_raw_vec();
195        // let vec: Vec<i32> = array_i32.into_raw_vec();
196        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
197        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
198    }
199
200    #[allow(clippy::cast_possible_wrap)]
201    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
202        let array_i32 = self.mapv(|x| x as i32);
203        let vec: Vec<i32> = array_i32.into_raw_vec();
204        // let vec: Vec<i32> = array_i32.into_raw_vec();
205        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
206        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
207    }
208}
209/// Implement `ToBurn` for converting `nalgebra::DMatrix<f32>` to a burn tensor
210/// (Float type)
211impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
212    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
213        let num_rows = self.nrows();
214        let num_cols = self.ncols();
215        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
216        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
217    }
218
219    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
220        let num_rows = self.nrows();
221        let num_cols = self.ncols();
222        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
223        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
224    }
225}
226
227/// Implement `ToBurn` for converting `nalgebra::DMatrix<u32>` to a burn tensor
228/// (Int type)
229impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
230    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
231        let num_rows = self.nrows();
232        let num_cols = self.ncols();
233        let flattened: Vec<i32> = self
234            .transpose()
235            .as_slice()
236            .iter()
237            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
238            .collect();
239        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
240    }
241
242    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
243        let num_rows = self.nrows();
244        let num_cols = self.ncols();
245        let flattened: Vec<i32> = self
246            .transpose()
247            .as_slice()
248            .iter()
249            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
250            .collect();
251        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
252    }
253}
254
255// ================ To and Into NdArray Conversions ================
256
257/// Trait for converting burn tensor to ndarray (generic over Float/Int and
258/// dimensionality)
259pub trait ToNdArray<B: Backend, const D: usize, T> {
260    fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
261    fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
262}
263
264/// Trait implementation for converting 3D Float burn tensor to ndarray
265impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
266    fn to_ndarray(&self) -> nd::Array3<f32> {
267        let tensor_data = tensor_to_data_float(self);
268        let shape = self.dims();
269        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
270    }
271
272    fn into_ndarray(self) -> nd::Array3<f32> {
273        let tensor_data = tensor_to_data_float(&self);
274        let shape = self.dims();
275        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
276    }
277}
278
279/// Trait implementation for converting 2D Float burn tensor to ndarray
280impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
281    fn to_ndarray(&self) -> nd::Array2<f32> {
282        let tensor_data = tensor_to_data_float(self);
283        let shape = self.dims();
284        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
285    }
286
287    fn into_ndarray(self) -> nd::Array2<f32> {
288        let tensor_data = tensor_to_data_float(&self);
289        let shape = self.dims();
290        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
291    }
292}
293
294/// Trait implementation for converting 1D Float burn tensor to ndarray
295impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
296    fn to_ndarray(&self) -> nd::Array1<f32> {
297        let tensor_data = tensor_to_data_float(self);
298        nd::Array1::from_vec(tensor_data)
299    }
300
301    fn into_ndarray(self) -> nd::Array1<f32> {
302        let tensor_data = tensor_to_data_float(&self);
303        nd::Array1::from_vec(tensor_data)
304    }
305}
306
307/// Trait implementation for converting 3D Int burn tensor to ndarray
308#[allow(clippy::cast_sign_loss)]
309impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
310    fn to_ndarray(&self) -> nd::Array3<u32> {
311        let tensor_data = tensor_to_data_int(self);
312        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
313        let shape = self.dims();
314        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
315    }
316
317    fn into_ndarray(self) -> nd::Array3<u32> {
318        let tensor_data = tensor_to_data_int(&self);
319        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
320        let shape = self.dims();
321        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
322    }
323}
324
325/// Trait implementation for converting 2D Int burn tensor to ndarray
326#[allow(clippy::cast_sign_loss)]
327impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
328    fn to_ndarray(&self) -> nd::Array2<u32> {
329        let tensor_data = tensor_to_data_int(self);
330        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
331        let shape = self.dims();
332        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
333    }
334
335    fn into_ndarray(self) -> nd::Array2<u32> {
336        let tensor_data = tensor_to_data_int(&self);
337        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
338        let shape = self.dims();
339        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
340    }
341}
342
343/// Trait implementation for converting 1D Int burn tensor to ndarray
344#[allow(clippy::cast_sign_loss)]
345impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
346    fn to_ndarray(&self) -> nd::Array1<u32> {
347        let tensor_data = tensor_to_data_int(self);
348        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
349        nd::Array1::from_vec(tensor_data_u32)
350    }
351
352    fn into_ndarray(self) -> nd::Array1<u32> {
353        let tensor_data = tensor_to_data_int(&self);
354        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
355        nd::Array1::from_vec(tensor_data_u32)
356    }
357}
358
359// ================ To and Into Nalgebra Conversions ================
360
361/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
362/// `nalgebra::DVector` (Float type)
363pub trait ToNalgebraFloat<B: Backend, const D: usize> {
364    fn to_nalgebra(&self) -> na::DMatrix<f32>;
365    fn into_nalgebra(self) -> na::DMatrix<f32>;
366}
367
368/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
369/// `nalgebra::DVector` (Int type)
370pub trait ToNalgebraInt<B: Backend, const D: usize> {
371    fn to_nalgebra(&self) -> na::DMatrix<u32>;
372    fn into_nalgebra(self) -> na::DMatrix<u32>;
373}
374
375/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<f32>` (Float
376/// type)
377impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
378    fn to_nalgebra(&self) -> na::DMatrix<f32> {
379        let data = tensor_to_data_float(self);
380        let shape = self.shape().dims;
381        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
382    }
383
384    fn into_nalgebra(self) -> na::DMatrix<f32> {
385        let data = tensor_to_data_float(&self);
386        let shape = self.shape().dims;
387        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
388    }
389}
390
391/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<u32>` (Int
392/// type)
393impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
394    #[allow(clippy::cast_sign_loss)]
395    fn to_nalgebra(&self) -> na::DMatrix<u32> {
396        let data = tensor_to_data_int(self);
397        let shape = self.shape().dims;
398        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
399        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
400    }
401    #[allow(clippy::cast_sign_loss)]
402    fn into_nalgebra(self) -> na::DMatrix<u32> {
403        let data = tensor_to_data_int(&self);
404        let shape = self.shape().dims;
405        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
406        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
407    }
408}