gloss_utils/
bshare.rs

1use burn::{
2    prelude::Backend,
3    tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8// TODO: Take another look at these conversions with the burn update
9// ================ Tensor to Data Functions ================
10// Handle Float tensors for both 1D and 2D
11/// Convert a burn float tensor to a Vec on wasm
12#[cfg(target_arch = "wasm32")]
13pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
14    // tensor.to_data().block_on().to_vec::<f32>().unwrap()
15    tensor.to_data().to_vec::<f32>().unwrap()
16}
17
18/// Convert a burn float tensor to a Vec
19#[cfg(not(target_arch = "wasm32"))]
20pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
21    tensor.to_data().to_vec::<f32>().unwrap()
22}
23
24// Handle Int tensors for both 1D and 2D
25/// Convert a burn int tensor to a Vec on wasm
26#[cfg(target_arch = "wasm32")]
27#[allow(clippy::cast_possible_truncation)]
28pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
29    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
30        return data;
31    }
32
33    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
34    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
35    data_i64.into_iter().map(|x| x as i32).collect()
36}
37
38/// Convert a burn int tensor to a Vec
39#[cfg(not(target_arch = "wasm32"))]
40#[allow(clippy::cast_possible_truncation)]
41pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
42    // tensor.to_data().to_vec::<i32>().unwrap()
43    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
44        return data;
45    }
46
47    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
48    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
49    data_i64.into_iter().map(|x| x as i32).collect()
50}
51
52// ================ To and Into Burn Conversions ================
53
54/// Trait for converting ndarray to burn tensor (generic over Float/Int and
55/// dimensionality)
56pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
57    fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
58    fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
59}
60
61/// Implementation of the trait for 2D Float ndarray
62impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
63    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
64        let vec: Vec<f32>;
65        let bytes = if self.is_standard_layout() {
66            self.as_slice().unwrap()
67        } else {
68            vec = self.iter().copied().collect();
69            vec.as_slice()
70        };
71        let shape = [self.nrows(), self.ncols()];
72        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
73    }
74
75    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
76        let vec: Vec<f32>;
77        let bytes = if self.is_standard_layout() {
78            self.as_slice().expect("Array should have a slice if it's in standard layout")
79        } else {
80            vec = self.iter().copied().collect();
81            vec.as_slice()
82        };
83        let shape = [self.nrows(), self.ncols()];
84        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
85    }
86}
87
88/// Trait implementation for 1D Float ndarray
89impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
90    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
91        let vec: Vec<f32> = self.iter().copied().collect();
92        Tensor::<B, 1, Float>::from_floats(&vec[..], device)
93    }
94
95    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
96        let vec: Vec<f32>;
97        let bytes = if self.is_standard_layout() {
98            self.as_slice().expect("Array should have a slice if it's in standard layout")
99        } else {
100            vec = self.iter().copied().collect();
101            vec.as_slice()
102        };
103        Tensor::<B, 1, Float>::from_floats(bytes, device)
104    }
105}
106
107/// Trait implementation for 2D Int ndarray
108impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
109    #[allow(clippy::cast_possible_wrap)]
110    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
111        let array_i32 = self.mapv(|x| x as i32);
112        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
113        // let vec: Vec<i32> = array_i32.into_raw_vec();
114        let shape = [self.nrows(), self.ncols()];
115        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
116    }
117
118    #[allow(clippy::cast_possible_wrap)]
119    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
120        let array_i32 = self.mapv(|x| x as i32);
121        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
122        // let vec: Vec<i32> = array_i32.into_raw_vec();
123        let shape = [self.nrows(), self.ncols()];
124        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
125    }
126}
127
128/// Trait implementation for 1D Int ndarray
129impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
130    #[allow(clippy::cast_possible_wrap)]
131    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
132        let array_i32 = self.mapv(|x| x as i32);
133        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
134        // let vec: Vec<i32> = array_i32.into_raw_vec();
135        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
136    }
137
138    #[allow(clippy::cast_possible_wrap)]
139    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
140        let array_i32 = self.mapv(|x| x as i32);
141        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
142        // let vec: Vec<i32> = array_i32.into_raw_vec();
143        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
144    }
145}
146impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
147    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
148        let vec: Vec<f32>;
149        let bytes = if self.is_standard_layout() {
150            self.as_slice().unwrap()
151        } else {
152            vec = self.iter().copied().collect();
153            vec.as_slice()
154        };
155        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
156        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
157    }
158
159    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
160        let vec: Vec<f32>;
161        let bytes = if self.is_standard_layout() {
162            self.as_slice().expect("Array should have a slice if it's in standard layout")
163        } else {
164            vec = self.iter().copied().collect();
165            vec.as_slice()
166        };
167        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
168        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
169    }
170}
171
172/// Trait implementation for 3D Int ndarray
173impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
174    #[allow(clippy::cast_possible_wrap)]
175    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
176        let array_i32 = self.mapv(|x| x as i32);
177        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
178        // let vec: Vec<i32> = array_i32.into_raw_vec();
179        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
180        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
181    }
182
183    #[allow(clippy::cast_possible_wrap)]
184    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
185        let array_i32 = self.mapv(|x| x as i32);
186        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
187        // let vec: Vec<i32> = array_i32.into_raw_vec();
188        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
189        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
190    }
191}
192/// Implement `ToBurn` for converting `nalgebra::DMatrix<f32>` to a burn tensor
193/// (Float type)
194impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
195    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
196        let num_rows = self.nrows();
197        let num_cols = self.ncols();
198        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
199        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
200    }
201
202    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
203        let num_rows = self.nrows();
204        let num_cols = self.ncols();
205        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
206        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
207    }
208}
209
210/// Implement `ToBurn` for converting `nalgebra::DMatrix<u32>` to a burn tensor
211/// (Int type)
212impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
213    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
214        let num_rows = self.nrows();
215        let num_cols = self.ncols();
216        let flattened: Vec<i32> = self
217            .transpose()
218            .as_slice()
219            .iter()
220            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
221            .collect();
222        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
223    }
224
225    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
226        let num_rows = self.nrows();
227        let num_cols = self.ncols();
228        let flattened: Vec<i32> = self
229            .transpose()
230            .as_slice()
231            .iter()
232            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
233            .collect();
234        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
235    }
236}
237
238// ================ To and Into NdArray Conversions ================
239
240/// Trait for converting burn tensor to ndarray (generic over Float/Int and
241/// dimensionality)
242pub trait ToNdArray<B: Backend, const D: usize, T> {
243    fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244    fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
245}
246
247/// Trait implementation for converting 3D Float burn tensor to ndarray
248impl<B: Backend> ToNdArray<B, 3, f32> for Tensor<B, 3, Float> {
249    fn to_ndarray(&self) -> nd::Array3<f32> {
250        let tensor_data = tensor_to_data_float(self);
251        let shape = self.dims();
252        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
253    }
254
255    fn into_ndarray(self) -> nd::Array3<f32> {
256        let tensor_data = tensor_to_data_float(&self);
257        let shape = self.dims();
258        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data).unwrap()
259    }
260}
261
262/// Trait implementation for converting 2D Float burn tensor to ndarray
263impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
264    fn to_ndarray(&self) -> nd::Array2<f32> {
265        let tensor_data = tensor_to_data_float(self);
266        let shape = self.dims();
267        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
268    }
269
270    fn into_ndarray(self) -> nd::Array2<f32> {
271        let tensor_data = tensor_to_data_float(&self);
272        let shape = self.dims();
273        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
274    }
275}
276
277/// Trait implementation for converting 1D Float burn tensor to ndarray
278impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
279    fn to_ndarray(&self) -> nd::Array1<f32> {
280        let tensor_data = tensor_to_data_float(self);
281        nd::Array1::from_vec(tensor_data)
282    }
283
284    fn into_ndarray(self) -> nd::Array1<f32> {
285        let tensor_data = tensor_to_data_float(&self);
286        nd::Array1::from_vec(tensor_data)
287    }
288}
289
290/// Trait implementation for converting 3D Int burn tensor to ndarray
291#[allow(clippy::cast_sign_loss)]
292impl<B: Backend> ToNdArray<B, 3, u32> for Tensor<B, 3, Int> {
293    fn to_ndarray(&self) -> nd::Array3<u32> {
294        let tensor_data = tensor_to_data_int(self);
295        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
296        let shape = self.dims();
297        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
298    }
299
300    fn into_ndarray(self) -> nd::Array3<u32> {
301        let tensor_data = tensor_to_data_int(&self);
302        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
303        let shape = self.dims();
304        nd::Array3::from_shape_vec((shape[0], shape[1], shape[2]), tensor_data_u32).unwrap()
305    }
306}
307
308/// Trait implementation for converting 2D Int burn tensor to ndarray
309#[allow(clippy::cast_sign_loss)]
310impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
311    fn to_ndarray(&self) -> nd::Array2<u32> {
312        let tensor_data = tensor_to_data_int(self);
313        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
314        let shape = self.dims();
315        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
316    }
317
318    fn into_ndarray(self) -> nd::Array2<u32> {
319        let tensor_data = tensor_to_data_int(&self);
320        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
321        let shape = self.dims();
322        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
323    }
324}
325
326/// Trait implementation for converting 1D Int burn tensor to ndarray
327#[allow(clippy::cast_sign_loss)]
328impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
329    fn to_ndarray(&self) -> nd::Array1<u32> {
330        let tensor_data = tensor_to_data_int(self);
331        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
332        nd::Array1::from_vec(tensor_data_u32)
333    }
334
335    fn into_ndarray(self) -> nd::Array1<u32> {
336        let tensor_data = tensor_to_data_int(&self);
337        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
338        nd::Array1::from_vec(tensor_data_u32)
339    }
340}
341
342// ================ To and Into Nalgebra Conversions ================
343
344/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
345/// `nalgebra::DVector` (Float type)
346pub trait ToNalgebraFloat<B: Backend, const D: usize> {
347    fn to_nalgebra(&self) -> na::DMatrix<f32>;
348    fn into_nalgebra(self) -> na::DMatrix<f32>;
349}
350
351/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
352/// `nalgebra::DVector` (Int type)
353pub trait ToNalgebraInt<B: Backend, const D: usize> {
354    fn to_nalgebra(&self) -> na::DMatrix<u32>;
355    fn into_nalgebra(self) -> na::DMatrix<u32>;
356}
357
358/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<f32>` (Float
359/// type)
360impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
361    fn to_nalgebra(&self) -> na::DMatrix<f32> {
362        let data = tensor_to_data_float(self);
363        let shape = self.shape().dims;
364        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
365    }
366
367    fn into_nalgebra(self) -> na::DMatrix<f32> {
368        let data = tensor_to_data_float(&self);
369        let shape = self.shape().dims;
370        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
371    }
372}
373
374/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<u32>` (Int
375/// type)
376impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
377    #[allow(clippy::cast_sign_loss)]
378    fn to_nalgebra(&self) -> na::DMatrix<u32> {
379        let data = tensor_to_data_int(self);
380        let shape = self.shape().dims;
381        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
382        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
383    }
384    #[allow(clippy::cast_sign_loss)]
385    fn into_nalgebra(self) -> na::DMatrix<u32> {
386        let data = tensor_to_data_int(&self);
387        let shape = self.shape().dims;
388        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
389        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
390    }
391}