gloss_utils/
bshare.rs

1use burn::{
2    prelude::Backend,
3    tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8// TODO: Take another look at these conversions with the burn update
9// ================ Tensor to Data Functions ================
10// Handle Float tensors for both 1D and 2D
11/// Convert a burn float tensor to a Vec on wasm
12#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14    tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17/// Convert a burn float tensor to a Vec
18#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20    tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23// Handle Int tensors for both 1D and 2D
24/// Convert a burn int tensor to a Vec on wasm
25#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29        return data;
30    }
31
32    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
33    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34    data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37/// Convert a burn int tensor to a Vec
38#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41    // tensor.to_data().to_vec::<i32>().unwrap()
42    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43        return data;
44    }
45
46    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
47    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48    data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51// ================ To and Into Burn Conversions ================
52
53/// Trait for converting ndarray to burn tensor (generic over Float/Int and
54/// dimensionality)
55pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56    fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57    fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60/// Implementation of the trait for 2D Float ndarray
61impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63        let vec: Vec<f32>;
64        let bytes = if self.is_standard_layout() {
65            self.as_slice().unwrap()
66        } else {
67            vec = self.iter().copied().collect();
68            vec.as_slice()
69        };
70        let shape = [self.nrows(), self.ncols()];
71        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72    }
73
74    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75        let vec: Vec<f32>;
76        let bytes = if self.is_standard_layout() {
77            self.as_slice().expect("Array should have a slice if it's in standard layout")
78        } else {
79            vec = self.iter().copied().collect();
80            vec.as_slice()
81        };
82        let shape = [self.nrows(), self.ncols()];
83        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84    }
85}
86
87/// Trait implementation for 1D Float ndarray
88impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90        let vec: Vec<f32> = self.iter().copied().collect();
91        Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92    }
93
94    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95        let vec: Vec<f32>;
96        let bytes = if self.is_standard_layout() {
97            self.as_slice().expect("Array should have a slice if it's in standard layout")
98        } else {
99            vec = self.iter().copied().collect();
100            vec.as_slice()
101        };
102        Tensor::<B, 1, Float>::from_floats(bytes, device)
103    }
104}
105
106/// Trait implementation for 2D Int ndarray
107impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108    #[allow(clippy::cast_possible_wrap)]
109    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110        let array_i32 = self.mapv(|x| x as i32);
111        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
112        // let vec: Vec<i32> = array_i32.into_raw_vec();
113        let shape = [self.nrows(), self.ncols()];
114        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115    }
116
117    #[allow(clippy::cast_possible_wrap)]
118    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
119        let array_i32 = self.mapv(|x| x as i32);
120        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
121        // let vec: Vec<i32> = array_i32.into_raw_vec();
122        let shape = [self.nrows(), self.ncols()];
123        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
124    }
125}
126
127/// Trait implementation for 1D Int ndarray
128impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
129    #[allow(clippy::cast_possible_wrap)]
130    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
131        let array_i32 = self.mapv(|x| x as i32);
132        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
133        // let vec: Vec<i32> = array_i32.into_raw_vec();
134        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
135    }
136
137    #[allow(clippy::cast_possible_wrap)]
138    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
139        let array_i32 = self.mapv(|x| x as i32);
140        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
141        // let vec: Vec<i32> = array_i32.into_raw_vec();
142        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
143    }
144}
145impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
146    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
147        let vec: Vec<f32>;
148        let bytes = if self.is_standard_layout() {
149            self.as_slice().unwrap()
150        } else {
151            vec = self.iter().copied().collect();
152            vec.as_slice()
153        };
154        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
155        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
156    }
157
158    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
159        let vec: Vec<f32>;
160        let bytes = if self.is_standard_layout() {
161            self.as_slice().expect("Array should have a slice if it's in standard layout")
162        } else {
163            vec = self.iter().copied().collect();
164            vec.as_slice()
165        };
166        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
167        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
168    }
169}
170
171/// Trait implementation for 3D Int ndarray
172impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
173    #[allow(clippy::cast_possible_wrap)]
174    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
175        let array_i32 = self.mapv(|x| x as i32);
176        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
177        // let vec: Vec<i32> = array_i32.into_raw_vec();
178        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
179        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
180    }
181
182    #[allow(clippy::cast_possible_wrap)]
183    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
184        let array_i32 = self.mapv(|x| x as i32);
185        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
186        // let vec: Vec<i32> = array_i32.into_raw_vec();
187        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
188        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
189    }
190}
191/// Implement `ToBurn` for converting `nalgebra::DMatrix<f32>` to a burn tensor
192/// (Float type)
193impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
194    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
195        let num_rows = self.nrows();
196        let num_cols = self.ncols();
197        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
198        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
199    }
200
201    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
202        let num_rows = self.nrows();
203        let num_cols = self.ncols();
204        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
205        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
206    }
207}
208
209/// Implement `ToBurn` for converting `nalgebra::DMatrix<u32>` to a burn tensor
210/// (Int type)
211impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
212    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
213        let num_rows = self.nrows();
214        let num_cols = self.ncols();
215        let flattened: Vec<i32> = self
216            .transpose()
217            .as_slice()
218            .iter()
219            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
220            .collect();
221        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
222    }
223
224    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
225        let num_rows = self.nrows();
226        let num_cols = self.ncols();
227        let flattened: Vec<i32> = self
228            .transpose()
229            .as_slice()
230            .iter()
231            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
232            .collect();
233        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
234    }
235}
236
237// ================ To and Into NdArray Conversions ================
238
239/// Trait for converting burn tensor to ndarray (generic over Float/Int and
240/// dimensionality)
241pub trait ToNdArray<B: Backend, const D: usize, T> {
242    fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
243    fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244}
245
246/// Trait implementation for converting 3D Float burn tensor to ndarray
247impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
248    fn to_ndarray(&self) -> nd::Array3<f32> {
249        let tensor_data = tensor_to_data_float(self);
250        let shape = self.dims();
251        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
252    }
253
254    fn into_ndarray(self) -> nd::Array3<f32> {
255        let tensor_data = tensor_to_data_float(&self);
256        let shape = self.dims();
257        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
258    }
259}
260
261/// Trait implementation for converting 2D Float burn tensor to ndarray
262impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
263    fn to_ndarray(&self) -> nd::Array2<f32> {
264        let tensor_data = tensor_to_data_float(self);
265        let shape = self.dims();
266        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
267    }
268
269    fn into_ndarray(self) -> nd::Array2<f32> {
270        let tensor_data = tensor_to_data_float(&self);
271        let shape = self.dims();
272        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
273    }
274}
275
276/// Trait implementation for converting 1D Float burn tensor to ndarray
277impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
278    fn to_ndarray(&self) -> nd::Array1<f32> {
279        let tensor_data = tensor_to_data_float(self);
280        nd::Array1::from_vec(tensor_data)
281    }
282
283    fn into_ndarray(self) -> nd::Array1<f32> {
284        let tensor_data = tensor_to_data_float(&self);
285        nd::Array1::from_vec(tensor_data)
286    }
287}
288
289/// Trait implementation for converting 3D Int burn tensor to ndarray
290#[allow(clippy::cast_sign_loss)]
291impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
292    fn to_ndarray(&self) -> nd::Array3<u32> {
293        let tensor_data = tensor_to_data_int(self);
294        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
295        let shape = self.dims();
296        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
297    }
298
299    fn into_ndarray(self) -> nd::Array3<u32> {
300        let tensor_data = tensor_to_data_int(&self);
301        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
302        let shape = self.dims();
303        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
304    }
305}
306
307/// Trait implementation for converting 2D Int burn tensor to ndarray
308#[allow(clippy::cast_sign_loss)]
309impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
310    fn to_ndarray(&self) -> nd::Array2<u32> {
311        let tensor_data = tensor_to_data_int(self);
312        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
313        let shape = self.dims();
314        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
315    }
316
317    fn into_ndarray(self) -> nd::Array2<u32> {
318        let tensor_data = tensor_to_data_int(&self);
319        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
320        let shape = self.dims();
321        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
322    }
323}
324
325/// Trait implementation for converting 1D Int burn tensor to ndarray
326#[allow(clippy::cast_sign_loss)]
327impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
328    fn to_ndarray(&self) -> nd::Array1<u32> {
329        let tensor_data = tensor_to_data_int(self);
330        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
331        nd::Array1::from_vec(tensor_data_u32)
332    }
333
334    fn into_ndarray(self) -> nd::Array1<u32> {
335        let tensor_data = tensor_to_data_int(&self);
336        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
337        nd::Array1::from_vec(tensor_data_u32)
338    }
339}
340
341// ================ To and Into Nalgebra Conversions ================
342
343/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
344/// `nalgebra::DVector` (Float type)
345pub trait ToNalgebraFloat<B: Backend, const D: usize> {
346    fn to_nalgebra(&self) -> na::DMatrix<f32>;
347    fn into_nalgebra(self) -> na::DMatrix<f32>;
348}
349
350/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
351/// `nalgebra::DVector` (Int type)
352pub trait ToNalgebraInt<B: Backend, const D: usize> {
353    fn to_nalgebra(&self) -> na::DMatrix<u32>;
354    fn into_nalgebra(self) -> na::DMatrix<u32>;
355}
356
357/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<f32>` (Float
358/// type)
359impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
360    fn to_nalgebra(&self) -> na::DMatrix<f32> {
361        let data = tensor_to_data_float(self);
362        let shape = self.shape().dims;
363        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
364    }
365
366    fn into_nalgebra(self) -> na::DMatrix<f32> {
367        let data = tensor_to_data_float(&self);
368        let shape = self.shape().dims;
369        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
370    }
371}
372
373/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<u32>` (Int
374/// type)
375impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
376    #[allow(clippy::cast_sign_loss)]
377    fn to_nalgebra(&self) -> na::DMatrix<u32> {
378        let data = tensor_to_data_int(self);
379        let shape = self.shape().dims;
380        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
381        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
382    }
383    #[allow(clippy::cast_sign_loss)]
384    fn into_nalgebra(self) -> na::DMatrix<u32> {
385        let data = tensor_to_data_int(&self);
386        let shape = self.shape().dims;
387        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
388        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
389    }
390}