1use burn::{
2 backend::{candle::CandleDevice, ndarray::NdArrayDevice, wgpu::WgpuDevice, Candle, NdArray, Wgpu},
3 prelude::Backend,
4 tensor::{Float, Int, Tensor},
5};
6use core::panic;
7use ndarray as nd;
8use crate::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn, ToNalgebraFloat, ToNalgebraInt, ToNdArray};
11extern crate nalgebra as na;
12use bytemuck;
13use log::warn;
14
15pub type DefaultBackend = NdArray; #[derive(Clone, Debug)]
18pub enum BurnBackend {
19 Candle,
20 NdArray,
21 Wgpu,
22}
23
24#[derive(Clone, Debug)]
26pub enum DynamicTensorFloat1D {
27 NdArray(Tensor<NdArray, 1, Float>),
28 Wgpu(Tensor<Wgpu, 1, Float>),
29 Candle(Tensor<Candle, 1, Float>),
30}
31
32#[derive(Clone, Debug)]
34pub enum DynamicTensorFloat2D {
35 NdArray(Tensor<NdArray, 2, Float>),
36 Wgpu(Tensor<Wgpu, 2, Float>),
37 Candle(Tensor<Candle, 2, Float>),
38}
39
40#[derive(Clone, Debug)]
42pub enum DynamicTensorFloat3D {
43 NdArray(Tensor<NdArray, 3, Float>),
44 Wgpu(Tensor<Wgpu, 3, Float>),
45 Candle(Tensor<Candle, 3, Float>),
46}
47
48#[derive(Clone, Debug)]
50pub enum DynamicTensorInt1D {
51 NdArray(Tensor<NdArray, 1, Int>),
52 Wgpu(Tensor<Wgpu, 1, Int>),
53 Candle(Tensor<Candle, 1, Int>),
54}
55
56#[derive(Clone, Debug)]
58pub enum DynamicTensorInt2D {
59 NdArray(Tensor<NdArray, 2, Int>),
60 Wgpu(Tensor<Wgpu, 2, Int>),
61 Candle(Tensor<Candle, 2, Int>),
62}
63
64#[derive(Clone, Debug)]
66pub enum DynamicTensorInt3D {
67 NdArray(Tensor<NdArray, 3, Int>),
68 Wgpu(Tensor<Wgpu, 3, Int>),
69 Candle(Tensor<Candle, 3, Int>),
70}
71
72impl DynamicTensorFloat1D {
74 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 1, Float>) -> Self {
75 DynamicTensorFloat1D::NdArray(tensor)
76 }
77 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 1, Float>) -> Self {
78 DynamicTensorFloat1D::Wgpu(tensor)
79 }
80 pub fn from_candle_backend(tensor: Tensor<Candle, 1, Float>) -> Self {
81 DynamicTensorFloat1D::Candle(tensor)
82 }
83}
84
85impl DynamicTensorFloat2D {
87 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 2, Float>) -> Self {
88 DynamicTensorFloat2D::NdArray(tensor)
89 }
90 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 2, Float>) -> Self {
91 DynamicTensorFloat2D::Wgpu(tensor)
92 }
93 pub fn from_candle_backend(tensor: Tensor<Candle, 2, Float>) -> Self {
94 DynamicTensorFloat2D::Candle(tensor)
95 }
96}
97
98impl DynamicTensorFloat3D {
100 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 3, Float>) -> Self {
101 DynamicTensorFloat3D::NdArray(tensor)
102 }
103 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 3, Float>) -> Self {
104 DynamicTensorFloat3D::Wgpu(tensor)
105 }
106 pub fn from_candle_backend(tensor: Tensor<Candle, 3, Float>) -> Self {
107 DynamicTensorFloat3D::Candle(tensor)
108 }
109}
110
111impl DynamicTensorInt1D {
113 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 1, Int>) -> Self {
114 DynamicTensorInt1D::NdArray(tensor)
115 }
116 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 1, Int>) -> Self {
117 DynamicTensorInt1D::Wgpu(tensor)
118 }
119 pub fn from_candle_backend(tensor: Tensor<Candle, 1, Int>) -> Self {
120 DynamicTensorInt1D::Candle(tensor)
121 }
122}
123
124impl DynamicTensorInt2D {
126 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 2, Int>) -> Self {
127 DynamicTensorInt2D::NdArray(tensor)
128 }
129 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 2, Int>) -> Self {
130 DynamicTensorInt2D::Wgpu(tensor)
131 }
132 pub fn from_candle_backend(tensor: Tensor<Candle, 2, Int>) -> Self {
133 DynamicTensorInt2D::Candle(tensor)
134 }
135}
136
137impl DynamicTensorInt3D {
139 pub fn from_ndarray_backend(tensor: Tensor<NdArray, 3, Int>) -> Self {
140 DynamicTensorInt3D::NdArray(tensor)
141 }
142 pub fn from_wgpu_backend(tensor: Tensor<Wgpu, 3, Int>) -> Self {
143 DynamicTensorInt3D::Wgpu(tensor)
144 }
145 pub fn from_candle_backend(tensor: Tensor<Candle, 3, Int>) -> Self {
146 DynamicTensorInt3D::Candle(tensor)
147 }
148}
149
150pub trait DynamicTensorOps<T> {
154 fn as_bytes(&self) -> Vec<u8>;
155
156 fn nrows(&self) -> usize;
157 fn shape(&self) -> Vec<usize>;
158
159 fn to_vec(&self) -> Vec<T>;
160 fn min_vec(&self) -> Vec<T>;
161 fn max_vec(&self) -> Vec<T>;
162}
163
164impl DynamicTensorOps<f32> for DynamicTensorFloat1D {
166 fn as_bytes(&self) -> Vec<u8> {
167 match self {
168 DynamicTensorFloat1D::NdArray(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
169 DynamicTensorFloat1D::Wgpu(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
170 DynamicTensorFloat1D::Candle(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
171 }
172 }
173
174 fn nrows(&self) -> usize {
175 match self {
176 DynamicTensorFloat1D::NdArray(tensor) => tensor.dims()[0],
177 DynamicTensorFloat1D::Wgpu(tensor) => tensor.dims()[0],
178 DynamicTensorFloat1D::Candle(tensor) => tensor.dims()[0],
179 }
180 }
181
182 fn shape(&self) -> Vec<usize> {
183 match self {
184 DynamicTensorFloat1D::NdArray(tensor) => vec![tensor.dims()[0]],
185 DynamicTensorFloat1D::Wgpu(tensor) => vec![tensor.dims()[0]],
186 DynamicTensorFloat1D::Candle(tensor) => vec![tensor.dims()[0]],
187 }
188 }
189
190 fn to_vec(&self) -> Vec<f32> {
191 match &self {
192 DynamicTensorFloat1D::NdArray(tensor) => tensor_to_data_float(tensor),
193 DynamicTensorFloat1D::Wgpu(tensor) => tensor_to_data_float(tensor),
194 DynamicTensorFloat1D::Candle(tensor) => tensor_to_data_float(tensor),
195 }
196 }
197
198 fn min_vec(&self) -> Vec<f32> {
199 vec![self.to_vec().iter().copied().fold(f32::INFINITY, f32::min)]
200 }
201
202 fn max_vec(&self) -> Vec<f32> {
203 vec![self.to_vec().iter().copied().fold(f32::NEG_INFINITY, f32::max)]
204 }
205}
206
207impl DynamicTensorOps<f32> for DynamicTensorFloat2D {
209 fn as_bytes(&self) -> Vec<u8> {
210 match self {
211 DynamicTensorFloat2D::NdArray(tensor) => {
212 let tensor_data = tensor_to_data_float(tensor);
213 bytemuck::cast_slice(&tensor_data).to_vec()
214 }
215 DynamicTensorFloat2D::Wgpu(tensor) => {
216 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
217 let tensor_data = tensor_to_data_float(tensor);
218 bytemuck::cast_slice(&tensor_data).to_vec()
219 }
220 DynamicTensorFloat2D::Candle(tensor) => {
221 let tensor_data = tensor_to_data_float(tensor);
222 bytemuck::cast_slice(&tensor_data).to_vec()
223 }
224 }
225 }
226
227 fn nrows(&self) -> usize {
228 match self {
229 DynamicTensorFloat2D::NdArray(tensor) => tensor.dims()[0],
230 DynamicTensorFloat2D::Wgpu(tensor) => tensor.dims()[0],
231 DynamicTensorFloat2D::Candle(tensor) => tensor.dims()[0],
232 }
233 }
234
235 fn shape(&self) -> Vec<usize> {
236 match self {
237 DynamicTensorFloat2D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
238 DynamicTensorFloat2D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
239 DynamicTensorFloat2D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
240 }
241 }
242
243 fn to_vec(&self) -> Vec<f32> {
244 match &self {
245 DynamicTensorFloat2D::NdArray(tensor) => tensor_to_data_float(tensor),
246 DynamicTensorFloat2D::Wgpu(tensor) => {
247 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
248 tensor_to_data_float(tensor)
249 }
250 DynamicTensorFloat2D::Candle(tensor) => tensor_to_data_float(tensor),
251 }
252 }
253
254 fn min_vec(&self) -> Vec<f32> {
255 match &self {
256 DynamicTensorFloat2D::NdArray(tensor) => {
257 let min_tensor = tensor.clone().min_dim(0);
258 tensor_to_data_float(&min_tensor)
259 }
260 DynamicTensorFloat2D::Wgpu(tensor) => {
261 let min_tensor = tensor.clone().min_dim(0);
262 tensor_to_data_float(&min_tensor)
263 }
264 DynamicTensorFloat2D::Candle(tensor) => {
265 let min_tensor = tensor.clone().min_dim(0);
266 tensor_to_data_float(&min_tensor)
267 }
268 }
269 }
270
271 fn max_vec(&self) -> Vec<f32> {
272 match &self {
273 DynamicTensorFloat2D::NdArray(tensor) => {
274 let max_tensor = tensor.clone().max_dim(0);
275 tensor_to_data_float(&max_tensor)
276 }
277 DynamicTensorFloat2D::Wgpu(tensor) => {
278 let max_tensor = tensor.clone().max_dim(0);
279 tensor_to_data_float(&max_tensor)
280 }
281 DynamicTensorFloat2D::Candle(tensor) => {
282 let max_tensor = tensor.clone().max_dim(0);
283 tensor_to_data_float(&max_tensor)
284 }
285 }
286 }
287}
288
289impl DynamicTensorOps<f32> for DynamicTensorFloat3D {
291 fn as_bytes(&self) -> Vec<u8> {
292 match self {
293 DynamicTensorFloat3D::NdArray(tensor) => {
294 let tensor_data = tensor_to_data_float(tensor);
295 bytemuck::cast_slice(&tensor_data).to_vec()
296 }
297 DynamicTensorFloat3D::Wgpu(tensor) => {
298 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
299 let tensor_data = tensor_to_data_float(tensor);
300 bytemuck::cast_slice(&tensor_data).to_vec()
301 }
302 DynamicTensorFloat3D::Candle(tensor) => {
303 let tensor_data = tensor_to_data_float(tensor);
304 bytemuck::cast_slice(&tensor_data).to_vec()
305 }
306 }
307 }
308
309 fn nrows(&self) -> usize {
310 match self {
311 DynamicTensorFloat3D::NdArray(tensor) => tensor.dims()[0],
312 DynamicTensorFloat3D::Wgpu(tensor) => tensor.dims()[0],
313 DynamicTensorFloat3D::Candle(tensor) => tensor.dims()[0],
314 }
315 }
316
317 fn shape(&self) -> Vec<usize> {
318 match self {
319 DynamicTensorFloat3D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
320 DynamicTensorFloat3D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
321 DynamicTensorFloat3D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
322 }
323 }
324
325 fn to_vec(&self) -> Vec<f32> {
326 match &self {
327 DynamicTensorFloat3D::NdArray(tensor) => tensor_to_data_float(tensor),
328 DynamicTensorFloat3D::Wgpu(tensor) => {
329 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
330 tensor_to_data_float(tensor)
331 }
332 DynamicTensorFloat3D::Candle(tensor) => tensor_to_data_float(tensor),
333 }
334 }
335
336 fn min_vec(&self) -> Vec<f32> {
337 match &self {
338 DynamicTensorFloat3D::NdArray(tensor) => {
339 let min_tensor = tensor.clone().min_dim(0);
340 tensor_to_data_float(&min_tensor)
341 }
342 DynamicTensorFloat3D::Wgpu(tensor) => {
343 let min_tensor = tensor.clone().min_dim(0);
344 tensor_to_data_float(&min_tensor)
345 }
346 DynamicTensorFloat3D::Candle(tensor) => {
347 let min_tensor = tensor.clone().min_dim(0);
348 tensor_to_data_float(&min_tensor)
349 }
350 }
351 }
352
353 fn max_vec(&self) -> Vec<f32> {
354 match &self {
355 DynamicTensorFloat3D::NdArray(tensor) => {
356 let max_tensor = tensor.clone().max_dim(0);
357 tensor_to_data_float(&max_tensor)
358 }
359 DynamicTensorFloat3D::Wgpu(tensor) => {
360 let max_tensor = tensor.clone().max_dim(0);
361 tensor_to_data_float(&max_tensor)
362 }
363 DynamicTensorFloat3D::Candle(tensor) => {
364 let max_tensor = tensor.clone().max_dim(0);
365 tensor_to_data_float(&max_tensor)
366 }
367 }
368 }
369}
370
371impl DynamicTensorOps<u32> for DynamicTensorInt1D {
373 fn as_bytes(&self) -> Vec<u8> {
374 match self {
375 DynamicTensorInt1D::NdArray(tensor) => {
376 let tensor_data = tensor_to_data_int(tensor);
377 let u32_data: Vec<u32> = tensor_data
378 .into_iter()
379 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
380 .collect();
381 bytemuck::cast_slice(&u32_data).to_vec()
382 }
383 DynamicTensorInt1D::Wgpu(tensor) => {
384 let tensor_data = tensor_to_data_int(tensor);
385 let u32_data: Vec<u32> = tensor_data
386 .into_iter()
387 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
388 .collect();
389 bytemuck::cast_slice(&u32_data).to_vec()
390 }
391 DynamicTensorInt1D::Candle(tensor) => {
392 let tensor_data = tensor_to_data_int(tensor);
393 let u32_data: Vec<u32> = tensor_data
394 .into_iter()
395 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
396 .collect();
397 bytemuck::cast_slice(&u32_data).to_vec()
398 }
399 }
400 }
401
402 fn nrows(&self) -> usize {
403 match self {
404 DynamicTensorInt1D::NdArray(tensor) => tensor.dims()[0],
405 DynamicTensorInt1D::Wgpu(tensor) => tensor.dims()[0],
406 DynamicTensorInt1D::Candle(tensor) => tensor.dims()[0],
407 }
408 }
409
410 fn shape(&self) -> Vec<usize> {
411 match self {
412 DynamicTensorInt1D::NdArray(tensor) => vec![tensor.dims()[0]],
413 DynamicTensorInt1D::Wgpu(tensor) => vec![tensor.dims()[0]],
414 DynamicTensorInt1D::Candle(tensor) => vec![tensor.dims()[0]],
415 }
416 }
417
418 fn to_vec(&self) -> Vec<u32> {
419 match &self {
420 DynamicTensorInt1D::NdArray(tensor) => {
421 let data = tensor_to_data_int(tensor);
422 data.into_iter()
423 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
424 .collect()
425 }
426 DynamicTensorInt1D::Wgpu(tensor) => {
427 let data = tensor_to_data_int(tensor);
428 data.into_iter()
429 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
430 .collect()
431 }
432 DynamicTensorInt1D::Candle(tensor) => {
433 let data = tensor_to_data_int(tensor);
434 data.into_iter()
435 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
436 .collect()
437 }
438 }
439 }
440
441 fn min_vec(&self) -> Vec<u32> {
442 vec![self.to_vec().into_iter().min().unwrap_or(0)]
443 }
444
445 fn max_vec(&self) -> Vec<u32> {
446 vec![self.to_vec().into_iter().max().unwrap_or(0)]
447 }
448}
449
450impl DynamicTensorOps<u32> for DynamicTensorInt2D {
452 fn as_bytes(&self) -> Vec<u8> {
453 match self {
454 DynamicTensorInt2D::NdArray(tensor) => {
455 let tensor_data = tensor_to_data_int(tensor);
456 let u32_data: Vec<u32> = tensor_data
457 .into_iter()
458 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
459 .collect();
460 bytemuck::cast_slice(&u32_data).to_vec()
461 }
462 DynamicTensorInt2D::Wgpu(tensor) => {
463 let tensor_data = tensor_to_data_int(tensor);
464 let u32_data: Vec<u32> = tensor_data
465 .into_iter()
466 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
467 .collect();
468 bytemuck::cast_slice(&u32_data).to_vec()
469 }
470 DynamicTensorInt2D::Candle(tensor) => {
471 let tensor_data = tensor_to_data_int(tensor);
472 let u32_data: Vec<u32> = tensor_data
473 .into_iter()
474 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
475 .collect();
476 bytemuck::cast_slice(&u32_data).to_vec()
477 }
478 }
479 }
480
481 fn nrows(&self) -> usize {
482 match self {
483 DynamicTensorInt2D::NdArray(tensor) => tensor.dims()[0],
484 DynamicTensorInt2D::Wgpu(tensor) => tensor.dims()[0],
485 DynamicTensorInt2D::Candle(tensor) => tensor.dims()[0],
486 }
487 }
488
489 fn shape(&self) -> Vec<usize> {
490 match self {
491 DynamicTensorInt2D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
492 DynamicTensorInt2D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
493 DynamicTensorInt2D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1]],
494 }
495 }
496
497 fn to_vec(&self) -> Vec<u32> {
498 match &self {
499 DynamicTensorInt2D::NdArray(tensor) => {
500 let data = tensor_to_data_int(tensor);
501 data.into_iter()
502 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
503 .collect()
504 }
505 DynamicTensorInt2D::Wgpu(tensor) => {
506 let data = tensor_to_data_int(tensor);
507 data.into_iter()
508 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
509 .collect()
510 }
511 DynamicTensorInt2D::Candle(tensor) => {
512 let data = tensor_to_data_int(tensor);
513 data.into_iter()
514 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
515 .collect()
516 }
517 }
518 }
519
520 fn min_vec(&self) -> Vec<u32> {
521 match &self {
522 DynamicTensorInt2D::NdArray(tensor) => {
523 let min_tensor = tensor.clone().min_dim(0);
524 tensor_to_data_int(&min_tensor)
525 .into_iter()
526 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
527 .collect()
528 }
529 DynamicTensorInt2D::Wgpu(tensor) => {
530 let min_tensor = tensor.clone().min_dim(0);
531 tensor_to_data_int(&min_tensor)
532 .into_iter()
533 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
534 .collect()
535 }
536 DynamicTensorInt2D::Candle(tensor) => {
537 let min_tensor = tensor.clone().min_dim(0);
538 tensor_to_data_int(&min_tensor)
539 .into_iter()
540 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
541 .collect()
542 }
543 }
544 }
545
546 fn max_vec(&self) -> Vec<u32> {
547 match &self {
548 DynamicTensorInt2D::NdArray(tensor) => {
549 let max_tensor = tensor.clone().max_dim(0);
550 tensor_to_data_int(&max_tensor)
551 .into_iter()
552 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
553 .collect()
554 }
555 DynamicTensorInt2D::Wgpu(tensor) => {
556 let max_tensor = tensor.clone().max_dim(0);
557 tensor_to_data_int(&max_tensor)
558 .into_iter()
559 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
560 .collect()
561 }
562 DynamicTensorInt2D::Candle(tensor) => {
563 let max_tensor = tensor.clone().max_dim(0);
564 tensor_to_data_int(&max_tensor)
565 .into_iter()
566 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
567 .collect()
568 }
569 }
570 }
571}
572
573impl DynamicTensorOps<u32> for DynamicTensorInt3D {
575 fn as_bytes(&self) -> Vec<u8> {
576 match self {
577 DynamicTensorInt3D::NdArray(tensor) => {
578 let tensor_data = tensor_to_data_int(tensor);
579 let u32_data: Vec<u32> = tensor_data
580 .into_iter()
581 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
582 .collect();
583 bytemuck::cast_slice(&u32_data).to_vec()
584 }
585 DynamicTensorInt3D::Wgpu(tensor) => {
586 let tensor_data = tensor_to_data_int(tensor);
587 let u32_data: Vec<u32> = tensor_data
588 .into_iter()
589 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
590 .collect();
591 bytemuck::cast_slice(&u32_data).to_vec()
592 }
593 DynamicTensorInt3D::Candle(tensor) => {
594 let tensor_data = tensor_to_data_int(tensor);
595 let u32_data: Vec<u32> = tensor_data
596 .into_iter()
597 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
598 .collect();
599 bytemuck::cast_slice(&u32_data).to_vec()
600 }
601 }
602 }
603
604 fn nrows(&self) -> usize {
605 match self {
606 DynamicTensorInt3D::NdArray(tensor) => tensor.dims()[0],
607 DynamicTensorInt3D::Wgpu(tensor) => tensor.dims()[0],
608 DynamicTensorInt3D::Candle(tensor) => tensor.dims()[0],
609 }
610 }
611
612 fn shape(&self) -> Vec<usize> {
613 match self {
614 DynamicTensorInt3D::NdArray(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
615 DynamicTensorInt3D::Wgpu(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
616 DynamicTensorInt3D::Candle(tensor) => vec![tensor.dims()[0], tensor.dims()[1], tensor.dims()[2]],
617 }
618 }
619
620 fn to_vec(&self) -> Vec<u32> {
621 match &self {
622 DynamicTensorInt3D::NdArray(tensor) => {
623 let data = tensor_to_data_int(tensor);
624 data.into_iter()
625 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
626 .collect()
627 }
628 DynamicTensorInt3D::Wgpu(tensor) => {
629 let data = tensor_to_data_int(tensor);
630 data.into_iter()
631 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
632 .collect()
633 }
634 DynamicTensorInt3D::Candle(tensor) => {
635 let data = tensor_to_data_int(tensor);
636 data.into_iter()
637 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
638 .collect()
639 }
640 }
641 }
642
643 fn min_vec(&self) -> Vec<u32> {
644 match &self {
645 DynamicTensorInt3D::NdArray(tensor) => {
646 let min_tensor = tensor.clone().min_dim(0);
647 tensor_to_data_int(&min_tensor)
648 .into_iter()
649 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
650 .collect()
651 }
652 DynamicTensorInt3D::Wgpu(tensor) => {
653 let min_tensor = tensor.clone().min_dim(0);
654 tensor_to_data_int(&min_tensor)
655 .into_iter()
656 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
657 .collect()
658 }
659 DynamicTensorInt3D::Candle(tensor) => {
660 let min_tensor = tensor.clone().min_dim(0);
661 tensor_to_data_int(&min_tensor)
662 .into_iter()
663 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
664 .collect()
665 }
666 }
667 }
668
669 fn max_vec(&self) -> Vec<u32> {
670 match &self {
671 DynamicTensorInt3D::NdArray(tensor) => {
672 let max_tensor = tensor.clone().max_dim(0);
673 tensor_to_data_int(&max_tensor)
674 .into_iter()
675 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
676 .collect()
677 }
678 DynamicTensorInt3D::Wgpu(tensor) => {
679 let max_tensor = tensor.clone().max_dim(0);
680 tensor_to_data_int(&max_tensor)
681 .into_iter()
682 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
683 .collect()
684 }
685 DynamicTensorInt3D::Candle(tensor) => {
686 let max_tensor = tensor.clone().max_dim(0);
687 tensor_to_data_int(&max_tensor)
688 .into_iter()
689 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
690 .collect()
691 }
692 }
693 }
694}
695
696pub trait DynamicMatrixOps<T, const D: usize> {
698 fn from_ndarray(array: &nd::Array<T, nd::Dim<[usize; D]>>) -> Self;
699 fn to_ndarray(&self) -> nd::Array<T, nd::Dim<[usize; D]>>;
700 fn into_ndarray(self) -> nd::Array<T, nd::Dim<[usize; D]>>;
701 fn from_dmatrix(matrix: &na::DMatrix<T>) -> Self;
702 fn to_dmatrix(&self) -> na::DMatrix<T>;
703 fn into_dmatrix(self) -> na::DMatrix<T>;
704}
705
706impl DynamicMatrixOps<f32, 2> for DynamicTensorFloat2D {
708 fn from_ndarray(array: &nd::Array<f32, nd::Dim<[usize; 2]>>) -> Self {
709 match std::any::TypeId::of::<DefaultBackend>() {
710 id if id == std::any::TypeId::of::<NdArray>() => {
711 let tensor = array.to_burn(&NdArrayDevice::Cpu);
712 DynamicTensorFloat2D::NdArray(tensor)
713 }
714 id if id == std::any::TypeId::of::<Candle>() => {
715 let tensor = array.to_burn(&CandleDevice::Cpu);
716 DynamicTensorFloat2D::Candle(tensor)
717 }
718 id if id == std::any::TypeId::of::<Wgpu>() => {
719 let tensor = array.to_burn(&WgpuDevice::BestAvailable);
720 DynamicTensorFloat2D::Wgpu(tensor)
721 }
722 _ => panic!("Unsupported backend!"),
723 }
724 }
725
726 fn to_ndarray(&self) -> nd::Array<f32, nd::Dim<[usize; 2]>> {
727 match self {
728 DynamicTensorFloat2D::NdArray(tensor) => tensor.to_ndarray(),
729 DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_ndarray(),
730 DynamicTensorFloat2D::Candle(tensor) => tensor.to_ndarray(),
731 }
732 }
733
734 fn into_ndarray(self) -> nd::Array<f32, nd::Dim<[usize; 2]>> {
735 match self {
736 DynamicTensorFloat2D::NdArray(tensor) => tensor.into_ndarray(),
737 DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_ndarray(),
738 DynamicTensorFloat2D::Candle(tensor) => tensor.into_ndarray(),
739 }
740 }
741
742 fn from_dmatrix(matrix: &na::DMatrix<f32>) -> Self {
743 match std::any::TypeId::of::<DefaultBackend>() {
744 id if id == std::any::TypeId::of::<NdArray>() => {
745 let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
746 DynamicTensorFloat2D::NdArray(tensor)
747 }
748 id if id == std::any::TypeId::of::<Candle>() => {
749 let tensor = matrix.to_burn(&CandleDevice::Cpu);
750 DynamicTensorFloat2D::Candle(tensor)
751 }
752 id if id == std::any::TypeId::of::<Wgpu>() => {
753 let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
754 DynamicTensorFloat2D::Wgpu(tensor)
755 }
756 _ => panic!("Unsupported backend!"),
757 }
758 }
759
760 fn to_dmatrix(&self) -> na::DMatrix<f32> {
761 match self {
762 DynamicTensorFloat2D::NdArray(tensor) => tensor.to_nalgebra(),
763 DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_nalgebra(),
764 DynamicTensorFloat2D::Candle(tensor) => tensor.to_nalgebra(),
765 }
766 }
767
768 fn into_dmatrix(self) -> na::DMatrix<f32> {
769 match self {
770 DynamicTensorFloat2D::NdArray(tensor) => tensor.into_nalgebra(),
771 DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_nalgebra(),
772 DynamicTensorFloat2D::Candle(tensor) => tensor.into_nalgebra(),
773 }
774 }
775}
776
777impl DynamicMatrixOps<f32, 3> for DynamicTensorFloat3D {
779 fn from_ndarray(array: &nd::Array<f32, nd::Dim<[usize; 3]>>) -> Self {
780 match std::any::TypeId::of::<DefaultBackend>() {
781 id if id == std::any::TypeId::of::<NdArray>() => {
782 let tensor = array.to_burn(&NdArrayDevice::Cpu);
783 DynamicTensorFloat3D::NdArray(tensor)
784 }
785 id if id == std::any::TypeId::of::<Candle>() => {
786 let tensor = array.to_burn(&CandleDevice::Cpu);
787 DynamicTensorFloat3D::Candle(tensor)
788 }
789 id if id == std::any::TypeId::of::<Wgpu>() => {
790 let tensor = array.to_burn(&WgpuDevice::BestAvailable);
791 DynamicTensorFloat3D::Wgpu(tensor)
792 }
793 _ => panic!("Unsupported backend!"),
794 }
795 }
796
797 fn to_ndarray(&self) -> nd::Array<f32, nd::Dim<[usize; 3]>> {
798 match self {
799 DynamicTensorFloat3D::NdArray(tensor) => tensor.to_ndarray(),
800 DynamicTensorFloat3D::Wgpu(tensor) => tensor.to_ndarray(),
801 DynamicTensorFloat3D::Candle(tensor) => tensor.to_ndarray(),
802 }
803 }
804
805 fn into_ndarray(self) -> nd::Array<f32, nd::Dim<[usize; 3]>> {
806 match self {
807 DynamicTensorFloat3D::NdArray(tensor) => tensor.into_ndarray(),
808 DynamicTensorFloat3D::Wgpu(tensor) => tensor.into_ndarray(),
809 DynamicTensorFloat3D::Candle(tensor) => tensor.into_ndarray(),
810 }
811 }
812
813 fn from_dmatrix(_matrix: &na::DMatrix<f32>) -> Self {
814 panic!("3D DynamicTensor interop with DMatrix is not supported!");
815 }
816
817 fn to_dmatrix(&self) -> na::DMatrix<f32> {
818 panic!("3D DynamicTensor interop with DMatrix is not supported!");
819 }
820
821 fn into_dmatrix(self) -> na::DMatrix<f32> {
822 panic!("3D DynamicTensor interop with DMatrix is not supported!");
823 }
824}
825
826impl DynamicMatrixOps<u32, 2> for DynamicTensorInt2D {
828 fn from_ndarray(array: &nd::Array<u32, nd::Dim<[usize; 2]>>) -> Self {
829 match std::any::TypeId::of::<DefaultBackend>() {
830 id if id == std::any::TypeId::of::<NdArray>() => {
831 let tensor = array.to_burn(&NdArrayDevice::Cpu);
832 DynamicTensorInt2D::NdArray(tensor)
833 }
834 id if id == std::any::TypeId::of::<Candle>() => {
835 let tensor = array.to_burn(&CandleDevice::Cpu);
836 DynamicTensorInt2D::Candle(tensor)
837 }
838 id if id == std::any::TypeId::of::<Wgpu>() => {
839 let tensor = array.to_burn(&WgpuDevice::BestAvailable);
840 DynamicTensorInt2D::Wgpu(tensor)
841 }
842 _ => panic!("Unsupported backend!"),
843 }
844 }
845
846 fn to_ndarray(&self) -> nd::Array<u32, nd::Dim<[usize; 2]>> {
847 match self {
848 DynamicTensorInt2D::NdArray(tensor) => tensor.to_ndarray(),
849 DynamicTensorInt2D::Wgpu(tensor) => tensor.to_ndarray(),
850 DynamicTensorInt2D::Candle(tensor) => tensor.to_ndarray(),
851 }
852 }
853
854 fn into_ndarray(self) -> nd::Array<u32, nd::Dim<[usize; 2]>> {
855 match self {
856 DynamicTensorInt2D::NdArray(tensor) => tensor.into_ndarray(),
857 DynamicTensorInt2D::Wgpu(tensor) => tensor.into_ndarray(),
858 DynamicTensorInt2D::Candle(tensor) => tensor.into_ndarray(),
859 }
860 }
861
862 fn from_dmatrix(matrix: &na::DMatrix<u32>) -> Self {
863 match std::any::TypeId::of::<DefaultBackend>() {
864 id if id == std::any::TypeId::of::<NdArray>() => {
865 let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
866 DynamicTensorInt2D::NdArray(tensor)
867 }
868 id if id == std::any::TypeId::of::<Candle>() => {
869 let tensor = matrix.to_burn(&CandleDevice::Cpu);
870 DynamicTensorInt2D::Candle(tensor)
871 }
872 id if id == std::any::TypeId::of::<Wgpu>() => {
873 let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
874 DynamicTensorInt2D::Wgpu(tensor)
875 }
876 _ => panic!("Unsupported backend!"),
877 }
878 }
879
880 fn to_dmatrix(&self) -> na::DMatrix<u32> {
881 match self {
882 DynamicTensorInt2D::NdArray(tensor) => tensor.to_nalgebra(),
883 DynamicTensorInt2D::Wgpu(tensor) => tensor.to_nalgebra(),
884 DynamicTensorInt2D::Candle(tensor) => tensor.to_nalgebra(),
885 }
886 }
887
888 fn into_dmatrix(self) -> na::DMatrix<u32> {
889 match self {
890 DynamicTensorInt2D::NdArray(tensor) => tensor.into_nalgebra(),
891 DynamicTensorInt2D::Wgpu(tensor) => tensor.into_nalgebra(),
892 DynamicTensorInt2D::Candle(tensor) => tensor.into_nalgebra(),
893 }
894 }
895}
896
897impl DynamicMatrixOps<u32, 3> for DynamicTensorInt3D {
899 fn from_ndarray(array: &nd::Array<u32, nd::Dim<[usize; 3]>>) -> Self {
900 match std::any::TypeId::of::<DefaultBackend>() {
901 id if id == std::any::TypeId::of::<NdArray>() => {
902 let tensor = array.to_burn(&NdArrayDevice::Cpu);
903 DynamicTensorInt3D::NdArray(tensor)
904 }
905 id if id == std::any::TypeId::of::<Candle>() => {
906 let tensor = array.to_burn(&CandleDevice::Cpu);
907 DynamicTensorInt3D::Candle(tensor)
908 }
909 id if id == std::any::TypeId::of::<Wgpu>() => {
910 let tensor = array.to_burn(&WgpuDevice::BestAvailable);
911 DynamicTensorInt3D::Wgpu(tensor)
912 }
913 _ => panic!("Unsupported backend!"),
914 }
915 }
916
917 fn to_ndarray(&self) -> nd::Array<u32, nd::Dim<[usize; 3]>> {
918 match self {
919 DynamicTensorInt3D::NdArray(tensor) => tensor.to_ndarray(),
920 DynamicTensorInt3D::Wgpu(tensor) => tensor.to_ndarray(),
921 DynamicTensorInt3D::Candle(tensor) => tensor.to_ndarray(),
922 }
923 }
924
925 fn into_ndarray(self) -> nd::Array<u32, nd::Dim<[usize; 3]>> {
926 match self {
927 DynamicTensorInt3D::NdArray(tensor) => tensor.into_ndarray(),
928 DynamicTensorInt3D::Wgpu(tensor) => tensor.into_ndarray(),
929 DynamicTensorInt3D::Candle(tensor) => tensor.into_ndarray(),
930 }
931 }
932
933 fn from_dmatrix(_matrix: &na::DMatrix<u32>) -> Self {
934 panic!("3D DynamicTensor interop with DMatrix is not supported!");
935 }
936
937 fn to_dmatrix(&self) -> na::DMatrix<u32> {
938 panic!("3D DynamicTensor interop with DMatrix is not supported!");
939 }
940
941 fn into_dmatrix(self) -> na::DMatrix<u32> {
942 panic!("3D DynamicTensor interop with DMatrix is not supported!");
943 }
944}
945
946pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
952 let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); tensor.div(norm) }
955
956pub fn cross_product<B: Backend>(
958 a: &Tensor<B, 2, Float>, b: &Tensor<B, 2, Float>, ) -> Tensor<B, 2, Float> {
961 let a_chunks = a.clone().chunk(3, 1); let b_chunks = b.clone().chunk(3, 1); let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); let cz = ax.mul(by).sub(ay.mul(bx)); Tensor::stack(vec![cx, cy, cz], 1) }