gloss_utils/
bshare.rs

1use burn::{
2    prelude::Backend,
3    tensor::{Float, Int, Tensor, TensorKind},
4};
5use nalgebra as na;
6use ndarray as nd;
7
8// ================ Tensor to Data Functions ================
9// Handle Float tensors for both 1D and 2D
10/// Convert a burn float tensor to a Vec on wasm
11#[cfg(target_arch = "wasm32")]
12pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
13    // tensor.to_data().block_on().to_vec::<f32>().unwrap()
14    tensor.to_data().to_vec::<f32>().unwrap()
15}
16
17/// Convert a burn float tensor to a Vec
18#[cfg(not(target_arch = "wasm32"))]
19pub fn tensor_to_data_float<B: Backend, const D: usize>(tensor: &Tensor<B, D, Float>) -> Vec<f32> {
20    tensor.to_data().to_vec::<f32>().unwrap()
21}
22
23// Handle Int tensors for both 1D and 2D
24/// Convert a burn int tensor to a Vec on wasm
25#[cfg(target_arch = "wasm32")]
26#[allow(clippy::cast_possible_truncation)]
27pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
28    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
29        return data;
30    }
31
32    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
33    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
34    data_i64.into_iter().map(|x| x as i32).collect()
35}
36
37/// Convert a burn int tensor to a Vec
38#[cfg(not(target_arch = "wasm32"))]
39#[allow(clippy::cast_possible_truncation)]
40pub fn tensor_to_data_int<B: Backend, const D: usize>(tensor: &Tensor<B, D, Int>) -> Vec<i32> {
41    // tensor.to_data().to_vec::<i32>().unwrap()
42    if let Ok(data) = tensor.to_data().to_vec::<i32>() {
43        return data;
44    }
45
46    // Fallback: Attempt `i64` conversion and downcast to `i32` if `i32` fails
47    let data_i64: Vec<i64> = tensor.to_data().to_vec::<i64>().unwrap();
48    data_i64.into_iter().map(|x| x as i32).collect()
49}
50
51// ================ To and Into Burn Conversions ================
52
53/// Trait for converting ndarray to burn tensor (generic over Float/Int and
54/// dimensionality)
55pub trait ToBurn<B: Backend, const D: usize, T: TensorKind<B>> {
56    fn to_burn(&self, device: &B::Device) -> Tensor<B, D, T>;
57    fn into_burn(self, device: &B::Device) -> Tensor<B, D, T>;
58}
59
60/// Implementation of the trait for 2D Float ndarray
61impl<B: Backend> ToBurn<B, 2, Float> for nd::Array2<f32> {
62    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
63        let vec: Vec<f32>;
64        let bytes = if self.is_standard_layout() {
65            self.as_slice().unwrap()
66        } else {
67            vec = self.iter().copied().collect();
68            vec.as_slice()
69        };
70        let shape = [self.nrows(), self.ncols()];
71        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
72    }
73
74    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
75        let vec: Vec<f32>;
76        let bytes = if self.is_standard_layout() {
77            self.as_slice().expect("Array should have a slice if it's in standard layout")
78        } else {
79            vec = self.iter().copied().collect();
80            vec.as_slice()
81        };
82        let shape = [self.nrows(), self.ncols()];
83        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
84    }
85}
86
87/// Trait implementation for 1D Float ndarray
88impl<B: Backend> ToBurn<B, 1, Float> for nd::Array1<f32> {
89    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Float> {
90        let vec: Vec<f32> = self.iter().copied().collect();
91        Tensor::<B, 1, Float>::from_floats(&vec[..], device)
92    }
93
94    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Float> {
95        let vec: Vec<f32>;
96        let bytes = if self.is_standard_layout() {
97            self.as_slice().expect("Array should have a slice if it's in standard layout")
98        } else {
99            vec = self.iter().copied().collect();
100            vec.as_slice()
101        };
102        Tensor::<B, 1, Float>::from_floats(bytes, device)
103    }
104}
105
106/// Trait implementation for 2D Int ndarray
107impl<B: Backend> ToBurn<B, 2, Int> for nd::Array2<u32> {
108    #[allow(clippy::cast_possible_wrap)]
109    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
110        let array_i32 = self.mapv(|x| x as i32);
111        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
112        // let vec: Vec<i32> = array_i32.into_raw_vec();
113        let shape = [self.nrows(), self.ncols()];
114        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
115    }
116
117    #[allow(clippy::cast_possible_wrap)]
118    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
119        let array_i32 = self.mapv(|x| x as i32);
120        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
121        // let vec: Vec<i32> = array_i32.into_raw_vec();
122        let shape = [self.nrows(), self.ncols()];
123        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
124    }
125}
126
127/// Trait implementation for 1D Int ndarray
128impl<B: Backend> ToBurn<B, 1, Int> for nd::Array1<u32> {
129    #[allow(clippy::cast_possible_wrap)]
130    fn to_burn(&self, device: &B::Device) -> Tensor<B, 1, Int> {
131        let array_i32 = self.mapv(|x| x as i32);
132        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
133        // let vec: Vec<i32> = array_i32.into_raw_vec();
134        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
135    }
136
137    #[allow(clippy::cast_possible_wrap)]
138    fn into_burn(self, device: &B::Device) -> Tensor<B, 1, Int> {
139        let array_i32 = self.mapv(|x| x as i32);
140        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
141        // let vec: Vec<i32> = array_i32.into_raw_vec();
142        Tensor::<B, 1, Int>::from_ints(&vec[..], device)
143    }
144}
145impl<B: Backend> ToBurn<B, 3, Float> for nd::Array3<f32> {
146    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Float> {
147        let vec: Vec<f32>;
148        let bytes = if self.is_standard_layout() {
149            self.as_slice().unwrap()
150        } else {
151            vec = self.iter().copied().collect();
152            vec.as_slice()
153        };
154        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
155        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
156    }
157
158    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Float> {
159        let vec: Vec<f32>;
160        let bytes = if self.is_standard_layout() {
161            self.as_slice().expect("Array should have a slice if it's in standard layout")
162        } else {
163            vec = self.iter().copied().collect();
164            vec.as_slice()
165        };
166        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
167        Tensor::<B, 1, Float>::from_floats(bytes, device).reshape(shape)
168    }
169}
170
171/// Trait implementation for 3D Int ndarray
172impl<B: Backend> ToBurn<B, 3, Int> for nd::Array3<u32> {
173    #[allow(clippy::cast_possible_wrap)]
174    fn to_burn(&self, device: &B::Device) -> Tensor<B, 3, Int> {
175        let array_i32 = self.mapv(|x| x as i32);
176        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
177        // let vec: Vec<i32> = array_i32.into_raw_vec();
178        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
179        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
180    }
181
182    #[allow(clippy::cast_possible_wrap)]
183    fn into_burn(self, device: &B::Device) -> Tensor<B, 3, Int> {
184        let array_i32 = self.mapv(|x| x as i32);
185        let vec: Vec<i32> = array_i32.into_raw_vec_and_offset().0;
186        // let vec: Vec<i32> = array_i32.into_raw_vec();
187        let shape = [self.shape()[0], self.shape()[1], self.shape()[2]];
188        Tensor::<B, 1, Int>::from_ints(&vec[..], device).reshape(shape)
189    }
190}
191/// Implement `ToBurn` for converting `nalgebra::DMatrix<f32>` to a burn tensor
192/// (Float type)
193impl<B: Backend> ToBurn<B, 2, Float> for na::DMatrix<f32> {
194    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Float> {
195        let num_rows = self.nrows();
196        let num_cols = self.ncols();
197        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
198        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
199    }
200
201    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Float> {
202        let num_rows = self.nrows();
203        let num_cols = self.ncols();
204        let flattened: Vec<f32> = self.transpose().as_slice().to_vec();
205        Tensor::<B, 1, Float>::from_floats(&flattened[..], device).reshape([num_rows, num_cols])
206    }
207}
208
209/// Implement `ToBurn` for converting `nalgebra::DMatrix<u32>` to a burn tensor
210/// (Int type)
211impl<B: Backend> ToBurn<B, 2, Int> for na::DMatrix<u32> {
212    fn to_burn(&self, device: &B::Device) -> Tensor<B, 2, Int> {
213        let num_rows = self.nrows();
214        let num_cols = self.ncols();
215        let flattened: Vec<i32> = self
216            .transpose()
217            .as_slice()
218            .iter()
219            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
220            .collect();
221        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
222    }
223
224    fn into_burn(self, device: &B::Device) -> Tensor<B, 2, Int> {
225        let num_rows = self.nrows();
226        let num_cols = self.ncols();
227        let flattened: Vec<i32> = self
228            .transpose()
229            .as_slice()
230            .iter()
231            .map(|&x| i32::try_from(x).expect("Value out of range for i32"))
232            .collect();
233        Tensor::<B, 1, Int>::from_ints(&flattened[..], device).reshape([num_rows, num_cols])
234    }
235}
236
237// ================ To and Into NdArray Conversions ================
238
239/// Trait for converting burn tensor to ndarray (generic over Float/Int and
240/// dimensionality)
241pub trait ToNdArray<B: Backend, const D: usize, T> {
242    fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
243    fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
244}
245
246/// Trait implementation for converting 2D Float burn tensor to ndarray
247impl<B: Backend> ToNdArray<B, 2, f32> for Tensor<B, 2, Float> {
248    fn to_ndarray(&self) -> nd::Array2<f32> {
249        let tensor_data = tensor_to_data_float(self);
250        let shape = self.dims();
251        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
252    }
253
254    fn into_ndarray(self) -> nd::Array2<f32> {
255        let tensor_data = tensor_to_data_float(&self);
256        let shape = self.dims();
257        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data).unwrap()
258    }
259}
260
261/// Trait implementation for converting 1D Float burn tensor to ndarray
262impl<B: Backend> ToNdArray<B, 1, f32> for Tensor<B, 1, Float> {
263    fn to_ndarray(&self) -> nd::Array1<f32> {
264        let tensor_data = tensor_to_data_float(self);
265        nd::Array1::from_vec(tensor_data)
266    }
267
268    fn into_ndarray(self) -> nd::Array1<f32> {
269        let tensor_data = tensor_to_data_float(&self);
270        nd::Array1::from_vec(tensor_data)
271    }
272}
273
274/// Trait implementation for converting 2D Int burn tensor to ndarray
275#[allow(clippy::cast_sign_loss)]
276impl<B: Backend> ToNdArray<B, 2, u32> for Tensor<B, 2, Int> {
277    fn to_ndarray(&self) -> nd::Array2<u32> {
278        let tensor_data = tensor_to_data_int(self);
279        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
280        let shape = self.dims();
281        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
282    }
283
284    fn into_ndarray(self) -> nd::Array2<u32> {
285        let tensor_data = tensor_to_data_int(&self);
286        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
287        let shape = self.dims();
288        nd::Array2::from_shape_vec((shape[0], shape[1]), tensor_data_u32).unwrap()
289    }
290}
291
292/// Trait implementation for converting 1D Int burn tensor to ndarray
293#[allow(clippy::cast_sign_loss)]
294impl<B: Backend> ToNdArray<B, 1, u32> for Tensor<B, 1, Int> {
295    fn to_ndarray(&self) -> nd::Array1<u32> {
296        let tensor_data = tensor_to_data_int(self);
297        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
298        nd::Array1::from_vec(tensor_data_u32)
299    }
300
301    fn into_ndarray(self) -> nd::Array1<u32> {
302        let tensor_data = tensor_to_data_int(&self);
303        let tensor_data_u32: Vec<u32> = tensor_data.into_iter().map(|x| x as u32).collect();
304        nd::Array1::from_vec(tensor_data_u32)
305    }
306}
307
308// ================ To and Into Nalgebra Conversions ================
309
310/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
311/// `nalgebra::DVector` (Float type)
312pub trait ToNalgebraFloat<B: Backend, const D: usize> {
313    fn to_nalgebra(&self) -> na::DMatrix<f32>;
314    fn into_nalgebra(self) -> na::DMatrix<f32>;
315}
316
317/// Trait for converting `burn` tensor to `nalgebra::DMatrix` or
318/// `nalgebra::DVector` (Int type)
319pub trait ToNalgebraInt<B: Backend, const D: usize> {
320    fn to_nalgebra(&self) -> na::DMatrix<u32>;
321    fn into_nalgebra(self) -> na::DMatrix<u32>;
322}
323
324/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<f32>` (Float
325/// type)
326impl<B: Backend> ToNalgebraFloat<B, 2> for Tensor<B, 2, Float> {
327    fn to_nalgebra(&self) -> na::DMatrix<f32> {
328        let data = tensor_to_data_float(self);
329        let shape = self.shape().dims;
330        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
331    }
332
333    fn into_nalgebra(self) -> na::DMatrix<f32> {
334        let data = tensor_to_data_float(&self);
335        let shape = self.shape().dims;
336        na::DMatrix::from_vec(shape[1], shape[0], data).transpose()
337    }
338}
339
340/// Implement trait to convert `burn` tensor to `nalgebra::DMatrix<u32>` (Int
341/// type)
342impl<B: Backend> ToNalgebraInt<B, 2> for Tensor<B, 2, Int> {
343    #[allow(clippy::cast_sign_loss)]
344    fn to_nalgebra(&self) -> na::DMatrix<u32> {
345        let data = tensor_to_data_int(self);
346        let shape = self.shape().dims;
347        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
348        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
349    }
350    #[allow(clippy::cast_sign_loss)]
351    fn into_nalgebra(self) -> na::DMatrix<u32> {
352        let data = tensor_to_data_int(&self);
353        let shape = self.shape().dims;
354        let data_u32: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
355        na::DMatrix::from_vec(shape[1], shape[0], data_u32).transpose()
356    }
357}