gloss_utils/
tensor.rs

1use burn::{
2    backend::{candle::CandleDevice, ndarray::NdArrayDevice, wgpu::WgpuDevice, Candle, NdArray, Wgpu},
3    prelude::Backend,
4    tensor::{Float, Int, Tensor},
5};
6use core::panic;
7use ndarray as nd;
8// TODO: Consider using enum-dispatch whenever possible
9
10use crate::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn, ToNalgebraFloat, ToNalgebraInt, ToNdArray};
11extern crate nalgebra as na;
12use bytemuck;
13use log::warn;
14
15pub type DefaultBackend = NdArray; // Change this as needed
16
17#[derive(Clone, Debug)]
18pub enum BurnBackend {
19    Candle,
20    NdArray,
21    Wgpu,
22}
23
24/// `DynamicTensor` enum for Dynamic backend tensors in burn
25#[derive(Clone, Debug)]
26pub enum DynamicTensorFloat1D {
27    NdArray(Tensor<NdArray, 1, Float>),
28    Wgpu(Tensor<Wgpu, 1, Float>),
29    Candle(Tensor<Candle, 1, Float>),
30}
31
32/// `DynamicTensor` enum for Dynamic backend tensors in burn
33#[derive(Clone, Debug)]
34pub enum DynamicTensorFloat2D {
35    NdArray(Tensor<NdArray, 2, Float>),
36    Wgpu(Tensor<Wgpu, 2, Float>),
37    Candle(Tensor<Candle, 2, Float>),
38}
39
40/// `DynamicTensor` enum for Dynamic backend tensors in burn
41#[derive(Clone, Debug)]
42pub enum DynamicTensorFloat3D {
43    NdArray(Tensor<NdArray, 3, Float>),
44    Wgpu(Tensor<Wgpu, 3, Float>),
45    Candle(Tensor<Candle, 3, Float>),
46}
47
48/// `DynamicTensor` enum for Dynamic backend tensors in burn
49#[derive(Clone, Debug)]
50pub enum DynamicTensorInt1D {
51    NdArray(Tensor<NdArray, 1, Int>),
52    Wgpu(Tensor<Wgpu, 1, Int>),
53    Candle(Tensor<Candle, 1, Int>),
54}
55
56/// `DynamicTensor` enum for Dynamic backend tensors in burn
57#[derive(Clone, Debug)]
58pub enum DynamicTensorInt2D {
59    NdArray(Tensor<NdArray, 2, Int>),
60    Wgpu(Tensor<Wgpu, 2, Int>),
61    Candle(Tensor<Candle, 2, Int>),
62}
63
64/// `DynamicTensor` enum for Dynamic backend tensors in burn
65#[derive(Clone, Debug)]
66pub enum DynamicTensorInt3D {
67    NdArray(Tensor<NdArray, 3, Int>),
68    Wgpu(Tensor<Wgpu, 3, Int>),
69    Candle(Tensor<Candle, 3, Int>),
70}
71
72/// From methods for converting from Tensor to `DynamicTensor`
73impl DynamicTensorFloat1D {
74    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 1, Float>) -> Self {
75        DynamicTensorFloat1D::NdArray(tensor)
76    }
77    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 1, Float>) -> Self {
78        DynamicTensorFloat1D::Wgpu(tensor)
79    }
80    pub fn from_candle_backend(tensor: Tensor<Candle, 1, Float>) -> Self {
81        DynamicTensorFloat1D::Candle(tensor)
82    }
83}
84
85/// From methods for converting from Tensor to `DynamicTensor`
86impl DynamicTensorFloat2D {
87    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 2, Float>) -> Self {
88        DynamicTensorFloat2D::NdArray(tensor)
89    }
90    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 2, Float>) -> Self {
91        DynamicTensorFloat2D::Wgpu(tensor)
92    }
93    pub fn from_candle_backend(tensor: Tensor<Candle, 2, Float>) -> Self {
94        DynamicTensorFloat2D::Candle(tensor)
95    }
96}
97
98/// From methods for converting from Tensor to `DynamicTensor`
99impl DynamicTensorFloat3D {
100    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 3, Float>) -> Self {
101        DynamicTensorFloat3D::NdArray(tensor)
102    }
103    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 3, Float>) -> Self {
104        DynamicTensorFloat3D::Wgpu(tensor)
105    }
106    pub fn from_candle_backend(tensor: Tensor<Candle, 3, Float>) -> Self {
107        DynamicTensorFloat3D::Candle(tensor)
108    }
109}
110
111/// From methods for converting from Tensor to `DynamicTensor`
112impl DynamicTensorInt1D {
113    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 1, Int>) -> Self {
114        DynamicTensorInt1D::NdArray(tensor)
115    }
116    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 1, Int>) -> Self {
117        DynamicTensorInt1D::Wgpu(tensor)
118    }
119    pub fn from_candle_backend(tensor: Tensor<Candle, 1, Int>) -> Self {
120        DynamicTensorInt1D::Candle(tensor)
121    }
122}
123
124/// From methods for converting from Tensor to `DynamicTensor`
125impl DynamicTensorInt2D {
126    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 2, Int>) -> Self {
127        DynamicTensorInt2D::NdArray(tensor)
128    }
129    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 2, Int>) -> Self {
130        DynamicTensorInt2D::Wgpu(tensor)
131    }
132    pub fn from_candle_backend(tensor: Tensor<Candle, 2, Int>) -> Self {
133        DynamicTensorInt2D::Candle(tensor)
134    }
135}
136
137/// From methods for converting from Tensor to `DynamicTensor`
138impl DynamicTensorInt3D {
139    pub fn from_ndarray_backend(tensor: Tensor<NdArray, 3, Int>) -> Self {
140        DynamicTensorInt3D::NdArray(tensor)
141    }
142    pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 3, Int>) -> Self {
143        DynamicTensorInt3D::Wgpu(tensor)
144    }
145    pub fn from_candle_backend(tensor: Tensor<Candle, 3, Int>) -> Self {
146        DynamicTensorInt3D::Candle(tensor)
147    }
148}
149
150// Conversion and Utility Operations for DynamicTensor variants
151
152/// Trait for common `DynamicTensor` operations
153pub trait DynamicTensorOps<T> {
154    fn as_bytes(&self) -> Vec<u8>;
155
156    fn nrows(&self) -> usize;
157    fn shape(&self) -> Vec<usize>;
158
159    fn to_vec(&self) -> Vec<T>;
160    fn min_vec(&self) -> Vec<T>;
161    fn max_vec(&self) -> Vec<T>;
162}
163
164/// `DynamicTensorOps` for Float 1D tensors
165impl DynamicTensorOps<f32> for DynamicTensorFloat1D {
166    fn as_bytes(&self) -> Vec<u8> {
167        match self {
168            DynamicTensorFloat1D::NdArray(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
169            DynamicTensorFloat1D::Wgpu(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
170            DynamicTensorFloat1D::Candle(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
171        }
172    }
173
174    fn nrows(&self) -> usize {
175        match self {
176            DynamicTensorFloat1D::NdArray(tensor) => tensor.dims()[0],
177            DynamicTensorFloat1D::Wgpu(tensor) => tensor.dims()[0],
178            DynamicTensorFloat1D::Candle(tensor) => tensor.dims()[0],
179        }
180    }
181
182    fn shape(&self) -> Vec<usize> {
183        match self {
184            DynamicTensorFloat1D::NdArray(tensor) => vec![tensor.dims()[0]],
185            DynamicTensorFloat1D::Wgpu(tensor) => vec![tensor.dims()[0]],
186            DynamicTensorFloat1D::Candle(tensor) => vec![tensor.dims()[0]],
187        }
188    }
189
190    fn to_vec(&self) -> Vec<f32> {
191        match &self {
192            DynamicTensorFloat1D::NdArray(tensor) => tensor_to_data_float(tensor),
193            DynamicTensorFloat1D::Wgpu(tensor) => tensor_to_data_float(tensor),
194            DynamicTensorFloat1D::Candle(tensor) => tensor_to_data_float(tensor),
195        }
196    }
197
198    fn min_vec(&self) -> Vec<f32> {
199        vec![self.to_vec().iter().copied().fold(f32::INFINITY, f32::min)]
200    }
201
202    fn max_vec(&self) -> Vec<f32> {
203        vec![self.to_vec().iter().copied().fold(f32::NEG_INFINITY, f32::max)]
204    }
205}
206
207/// `DynamicTensorOps` for Float 2D tensors
208impl DynamicTensorOps<f32> for DynamicTensorFloat2D {
209    fn as_bytes(&self) -> Vec<u8> {
210        match self {
211            DynamicTensorFloat2D::NdArray(tensor) => {
212                let tensor_data = tensor_to_data_float(tensor);
213                bytemuck::cast_slice(&tensor_data).to_vec()
214            }
215            DynamicTensorFloat2D::Wgpu(tensor) => {
216                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
217                let tensor_data = tensor_to_data_float(tensor);
218                bytemuck::cast_slice(&tensor_data).to_vec()
219            }
220            DynamicTensorFloat2D::Candle(tensor) => {
221                let tensor_data = tensor_to_data_float(tensor);
222                bytemuck::cast_slice(&tensor_data).to_vec()
223            }
224        }
225    }
226
227    fn nrows(&self) -> usize {
228        match self {
229            DynamicTensorFloat2D::NdArray(tensor) => tensor.dims()[0],
230            DynamicTensorFloat2D::Wgpu(tensor) => tensor.dims()[0],
231            DynamicTensorFloat2D::Candle(tensor) => tensor.dims()[0],
232        }
233    }
234
235    fn shape(&self) -> Vec<usize> {
236        match self {
237            DynamicTensorFloat2D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
238            DynamicTensorFloat2D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
239            DynamicTensorFloat2D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
240        }
241    }
242
243    fn to_vec(&self) -> Vec<f32> {
244        match &self {
245            DynamicTensorFloat2D::NdArray(tensor) => tensor_to_data_float(tensor),
246            DynamicTensorFloat2D::Wgpu(tensor) => {
247                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
248                tensor_to_data_float(tensor)
249            }
250            DynamicTensorFloat2D::Candle(tensor) => tensor_to_data_float(tensor),
251        }
252    }
253
254    fn min_vec(&self) -> Vec<f32> {
255        match &self {
256            DynamicTensorFloat2D::NdArray(tensor) => {
257                let min_tensor = tensor.clone().min_dim(0);
258                tensor_to_data_float(&min_tensor)
259            }
260            DynamicTensorFloat2D::Wgpu(tensor) => {
261                let min_tensor = tensor.clone().min_dim(0);
262                tensor_to_data_float(&min_tensor)
263            }
264            DynamicTensorFloat2D::Candle(tensor) => {
265                let min_tensor = tensor.clone().min_dim(0);
266                tensor_to_data_float(&min_tensor)
267            }
268        }
269    }
270
271    fn max_vec(&self) -> Vec<f32> {
272        match &self {
273            DynamicTensorFloat2D::NdArray(tensor) => {
274                let max_tensor = tensor.clone().max_dim(0);
275                tensor_to_data_float(&max_tensor)
276            }
277            DynamicTensorFloat2D::Wgpu(tensor) => {
278                let max_tensor = tensor.clone().max_dim(0);
279                tensor_to_data_float(&max_tensor)
280            }
281            DynamicTensorFloat2D::Candle(tensor) => {
282                let max_tensor = tensor.clone().max_dim(0);
283                tensor_to_data_float(&max_tensor)
284            }
285        }
286    }
287}
288
289/// `DynamicTensorOps` for Float 3D tensors
290impl DynamicTensorOps<f32> for DynamicTensorFloat3D {
291    fn as_bytes(&self) -> Vec<u8> {
292        match self {
293            DynamicTensorFloat3D::NdArray(tensor) => {
294                let tensor_data = tensor_to_data_float(tensor);
295                bytemuck::cast_slice(&tensor_data).to_vec()
296            }
297            DynamicTensorFloat3D::Wgpu(tensor) => {
298                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
299                let tensor_data = tensor_to_data_float(tensor);
300                bytemuck::cast_slice(&tensor_data).to_vec()
301            }
302            DynamicTensorFloat3D::Candle(tensor) => {
303                let tensor_data = tensor_to_data_float(tensor);
304                bytemuck::cast_slice(&tensor_data).to_vec()
305            }
306        }
307    }
308
309    fn nrows(&self) -> usize {
310        match self {
311            DynamicTensorFloat3D::NdArray(tensor) => tensor.dims()[0],
312            DynamicTensorFloat3D::Wgpu(tensor) => tensor.dims()[0],
313            DynamicTensorFloat3D::Candle(tensor) => tensor.dims()[0],
314        }
315    }
316
317    fn shape(&self) -> Vec<usize> {
318        match self {
319            DynamicTensorFloat3D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
320            DynamicTensorFloat3D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
321            DynamicTensorFloat3D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
322        }
323    }
324
325    fn to_vec(&self) -> Vec<f32> {
326        match &self {
327            DynamicTensorFloat3D::NdArray(tensor) => tensor_to_data_float(tensor),
328            DynamicTensorFloat3D::Wgpu(tensor) => {
329                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
330                tensor_to_data_float(tensor)
331            }
332            DynamicTensorFloat3D::Candle(tensor) => tensor_to_data_float(tensor),
333        }
334    }
335
336    fn min_vec(&self) -> Vec<f32> {
337        match &self {
338            DynamicTensorFloat3D::NdArray(tensor) => {
339                let min_tensor = tensor.clone().min_dim(0);
340                tensor_to_data_float(&min_tensor)
341            }
342            DynamicTensorFloat3D::Wgpu(tensor) => {
343                let min_tensor = tensor.clone().min_dim(0);
344                tensor_to_data_float(&min_tensor)
345            }
346            DynamicTensorFloat3D::Candle(tensor) => {
347                let min_tensor = tensor.clone().min_dim(0);
348                tensor_to_data_float(&min_tensor)
349            }
350        }
351    }
352
353    fn max_vec(&self) -> Vec<f32> {
354        match &self {
355            DynamicTensorFloat3D::NdArray(tensor) => {
356                let max_tensor = tensor.clone().max_dim(0);
357                tensor_to_data_float(&max_tensor)
358            }
359            DynamicTensorFloat3D::Wgpu(tensor) => {
360                let max_tensor = tensor.clone().max_dim(0);
361                tensor_to_data_float(&max_tensor)
362            }
363            DynamicTensorFloat3D::Candle(tensor) => {
364                let max_tensor = tensor.clone().max_dim(0);
365                tensor_to_data_float(&max_tensor)
366            }
367        }
368    }
369}
370
371/// `DynamicTensorOps` for Int 1D tensors
372impl DynamicTensorOps<u32> for DynamicTensorInt1D {
373    fn as_bytes(&self) -> Vec<u8> {
374        match self {
375            DynamicTensorInt1D::NdArray(tensor) => {
376                let tensor_data = tensor_to_data_int(tensor);
377                let u32_data: Vec<u32> = tensor_data
378                    .into_iter()
379                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
380                    .collect();
381                bytemuck::cast_slice(&u32_data).to_vec()
382            }
383            DynamicTensorInt1D::Wgpu(tensor) => {
384                let tensor_data = tensor_to_data_int(tensor);
385                let u32_data: Vec<u32> = tensor_data
386                    .into_iter()
387                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
388                    .collect();
389                bytemuck::cast_slice(&u32_data).to_vec()
390            }
391            DynamicTensorInt1D::Candle(tensor) => {
392                let tensor_data = tensor_to_data_int(tensor);
393                let u32_data: Vec<u32> = tensor_data
394                    .into_iter()
395                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
396                    .collect();
397                bytemuck::cast_slice(&u32_data).to_vec()
398            }
399        }
400    }
401
402    fn nrows(&self) -> usize {
403        match self {
404            DynamicTensorInt1D::NdArray(tensor) => tensor.dims()[0],
405            DynamicTensorInt1D::Wgpu(tensor) => tensor.dims()[0],
406            DynamicTensorInt1D::Candle(tensor) => tensor.dims()[0],
407        }
408    }
409
410    fn shape(&self) -> Vec<usize> {
411        match self {
412            DynamicTensorInt1D::NdArray(tensor) => vec![tensor.dims()[0]],
413            DynamicTensorInt1D::Wgpu(tensor) => vec![tensor.dims()[0]],
414            DynamicTensorInt1D::Candle(tensor) => vec![tensor.dims()[0]],
415        }
416    }
417
418    fn to_vec(&self) -> Vec<u32> {
419        match &self {
420            DynamicTensorInt1D::NdArray(tensor) => {
421                let data = tensor_to_data_int(tensor);
422                data.into_iter()
423                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
424                    .collect()
425            }
426            DynamicTensorInt1D::Wgpu(tensor) => {
427                let data = tensor_to_data_int(tensor);
428                data.into_iter()
429                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
430                    .collect()
431            }
432            DynamicTensorInt1D::Candle(tensor) => {
433                let data = tensor_to_data_int(tensor);
434                data.into_iter()
435                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
436                    .collect()
437            }
438        }
439    }
440
441    fn min_vec(&self) -> Vec<u32> {
442        vec![self.to_vec().into_iter().min().unwrap_or(0)]
443    }
444
445    fn max_vec(&self) -> Vec<u32> {
446        vec![self.to_vec().into_iter().max().unwrap_or(0)]
447    }
448}
449
450/// `DynamicTensorOps` for Int 2D tensors
451impl DynamicTensorOps<u32> for DynamicTensorInt2D {
452    fn as_bytes(&self) -> Vec<u8> {
453        match self {
454            DynamicTensorInt2D::NdArray(tensor) => {
455                let tensor_data = tensor_to_data_int(tensor);
456                let u32_data: Vec<u32> = tensor_data
457                    .into_iter()
458                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
459                    .collect();
460                bytemuck::cast_slice(&u32_data).to_vec()
461            }
462            DynamicTensorInt2D::Wgpu(tensor) => {
463                let tensor_data = tensor_to_data_int(tensor);
464                let u32_data: Vec<u32> = tensor_data
465                    .into_iter()
466                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
467                    .collect();
468                bytemuck::cast_slice(&u32_data).to_vec()
469            }
470            DynamicTensorInt2D::Candle(tensor) => {
471                let tensor_data = tensor_to_data_int(tensor);
472                let u32_data: Vec<u32> = tensor_data
473                    .into_iter()
474                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
475                    .collect();
476                bytemuck::cast_slice(&u32_data).to_vec()
477            }
478        }
479    }
480
481    fn nrows(&self) -> usize {
482        match self {
483            DynamicTensorInt2D::NdArray(tensor) => tensor.dims()[0],
484            DynamicTensorInt2D::Wgpu(tensor) => tensor.dims()[0],
485            DynamicTensorInt2D::Candle(tensor) => tensor.dims()[0],
486        }
487    }
488
489    fn shape(&self) -> Vec<usize> {
490        match self {
491            DynamicTensorInt2D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
492            DynamicTensorInt2D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
493            DynamicTensorInt2D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
494        }
495    }
496
497    fn to_vec(&self) -> Vec<u32> {
498        match &self {
499            DynamicTensorInt2D::NdArray(tensor) => {
500                let data = tensor_to_data_int(tensor);
501                data.into_iter()
502                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
503                    .collect()
504            }
505            DynamicTensorInt2D::Wgpu(tensor) => {
506                let data = tensor_to_data_int(tensor);
507                data.into_iter()
508                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
509                    .collect()
510            }
511            DynamicTensorInt2D::Candle(tensor) => {
512                let data = tensor_to_data_int(tensor);
513                data.into_iter()
514                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
515                    .collect()
516            }
517        }
518    }
519
520    fn min_vec(&self) -> Vec<u32> {
521        match &self {
522            DynamicTensorInt2D::NdArray(tensor) => {
523                let min_tensor = tensor.clone().min_dim(0);
524                tensor_to_data_int(&min_tensor)
525                    .into_iter()
526                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
527                    .collect()
528            }
529            DynamicTensorInt2D::Wgpu(tensor) => {
530                let min_tensor = tensor.clone().min_dim(0);
531                tensor_to_data_int(&min_tensor)
532                    .into_iter()
533                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
534                    .collect()
535            }
536            DynamicTensorInt2D::Candle(tensor) => {
537                let min_tensor = tensor.clone().min_dim(0);
538                tensor_to_data_int(&min_tensor)
539                    .into_iter()
540                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
541                    .collect()
542            }
543        }
544    }
545
546    fn max_vec(&self) -> Vec<u32> {
547        match &self {
548            DynamicTensorInt2D::NdArray(tensor) => {
549                let max_tensor = tensor.clone().max_dim(0);
550                tensor_to_data_int(&max_tensor)
551                    .into_iter()
552                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
553                    .collect()
554            }
555            DynamicTensorInt2D::Wgpu(tensor) => {
556                let max_tensor = tensor.clone().max_dim(0);
557                tensor_to_data_int(&max_tensor)
558                    .into_iter()
559                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
560                    .collect()
561            }
562            DynamicTensorInt2D::Candle(tensor) => {
563                let max_tensor = tensor.clone().max_dim(0);
564                tensor_to_data_int(&max_tensor)
565                    .into_iter()
566                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
567                    .collect()
568            }
569        }
570    }
571}
572
573/// `DynamicTensorOps` for Int 2D tensors
574impl DynamicTensorOps<u32> for DynamicTensorInt3D {
575    fn as_bytes(&self) -> Vec<u8> {
576        match self {
577            DynamicTensorInt3D::NdArray(tensor) => {
578                let tensor_data = tensor_to_data_int(tensor);
579                let u32_data: Vec<u32> = tensor_data
580                    .into_iter()
581                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
582                    .collect();
583                bytemuck::cast_slice(&u32_data).to_vec()
584            }
585            DynamicTensorInt3D::Wgpu(tensor) => {
586                let tensor_data = tensor_to_data_int(tensor);
587                let u32_data: Vec<u32> = tensor_data
588                    .into_iter()
589                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
590                    .collect();
591                bytemuck::cast_slice(&u32_data).to_vec()
592            }
593            DynamicTensorInt3D::Candle(tensor) => {
594                let tensor_data = tensor_to_data_int(tensor);
595                let u32_data: Vec<u32> = tensor_data
596                    .into_iter()
597                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
598                    .collect();
599                bytemuck::cast_slice(&u32_data).to_vec()
600            }
601        }
602    }
603
604    fn nrows(&self) -> usize {
605        match self {
606            DynamicTensorInt3D::NdArray(tensor) => tensor.dims()[0],
607            DynamicTensorInt3D::Wgpu(tensor) => tensor.dims()[0],
608            DynamicTensorInt3D::Candle(tensor) => tensor.dims()[0],
609        }
610    }
611
612    fn shape(&self) -> Vec<usize> {
613        match self {
614            DynamicTensorInt3D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
615            DynamicTensorInt3D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
616            DynamicTensorInt3D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
617        }
618    }
619
620    fn to_vec(&self) -> Vec<u32> {
621        match &self {
622            DynamicTensorInt3D::NdArray(tensor) => {
623                let data = tensor_to_data_int(tensor);
624                data.into_iter()
625                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
626                    .collect()
627            }
628            DynamicTensorInt3D::Wgpu(tensor) => {
629                let data = tensor_to_data_int(tensor);
630                data.into_iter()
631                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
632                    .collect()
633            }
634            DynamicTensorInt3D::Candle(tensor) => {
635                let data = tensor_to_data_int(tensor);
636                data.into_iter()
637                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
638                    .collect()
639            }
640        }
641    }
642
643    fn min_vec(&self) -> Vec<u32> {
644        match &self {
645            DynamicTensorInt3D::NdArray(tensor) => {
646                let min_tensor = tensor.clone().min_dim(0);
647                tensor_to_data_int(&min_tensor)
648                    .into_iter()
649                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
650                    .collect()
651            }
652            DynamicTensorInt3D::Wgpu(tensor) => {
653                let min_tensor = tensor.clone().min_dim(0);
654                tensor_to_data_int(&min_tensor)
655                    .into_iter()
656                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
657                    .collect()
658            }
659            DynamicTensorInt3D::Candle(tensor) => {
660                let min_tensor = tensor.clone().min_dim(0);
661                tensor_to_data_int(&min_tensor)
662                    .into_iter()
663                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
664                    .collect()
665            }
666        }
667    }
668
669    fn max_vec(&self) -> Vec<u32> {
670        match &self {
671            DynamicTensorInt3D::NdArray(tensor) => {
672                let max_tensor = tensor.clone().max_dim(0);
673                tensor_to_data_int(&max_tensor)
674                    .into_iter()
675                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
676                    .collect()
677            }
678            DynamicTensorInt3D::Wgpu(tensor) => {
679                let max_tensor = tensor.clone().max_dim(0);
680                tensor_to_data_int(&max_tensor)
681                    .into_iter()
682                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
683                    .collect()
684            }
685            DynamicTensorInt3D::Candle(tensor) => {
686                let max_tensor = tensor.clone().max_dim(0);
687                tensor_to_data_int(&max_tensor)
688                    .into_iter()
689                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
690                    .collect()
691            }
692        }
693    }
694}
695
696/// Trait for conversion to and from nalgebra matrices
697pub trait DynamicMatrixOps<T, const D: usize> {
698    fn from_ndarray(array: &nd::Array<T, nd::Dim<[usize; D]>>) -> Self;
699    fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
700    fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
701    fn from_dmatrix(matrix: &na::DMatrix<T>) -> Self;
702    fn to_dmatrix(&self) -> na::DMatrix<T>;
703    fn into_dmatrix(self) -> na::DMatrix<T>;
704}
705
706/// `DynamicMatrixOps` for `DynamicTensorFloat2D`
707impl DynamicMatrixOps<f32, 2> for DynamicTensorFloat2D {
708    fn from_ndarray(array: &nd::Array<f32, nd::Dim<[usize; 2]>>) -> Self {
709        match std::any::TypeId::of::<DefaultBackend>() {
710            id if id == std::any::TypeId::of::<NdArray>() => {
711                let tensor = array.to_burn(&NdArrayDevice::Cpu);
712                DynamicTensorFloat2D::NdArray(tensor)
713            }
714            id if id == std::any::TypeId::of::<Candle>() => {
715                let tensor = array.to_burn(&CandleDevice::Cpu);
716                DynamicTensorFloat2D::Candle(tensor)
717            }
718            id if id == std::any::TypeId::of::<Wgpu>() => {
719                let tensor = array.to_burn(&WgpuDevice::BestAvailable);
720                DynamicTensorFloat2D::Wgpu(tensor)
721            }
722            _ => panic!("Unsupported backend!"),
723        }
724    }
725
726    fn to_ndarray(&self) -> nd::Array<f32, nd::Dim<[usize; 2]>> {
727        match self {
728            DynamicTensorFloat2D::NdArray(tensor) => tensor.to_ndarray(),
729            DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_ndarray(),
730            DynamicTensorFloat2D::Candle(tensor) => tensor.to_ndarray(),
731        }
732    }
733
734    fn into_ndarray(self) -> nd::Array<f32, nd::Dim<[usize; 2]>> {
735        match self {
736            DynamicTensorFloat2D::NdArray(tensor) => tensor.into_ndarray(),
737            DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_ndarray(),
738            DynamicTensorFloat2D::Candle(tensor) => tensor.into_ndarray(),
739        }
740    }
741
742    fn from_dmatrix(matrix: &na::DMatrix<f32>) -> Self {
743        match std::any::TypeId::of::<DefaultBackend>() {
744            id if id == std::any::TypeId::of::<NdArray>() => {
745                let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
746                DynamicTensorFloat2D::NdArray(tensor)
747            }
748            id if id == std::any::TypeId::of::<Candle>() => {
749                let tensor = matrix.to_burn(&CandleDevice::Cpu);
750                DynamicTensorFloat2D::Candle(tensor)
751            }
752            id if id == std::any::TypeId::of::<Wgpu>() => {
753                let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
754                DynamicTensorFloat2D::Wgpu(tensor)
755            }
756            _ => panic!("Unsupported backend!"),
757        }
758    }
759
760    fn to_dmatrix(&self) -> na::DMatrix<f32> {
761        match self {
762            DynamicTensorFloat2D::NdArray(tensor) => tensor.to_nalgebra(),
763            DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_nalgebra(),
764            DynamicTensorFloat2D::Candle(tensor) => tensor.to_nalgebra(),
765        }
766    }
767
768    fn into_dmatrix(self) -> na::DMatrix<f32> {
769        match self {
770            DynamicTensorFloat2D::NdArray(tensor) => tensor.into_nalgebra(),
771            DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_nalgebra(),
772            DynamicTensorFloat2D::Candle(tensor) => tensor.into_nalgebra(),
773        }
774    }
775}
776
777/// `DynamicMatrixOps` for `DynamicTensorFloat3D`
778impl DynamicMatrixOps<f32, 3> for DynamicTensorFloat3D {
779    fn from_ndarray(array: &nd::Array<f32, nd::Dim<[usize; 3]>>) -> Self {
780        match std::any::TypeId::of::<DefaultBackend>() {
781            id if id == std::any::TypeId::of::<NdArray>() => {
782                let tensor = array.to_burn(&NdArrayDevice::Cpu);
783                DynamicTensorFloat3D::NdArray(tensor)
784            }
785            id if id == std::any::TypeId::of::<Candle>() => {
786                let tensor = array.to_burn(&CandleDevice::Cpu);
787                DynamicTensorFloat3D::Candle(tensor)
788            }
789            id if id == std::any::TypeId::of::<Wgpu>() => {
790                let tensor = array.to_burn(&WgpuDevice::BestAvailable);
791                DynamicTensorFloat3D::Wgpu(tensor)
792            }
793            _ => panic!("Unsupported backend!"),
794        }
795    }
796
797    fn to_ndarray(&self) -> nd::Array<f32, nd::Dim<[usize; 3]>> {
798        match self {
799            DynamicTensorFloat3D::NdArray(tensor) => tensor.to_ndarray(),
800            DynamicTensorFloat3D::Wgpu(tensor) => tensor.to_ndarray(),
801            DynamicTensorFloat3D::Candle(tensor) => tensor.to_ndarray(),
802        }
803    }
804
805    fn into_ndarray(self) -> nd::Array<f32, nd::Dim<[usize; 3]>> {
806        match self {
807            DynamicTensorFloat3D::NdArray(tensor) => tensor.into_ndarray(),
808            DynamicTensorFloat3D::Wgpu(tensor) => tensor.into_ndarray(),
809            DynamicTensorFloat3D::Candle(tensor) => tensor.into_ndarray(),
810        }
811    }
812
813    fn from_dmatrix(_matrix: &na::DMatrix<f32>) -> Self {
814        panic!("3D DynamicTensor interop with DMatrix is not supported!");
815    }
816
817    fn to_dmatrix(&self) -> na::DMatrix<f32> {
818        panic!("3D DynamicTensor interop with DMatrix is not supported!");
819    }
820
821    fn into_dmatrix(self) -> na::DMatrix<f32> {
822        panic!("3D DynamicTensor interop with DMatrix is not supported!");
823    }
824}
825
826/// `DynamicMatrixOps` for `DynamicTensorInt2D`
827impl DynamicMatrixOps<u32, 2> for DynamicTensorInt2D {
828    fn from_ndarray(array: &nd::Array<u32, nd::Dim<[usize; 2]>>) -> Self {
829        match std::any::TypeId::of::<DefaultBackend>() {
830            id if id == std::any::TypeId::of::<NdArray>() => {
831                let tensor = array.to_burn(&NdArrayDevice::Cpu);
832                DynamicTensorInt2D::NdArray(tensor)
833            }
834            id if id == std::any::TypeId::of::<Candle>() => {
835                let tensor = array.to_burn(&CandleDevice::Cpu);
836                DynamicTensorInt2D::Candle(tensor)
837            }
838            id if id == std::any::TypeId::of::<Wgpu>() => {
839                let tensor = array.to_burn(&WgpuDevice::BestAvailable);
840                DynamicTensorInt2D::Wgpu(tensor)
841            }
842            _ => panic!("Unsupported backend!"),
843        }
844    }
845
846    fn to_ndarray(&self) -> nd::Array<u32, nd::Dim<[usize; 2]>> {
847        match self {
848            DynamicTensorInt2D::NdArray(tensor) => tensor.to_ndarray(),
849            DynamicTensorInt2D::Wgpu(tensor) => tensor.to_ndarray(),
850            DynamicTensorInt2D::Candle(tensor) => tensor.to_ndarray(),
851        }
852    }
853
854    fn into_ndarray(self) -> nd::Array<u32, nd::Dim<[usize; 2]>> {
855        match self {
856            DynamicTensorInt2D::NdArray(tensor) => tensor.into_ndarray(),
857            DynamicTensorInt2D::Wgpu(tensor) => tensor.into_ndarray(),
858            DynamicTensorInt2D::Candle(tensor) => tensor.into_ndarray(),
859        }
860    }
861
862    fn from_dmatrix(matrix: &na::DMatrix<u32>) -> Self {
863        match std::any::TypeId::of::<DefaultBackend>() {
864            id if id == std::any::TypeId::of::<NdArray>() => {
865                let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
866                DynamicTensorInt2D::NdArray(tensor)
867            }
868            id if id == std::any::TypeId::of::<Candle>() => {
869                let tensor = matrix.to_burn(&CandleDevice::Cpu);
870                DynamicTensorInt2D::Candle(tensor)
871            }
872            id if id == std::any::TypeId::of::<Wgpu>() => {
873                let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
874                DynamicTensorInt2D::Wgpu(tensor)
875            }
876            _ => panic!("Unsupported backend!"),
877        }
878    }
879
880    fn to_dmatrix(&self) -> na::DMatrix<u32> {
881        match self {
882            DynamicTensorInt2D::NdArray(tensor) => tensor.to_nalgebra(),
883            DynamicTensorInt2D::Wgpu(tensor) => tensor.to_nalgebra(),
884            DynamicTensorInt2D::Candle(tensor) => tensor.to_nalgebra(),
885        }
886    }
887
888    fn into_dmatrix(self) -> na::DMatrix<u32> {
889        match self {
890            DynamicTensorInt2D::NdArray(tensor) => tensor.into_nalgebra(),
891            DynamicTensorInt2D::Wgpu(tensor) => tensor.into_nalgebra(),
892            DynamicTensorInt2D::Candle(tensor) => tensor.into_nalgebra(),
893        }
894    }
895}
896
897/// `DynamicMatrixOps` for `DynamicTensorInt3D`
898impl DynamicMatrixOps<u32, 3> for DynamicTensorInt3D {
899    fn from_ndarray(array: &nd::Array<u32, nd::Dim<[usize; 3]>>) -> Self {
900        match std::any::TypeId::of::<DefaultBackend>() {
901            id if id == std::any::TypeId::of::<NdArray>() => {
902                let tensor = array.to_burn(&NdArrayDevice::Cpu);
903                DynamicTensorInt3D::NdArray(tensor)
904            }
905            id if id == std::any::TypeId::of::<Candle>() => {
906                let tensor = array.to_burn(&CandleDevice::Cpu);
907                DynamicTensorInt3D::Candle(tensor)
908            }
909            id if id == std::any::TypeId::of::<Wgpu>() => {
910                let tensor = array.to_burn(&WgpuDevice::BestAvailable);
911                DynamicTensorInt3D::Wgpu(tensor)
912            }
913            _ => panic!("Unsupported backend!"),
914        }
915    }
916
917    fn to_ndarray(&self) -> nd::Array<u32, nd::Dim<[usize; 3]>> {
918        match self {
919            DynamicTensorInt3D::NdArray(tensor) => tensor.to_ndarray(),
920            DynamicTensorInt3D::Wgpu(tensor) => tensor.to_ndarray(),
921            DynamicTensorInt3D::Candle(tensor) => tensor.to_ndarray(),
922        }
923    }
924
925    fn into_ndarray(self) -> nd::Array<u32, nd::Dim<[usize; 3]>> {
926        match self {
927            DynamicTensorInt3D::NdArray(tensor) => tensor.into_ndarray(),
928            DynamicTensorInt3D::Wgpu(tensor) => tensor.into_ndarray(),
929            DynamicTensorInt3D::Candle(tensor) => tensor.into_ndarray(),
930        }
931    }
932
933    fn from_dmatrix(_matrix: &na::DMatrix<u32>) -> Self {
934        panic!("3D DynamicTensor interop with DMatrix is not supported!");
935    }
936
937    fn to_dmatrix(&self) -> na::DMatrix<u32> {
938        panic!("3D DynamicTensor interop with DMatrix is not supported!");
939    }
940
941    fn into_dmatrix(self) -> na::DMatrix<u32> {
942        panic!("3D DynamicTensor interop with DMatrix is not supported!");
943    }
944}
945
946// /////////////////////////////////////////////////////////////////////////////
947// ///////////////////////////////////////// ///// Some burn utilities
948// /////////////////////////////////////////////////////////////////////////////
949// /////////////////////////////////////////
950/// Normalise a 2D tensor across dim 1
951pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
952    let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); // Compute the L2 norm along the last axis (dim = 1)
953    tensor.div(norm) // Divide each vector by its norm
954}
955
956/// Cross product of 2 2D Tensors
957pub fn cross_product<B: Backend>(
958    a: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
959    b: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
960) -> Tensor<B, 2, Float> {
961    // Split the input tensors along dimension 1 (the 3 components) using chunk
962    let a_chunks = a.clone().chunk(3, 1); // Split tensor `a` into 3 chunks: ax, ay, az
963    let b_chunks = b.clone().chunk(3, 1); // Split tensor `b` into 3 chunks: bx, by, bz
964
965    let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); // x component of a
966    let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); // y component of a
967    let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); // z component of a
968
969    let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); // x component of b
970    let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); // y component of b
971    let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); // z component of b
972
973    // Compute the components of the cross product
974    let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); // cx = ay * bz - az * by
975    let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); // cy = az * bx - ax * bz
976    let cz = ax.mul(by).sub(ay.mul(bx)); // cz = ax * by - ay * bx
977
978    // Stack the result to form the resulting [N, 3] tensor
979    Tensor::stack(vec![cx, cy, cz], 1) // Concatenate along the second dimension
980}