gloss_utils/
tensor.rs

1use core::panic;
2
3use burn::{
4    backend::{candle::CandleDevice, ndarray::NdArrayDevice, wgpu::WgpuDevice, Candle, NdArray, Wgpu},
5    prelude::Backend,
6    tensor::{Float, Int, Tensor},
7};
8// use burn::backend::ndarray::PrecisionBridge as NdArrayBridge;
9// use burn::backend::candle::PrecisionBridge as CandleBridge;
10// use burn::backend::wgpu::WebGpu PrecisionBridge as WgpuBridge;
11// use burn::tensor::backend::BackendBridge;
12
13use crate::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn, ToNalgebraFloat, ToNalgebraInt};
14extern crate nalgebra as na;
15use bytemuck;
16use log::warn;
17
18pub type DefaultBackend = NdArray; // Change this as needed
19
20#[derive(Clone)]
21pub enum BurnBackend {
22    Candle,
23    NdArray,
24    Wgpu,
25}
26
27/// `DynamicTensor` enum for Dynamic backend tensors in burn
28#[derive(Clone, Debug)]
29pub enum DynamicTensorFloat1D {
30    NdArray(Tensor<NdArray, 1, Float>),
31    Wgpu(Tensor<Wgpu, 1, Float>),
32    Candle(Tensor<Candle, 1, Float>),
33}
34
35/// `DynamicTensor` enum for Dynamic backend tensors in burn
36#[derive(Clone, Debug)]
37pub enum DynamicTensorFloat2D {
38    NdArray(Tensor<NdArray, 2, Float>),
39    Wgpu(Tensor<Wgpu, 2, Float>),
40    Candle(Tensor<Candle, 2, Float>),
41}
42
43/// `DynamicTensor` enum for Dynamic backend tensors in burn
44#[derive(Clone, Debug)]
45pub enum DynamicTensorInt1D {
46    NdArray(Tensor<NdArray, 1, Int>),
47    Wgpu(Tensor<Wgpu, 1, Int>),
48    Candle(Tensor<Candle, 1, Int>),
49}
50
51/// `DynamicTensor` enum for Dynamic backend tensors in burn
52#[derive(Clone, Debug)]
53pub enum DynamicTensorInt2D {
54    NdArray(Tensor<NdArray, 2, Int>),
55    Wgpu(Tensor<Wgpu, 2, Int>),
56    Candle(Tensor<Candle, 2, Int>),
57}
58
59/// From methods for converting from Tensor to `DynamicTensor`
60impl DynamicTensorFloat1D {
61    pub fn from_ndarray(tensor: Tensor<NdArray, 1, Float>) -> Self {
62        DynamicTensorFloat1D::NdArray(tensor)
63    }
64    pub fn from_wgpu(tensor: Tensor<Wgpu, 1, Float>) -> Self {
65        DynamicTensorFloat1D::Wgpu(tensor)
66    }
67    pub fn from_candle(tensor: Tensor<Candle, 1, Float>) -> Self {
68        DynamicTensorFloat1D::Candle(tensor)
69    }
70}
71
72/// From methods for converting from Tensor to `DynamicTensor`
73impl DynamicTensorFloat2D {
74    pub fn from_ndarray(tensor: Tensor<NdArray, 2, Float>) -> Self {
75        DynamicTensorFloat2D::NdArray(tensor)
76    }
77    pub fn from_wgpu(tensor: Tensor<Wgpu, 2, Float>) -> Self {
78        DynamicTensorFloat2D::Wgpu(tensor)
79    }
80    pub fn from_candle(tensor: Tensor<Candle, 2, Float>) -> Self {
81        DynamicTensorFloat2D::Candle(tensor)
82    }
83}
84
85/// From methods for converting from Tensor to `DynamicTensor`
86impl DynamicTensorInt1D {
87    pub fn from_ndarray(tensor: Tensor<NdArray, 1, Int>) -> Self {
88        DynamicTensorInt1D::NdArray(tensor)
89    }
90    pub fn from_wgpu(tensor: Tensor<Wgpu, 1, Int>) -> Self {
91        DynamicTensorInt1D::Wgpu(tensor)
92    }
93    pub fn from_candle(tensor: Tensor<Candle, 1, Int>) -> Self {
94        DynamicTensorInt1D::Candle(tensor)
95    }
96}
97
98/// From methods for converting from Tensor to `DynamicTensor`
99impl DynamicTensorInt2D {
100    pub fn from_ndarray(tensor: Tensor<NdArray, 2, Int>) -> Self {
101        DynamicTensorInt2D::NdArray(tensor)
102    }
103    pub fn from_wgpu(tensor: Tensor<Wgpu, 2, Int>) -> Self {
104        DynamicTensorInt2D::Wgpu(tensor)
105    }
106    pub fn from_candle(tensor: Tensor<Candle, 2, Int>) -> Self {
107        DynamicTensorInt2D::Candle(tensor)
108    }
109}
110
111// Conversion and Utility Operations for DynamicTensor variants
112
113/// Trait for common `DynamicTensor` operations
114pub trait DynamicTensorOps<T> {
115    fn as_bytes(&self) -> Vec<u8>;
116
117    fn nrows(&self) -> usize;
118    fn shape(&self) -> (usize, usize);
119
120    fn to_vec(&self) -> Vec<T>;
121    fn min_vec(&self) -> Vec<T>;
122    fn max_vec(&self) -> Vec<T>;
123}
124
125/// `DynamicTensorOps` for Float 1D tensors
126impl DynamicTensorOps<f32> for DynamicTensorFloat1D {
127    fn as_bytes(&self) -> Vec<u8> {
128        match self {
129            DynamicTensorFloat1D::NdArray(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
130            DynamicTensorFloat1D::Wgpu(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
131            DynamicTensorFloat1D::Candle(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
132        }
133    }
134
135    fn nrows(&self) -> usize {
136        match self {
137            DynamicTensorFloat1D::NdArray(tensor) => tensor.dims()[0],
138            DynamicTensorFloat1D::Wgpu(tensor) => tensor.dims()[0],
139            DynamicTensorFloat1D::Candle(tensor) => tensor.dims()[0],
140        }
141    }
142
143    fn shape(&self) -> (usize, usize) {
144        match self {
145            DynamicTensorFloat1D::NdArray(tensor) => (tensor.dims()[0], 1),
146            DynamicTensorFloat1D::Wgpu(tensor) => (tensor.dims()[0], 1),
147            DynamicTensorFloat1D::Candle(tensor) => (tensor.dims()[0], 1),
148        }
149    }
150
151    fn to_vec(&self) -> Vec<f32> {
152        match &self {
153            DynamicTensorFloat1D::NdArray(tensor) => tensor_to_data_float(tensor),
154            DynamicTensorFloat1D::Wgpu(tensor) => tensor_to_data_float(tensor),
155            DynamicTensorFloat1D::Candle(tensor) => tensor_to_data_float(tensor),
156        }
157    }
158
159    fn min_vec(&self) -> Vec<f32> {
160        vec![self.to_vec().iter().copied().fold(f32::INFINITY, f32::min)]
161    }
162
163    fn max_vec(&self) -> Vec<f32> {
164        vec![self.to_vec().iter().copied().fold(f32::NEG_INFINITY, f32::max)]
165    }
166}
167
168/// `DynamicTensorOps` for Float 2D tensors
169impl DynamicTensorOps<f32> for DynamicTensorFloat2D {
170    fn as_bytes(&self) -> Vec<u8> {
171        match self {
172            DynamicTensorFloat2D::NdArray(tensor) => {
173                let tensor_data = tensor_to_data_float(tensor);
174                bytemuck::cast_slice(&tensor_data).to_vec()
175            }
176            DynamicTensorFloat2D::Wgpu(tensor) => {
177                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
178                let tensor_data = tensor_to_data_float(tensor);
179                bytemuck::cast_slice(&tensor_data).to_vec()
180            }
181            DynamicTensorFloat2D::Candle(tensor) => {
182                let tensor_data = tensor_to_data_float(tensor);
183                bytemuck::cast_slice(&tensor_data).to_vec()
184            }
185        }
186    }
187
188    fn nrows(&self) -> usize {
189        match self {
190            DynamicTensorFloat2D::NdArray(tensor) => tensor.dims()[0],
191            DynamicTensorFloat2D::Wgpu(tensor) => tensor.dims()[0],
192            DynamicTensorFloat2D::Candle(tensor) => tensor.dims()[0],
193        }
194    }
195
196    fn shape(&self) -> (usize, usize) {
197        match self {
198            DynamicTensorFloat2D::NdArray(tensor) => (tensor.dims()[0], tensor.dims()[1]),
199            DynamicTensorFloat2D::Wgpu(tensor) => (tensor.dims()[0], tensor.dims()[1]),
200            DynamicTensorFloat2D::Candle(tensor) => (tensor.dims()[0], tensor.dims()[1]),
201        }
202    }
203
204    fn to_vec(&self) -> Vec<f32> {
205        match &self {
206            DynamicTensorFloat2D::NdArray(tensor) => tensor_to_data_float(tensor),
207            DynamicTensorFloat2D::Wgpu(tensor) => {
208                warn!("Forcing DynamicTensor with Wgpu backend to CPU");
209                tensor_to_data_float(tensor)
210            }
211            DynamicTensorFloat2D::Candle(tensor) => tensor_to_data_float(tensor),
212        }
213    }
214
215    fn min_vec(&self) -> Vec<f32> {
216        match &self {
217            DynamicTensorFloat2D::NdArray(tensor) => {
218                let min_tensor = tensor.clone().min_dim(0);
219                tensor_to_data_float(&min_tensor)
220            }
221            DynamicTensorFloat2D::Wgpu(tensor) => {
222                let min_tensor = tensor.clone().min_dim(0);
223                tensor_to_data_float(&min_tensor)
224            }
225            DynamicTensorFloat2D::Candle(tensor) => {
226                let min_tensor = tensor.clone().min_dim(0);
227                tensor_to_data_float(&min_tensor)
228            }
229        }
230    }
231
232    fn max_vec(&self) -> Vec<f32> {
233        match &self {
234            DynamicTensorFloat2D::NdArray(tensor) => {
235                let max_tensor = tensor.clone().max_dim(0);
236                tensor_to_data_float(&max_tensor)
237            }
238            DynamicTensorFloat2D::Wgpu(tensor) => {
239                let max_tensor = tensor.clone().max_dim(0);
240                tensor_to_data_float(&max_tensor)
241            }
242            DynamicTensorFloat2D::Candle(tensor) => {
243                let max_tensor = tensor.clone().max_dim(0);
244                tensor_to_data_float(&max_tensor)
245            }
246        }
247    }
248}
249
250/// `DynamicTensorOps` for Int 1D tensors
251impl DynamicTensorOps<u32> for DynamicTensorInt1D {
252    fn as_bytes(&self) -> Vec<u8> {
253        match self {
254            DynamicTensorInt1D::NdArray(tensor) => {
255                let tensor_data = tensor_to_data_int(tensor);
256                let u32_data: Vec<u32> = tensor_data
257                    .into_iter()
258                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
259                    .collect();
260                bytemuck::cast_slice(&u32_data).to_vec()
261            }
262            DynamicTensorInt1D::Wgpu(tensor) => {
263                let tensor_data = tensor_to_data_int(tensor);
264                let u32_data: Vec<u32> = tensor_data
265                    .into_iter()
266                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
267                    .collect();
268                bytemuck::cast_slice(&u32_data).to_vec()
269            }
270            DynamicTensorInt1D::Candle(tensor) => {
271                let tensor_data = tensor_to_data_int(tensor);
272                let u32_data: Vec<u32> = tensor_data
273                    .into_iter()
274                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
275                    .collect();
276                bytemuck::cast_slice(&u32_data).to_vec()
277            }
278        }
279    }
280
281    fn nrows(&self) -> usize {
282        match self {
283            DynamicTensorInt1D::NdArray(tensor) => tensor.dims()[0],
284            DynamicTensorInt1D::Wgpu(tensor) => tensor.dims()[0],
285            DynamicTensorInt1D::Candle(tensor) => tensor.dims()[0],
286        }
287    }
288
289    fn shape(&self) -> (usize, usize) {
290        match self {
291            DynamicTensorInt1D::NdArray(tensor) => (tensor.dims()[0], 1),
292            DynamicTensorInt1D::Wgpu(tensor) => (tensor.dims()[0], 1),
293            DynamicTensorInt1D::Candle(tensor) => (tensor.dims()[0], 1),
294        }
295    }
296
297    fn to_vec(&self) -> Vec<u32> {
298        match &self {
299            DynamicTensorInt1D::NdArray(tensor) => {
300                let data = tensor_to_data_int(tensor);
301                data.into_iter()
302                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
303                    .collect()
304            }
305            DynamicTensorInt1D::Wgpu(tensor) => {
306                let data = tensor_to_data_int(tensor);
307                data.into_iter()
308                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
309                    .collect()
310            }
311            DynamicTensorInt1D::Candle(tensor) => {
312                let data = tensor_to_data_int(tensor);
313                data.into_iter()
314                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
315                    .collect()
316            }
317        }
318    }
319
320    fn min_vec(&self) -> Vec<u32> {
321        vec![self.to_vec().into_iter().min().unwrap_or(0)]
322    }
323
324    fn max_vec(&self) -> Vec<u32> {
325        vec![self.to_vec().into_iter().max().unwrap_or(0)]
326    }
327}
328
329/// `DynamicTensorOps` for Int 2D tensors
330impl DynamicTensorOps<u32> for DynamicTensorInt2D {
331    fn as_bytes(&self) -> Vec<u8> {
332        match self {
333            DynamicTensorInt2D::NdArray(tensor) => {
334                let tensor_data = tensor_to_data_int(tensor);
335                let u32_data: Vec<u32> = tensor_data
336                    .into_iter()
337                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
338                    .collect();
339                bytemuck::cast_slice(&u32_data).to_vec()
340            }
341            DynamicTensorInt2D::Wgpu(tensor) => {
342                let tensor_data = tensor_to_data_int(tensor);
343                let u32_data: Vec<u32> = tensor_data
344                    .into_iter()
345                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
346                    .collect();
347                bytemuck::cast_slice(&u32_data).to_vec()
348            }
349            DynamicTensorInt2D::Candle(tensor) => {
350                let tensor_data = tensor_to_data_int(tensor);
351                let u32_data: Vec<u32> = tensor_data
352                    .into_iter()
353                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
354                    .collect();
355                bytemuck::cast_slice(&u32_data).to_vec()
356            }
357        }
358    }
359
360    fn nrows(&self) -> usize {
361        match self {
362            DynamicTensorInt2D::NdArray(tensor) => tensor.dims()[0],
363            DynamicTensorInt2D::Wgpu(tensor) => tensor.dims()[0],
364            DynamicTensorInt2D::Candle(tensor) => tensor.dims()[0],
365        }
366    }
367
368    fn shape(&self) -> (usize, usize) {
369        match self {
370            DynamicTensorInt2D::NdArray(tensor) => (tensor.dims()[0], tensor.dims()[1]),
371            DynamicTensorInt2D::Wgpu(tensor) => (tensor.dims()[0], tensor.dims()[1]),
372            DynamicTensorInt2D::Candle(tensor) => (tensor.dims()[0], tensor.dims()[1]),
373        }
374    }
375
376    fn to_vec(&self) -> Vec<u32> {
377        match &self {
378            DynamicTensorInt2D::NdArray(tensor) => {
379                let data = tensor_to_data_int(tensor);
380                data.into_iter()
381                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
382                    .collect()
383            }
384            DynamicTensorInt2D::Wgpu(tensor) => {
385                let data = tensor_to_data_int(tensor);
386                data.into_iter()
387                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
388                    .collect()
389            }
390            DynamicTensorInt2D::Candle(tensor) => {
391                let data = tensor_to_data_int(tensor);
392                data.into_iter()
393                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
394                    .collect()
395            }
396        }
397    }
398
399    fn min_vec(&self) -> Vec<u32> {
400        match &self {
401            DynamicTensorInt2D::NdArray(tensor) => {
402                let min_tensor = tensor.clone().min_dim(0);
403                tensor_to_data_int(&min_tensor)
404                    .into_iter()
405                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
406                    .collect()
407            }
408            DynamicTensorInt2D::Wgpu(tensor) => {
409                let min_tensor = tensor.clone().min_dim(0);
410                tensor_to_data_int(&min_tensor)
411                    .into_iter()
412                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
413                    .collect()
414            }
415            DynamicTensorInt2D::Candle(tensor) => {
416                let min_tensor = tensor.clone().min_dim(0);
417                tensor_to_data_int(&min_tensor)
418                    .into_iter()
419                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
420                    .collect()
421            }
422        }
423    }
424
425    fn max_vec(&self) -> Vec<u32> {
426        match &self {
427            DynamicTensorInt2D::NdArray(tensor) => {
428                let max_tensor = tensor.clone().max_dim(0);
429                tensor_to_data_int(&max_tensor)
430                    .into_iter()
431                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
432                    .collect()
433            }
434            DynamicTensorInt2D::Wgpu(tensor) => {
435                let max_tensor = tensor.clone().max_dim(0);
436                tensor_to_data_int(&max_tensor)
437                    .into_iter()
438                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
439                    .collect()
440            }
441            DynamicTensorInt2D::Candle(tensor) => {
442                let max_tensor = tensor.clone().max_dim(0);
443                tensor_to_data_int(&max_tensor)
444                    .into_iter()
445                    .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
446                    .collect()
447            }
448        }
449    }
450}
451
452/// Trait for conversion to and from nalgebra matrices
453pub trait DynamicMatrixOps<T> {
454    fn from_dmatrix(matrix: &na::DMatrix<T>) -> Self;
455    fn to_dmatrix(&self) -> na::DMatrix<T>;
456    fn into_dmatrix(self) -> na::DMatrix<T>;
457}
458
459/// `DynamicMatrixOps` for `DynamicTensorFloat2D`
460impl DynamicMatrixOps<f32> for DynamicTensorFloat2D {
461    fn from_dmatrix(matrix: &na::DMatrix<f32>) -> Self {
462        match std::any::TypeId::of::<DefaultBackend>() {
463            id if id == std::any::TypeId::of::<NdArray>() => {
464                let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
465                DynamicTensorFloat2D::NdArray(tensor)
466            }
467            id if id == std::any::TypeId::of::<Candle>() => {
468                let tensor = matrix.to_burn(&CandleDevice::Cpu);
469                DynamicTensorFloat2D::Candle(tensor)
470            }
471            id if id == std::any::TypeId::of::<Wgpu>() => {
472                let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
473                DynamicTensorFloat2D::Wgpu(tensor)
474            }
475            _ => panic!("Unsupported backend!"),
476        }
477    }
478
479    fn to_dmatrix(&self) -> na::DMatrix<f32> {
480        match self {
481            DynamicTensorFloat2D::NdArray(tensor) => tensor.to_nalgebra(),
482            DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_nalgebra(),
483            DynamicTensorFloat2D::Candle(tensor) => tensor.to_nalgebra(),
484        }
485    }
486
487    fn into_dmatrix(self) -> na::DMatrix<f32> {
488        match self {
489            DynamicTensorFloat2D::NdArray(tensor) => tensor.into_nalgebra(),
490            DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_nalgebra(),
491            DynamicTensorFloat2D::Candle(tensor) => tensor.into_nalgebra(),
492        }
493    }
494}
495
496/// `DynamicMatrixOps` for `DynamicTensorInt2D`
497impl DynamicMatrixOps<u32> for DynamicTensorInt2D {
498    fn from_dmatrix(matrix: &na::DMatrix<u32>) -> Self {
499        match std::any::TypeId::of::<DefaultBackend>() {
500            id if id == std::any::TypeId::of::<NdArray>() => {
501                let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
502                DynamicTensorInt2D::NdArray(tensor)
503            }
504            id if id == std::any::TypeId::of::<Candle>() => {
505                let tensor = matrix.to_burn(&CandleDevice::Cpu);
506                DynamicTensorInt2D::Candle(tensor)
507            }
508            id if id == std::any::TypeId::of::<Wgpu>() => {
509                let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
510                DynamicTensorInt2D::Wgpu(tensor)
511            }
512            _ => panic!("Unsupported backend!"),
513        }
514    }
515
516    fn to_dmatrix(&self) -> na::DMatrix<u32> {
517        match self {
518            DynamicTensorInt2D::NdArray(tensor) => tensor.to_nalgebra(),
519            DynamicTensorInt2D::Wgpu(tensor) => tensor.to_nalgebra(),
520            DynamicTensorInt2D::Candle(tensor) => tensor.to_nalgebra(),
521        }
522    }
523
524    fn into_dmatrix(self) -> na::DMatrix<u32> {
525        match self {
526            DynamicTensorInt2D::NdArray(tensor) => tensor.into_nalgebra(),
527            DynamicTensorInt2D::Wgpu(tensor) => tensor.into_nalgebra(),
528            DynamicTensorInt2D::Candle(tensor) => tensor.into_nalgebra(),
529        }
530    }
531}
532
533// /////////////////////////////////////////////////////////////////////////////
534// ///////////////////////////////////////// ///// Some burn utilities
535// /////////////////////////////////////////////////////////////////////////////
536// /////////////////////////////////////////
537/// Normalise a 2D tensor across dim 1
538pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
539    let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); // Compute the L2 norm along the last axis (dim = 1)
540    tensor.div(norm) // Divide each vector by its norm
541}
542
543/// Cross product of 2 2D Tensors
544pub fn cross_product<B: Backend>(
545    a: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
546    b: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
547) -> Tensor<B, 2, Float> {
548    // Split the input tensors along dimension 1 (the 3 components) using chunk
549    let a_chunks = a.clone().chunk(3, 1); // Split tensor `a` into 3 chunks: ax, ay, az
550    let b_chunks = b.clone().chunk(3, 1); // Split tensor `b` into 3 chunks: bx, by, bz
551
552    let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); // x component of a
553    let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); // y component of a
554    let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); // z component of a
555
556    let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); // x component of b
557    let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); // y component of b
558    let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); // z component of b
559
560    // Compute the components of the cross product
561    let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); // cx = ay * bz - az * by
562    let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); // cy = az * bx - ax * bz
563    let cz = ax.mul(by).sub(ay.mul(bx)); // cz = ax * by - ay * bx
564
565    // Stack the result to form the resulting [N, 3] tensor
566    Tensor::stack(vec![cx, cy, cz], 1) // Concatenate along the second dimension
567}