1use core::panic;
2
3use burn::{
4 backend::{candle::CandleDevice, ndarray::NdArrayDevice, wgpu::WgpuDevice, Candle, NdArray, Wgpu},
5 prelude::Backend,
6 tensor::{Float, Int, Tensor},
7};
8use crate::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn, ToNalgebraFloat, ToNalgebraInt};
14extern crate nalgebra as na;
15use bytemuck;
16use log::warn;
17
18pub type DefaultBackend = NdArray; #[derive(Clone)]
21pub enum BurnBackend {
22 Candle,
23 NdArray,
24 Wgpu,
25}
26
27#[derive(Clone, Debug)]
29pub enum DynamicTensorFloat1D {
30 NdArray(Tensor<NdArray, 1, Float>),
31 Wgpu(Tensor<Wgpu, 1, Float>),
32 Candle(Tensor<Candle, 1, Float>),
33}
34
35#[derive(Clone, Debug)]
37pub enum DynamicTensorFloat2D {
38 NdArray(Tensor<NdArray, 2, Float>),
39 Wgpu(Tensor<Wgpu, 2, Float>),
40 Candle(Tensor<Candle, 2, Float>),
41}
42
43#[derive(Clone, Debug)]
45pub enum DynamicTensorInt1D {
46 NdArray(Tensor<NdArray, 1, Int>),
47 Wgpu(Tensor<Wgpu, 1, Int>),
48 Candle(Tensor<Candle, 1, Int>),
49}
50
51#[derive(Clone, Debug)]
53pub enum DynamicTensorInt2D {
54 NdArray(Tensor<NdArray, 2, Int>),
55 Wgpu(Tensor<Wgpu, 2, Int>),
56 Candle(Tensor<Candle, 2, Int>),
57}
58
59impl DynamicTensorFloat1D {
61 pub fn from_ndarray(tensor: Tensor<NdArray, 1, Float>) -> Self {
62 DynamicTensorFloat1D::NdArray(tensor)
63 }
64 pub fn from_wgpu(tensor: Tensor<Wgpu, 1, Float>) -> Self {
65 DynamicTensorFloat1D::Wgpu(tensor)
66 }
67 pub fn from_candle(tensor: Tensor<Candle, 1, Float>) -> Self {
68 DynamicTensorFloat1D::Candle(tensor)
69 }
70}
71
72impl DynamicTensorFloat2D {
74 pub fn from_ndarray(tensor: Tensor<NdArray, 2, Float>) -> Self {
75 DynamicTensorFloat2D::NdArray(tensor)
76 }
77 pub fn from_wgpu(tensor: Tensor<Wgpu, 2, Float>) -> Self {
78 DynamicTensorFloat2D::Wgpu(tensor)
79 }
80 pub fn from_candle(tensor: Tensor<Candle, 2, Float>) -> Self {
81 DynamicTensorFloat2D::Candle(tensor)
82 }
83}
84
85impl DynamicTensorInt1D {
87 pub fn from_ndarray(tensor: Tensor<NdArray, 1, Int>) -> Self {
88 DynamicTensorInt1D::NdArray(tensor)
89 }
90 pub fn from_wgpu(tensor: Tensor<Wgpu, 1, Int>) -> Self {
91 DynamicTensorInt1D::Wgpu(tensor)
92 }
93 pub fn from_candle(tensor: Tensor<Candle, 1, Int>) -> Self {
94 DynamicTensorInt1D::Candle(tensor)
95 }
96}
97
98impl DynamicTensorInt2D {
100 pub fn from_ndarray(tensor: Tensor<NdArray, 2, Int>) -> Self {
101 DynamicTensorInt2D::NdArray(tensor)
102 }
103 pub fn from_wgpu(tensor: Tensor<Wgpu, 2, Int>) -> Self {
104 DynamicTensorInt2D::Wgpu(tensor)
105 }
106 pub fn from_candle(tensor: Tensor<Candle, 2, Int>) -> Self {
107 DynamicTensorInt2D::Candle(tensor)
108 }
109}
110
111pub trait DynamicTensorOps<T> {
115 fn as_bytes(&self) -> Vec<u8>;
116
117 fn nrows(&self) -> usize;
118 fn shape(&self) -> (usize, usize);
119
120 fn to_vec(&self) -> Vec<T>;
121 fn min_vec(&self) -> Vec<T>;
122 fn max_vec(&self) -> Vec<T>;
123}
124
125impl DynamicTensorOps<f32> for DynamicTensorFloat1D {
127 fn as_bytes(&self) -> Vec<u8> {
128 match self {
129 DynamicTensorFloat1D::NdArray(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
130 DynamicTensorFloat1D::Wgpu(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
131 DynamicTensorFloat1D::Candle(tensor) => bytemuck::cast_slice(&tensor_to_data_float(tensor)).to_vec(),
132 }
133 }
134
135 fn nrows(&self) -> usize {
136 match self {
137 DynamicTensorFloat1D::NdArray(tensor) => tensor.dims()[0],
138 DynamicTensorFloat1D::Wgpu(tensor) => tensor.dims()[0],
139 DynamicTensorFloat1D::Candle(tensor) => tensor.dims()[0],
140 }
141 }
142
143 fn shape(&self) -> (usize, usize) {
144 match self {
145 DynamicTensorFloat1D::NdArray(tensor) => (tensor.dims()[0], 1),
146 DynamicTensorFloat1D::Wgpu(tensor) => (tensor.dims()[0], 1),
147 DynamicTensorFloat1D::Candle(tensor) => (tensor.dims()[0], 1),
148 }
149 }
150
151 fn to_vec(&self) -> Vec<f32> {
152 match &self {
153 DynamicTensorFloat1D::NdArray(tensor) => tensor_to_data_float(tensor),
154 DynamicTensorFloat1D::Wgpu(tensor) => tensor_to_data_float(tensor),
155 DynamicTensorFloat1D::Candle(tensor) => tensor_to_data_float(tensor),
156 }
157 }
158
159 fn min_vec(&self) -> Vec<f32> {
160 vec![self.to_vec().iter().copied().fold(f32::INFINITY, f32::min)]
161 }
162
163 fn max_vec(&self) -> Vec<f32> {
164 vec![self.to_vec().iter().copied().fold(f32::NEG_INFINITY, f32::max)]
165 }
166}
167
168impl DynamicTensorOps<f32> for DynamicTensorFloat2D {
170 fn as_bytes(&self) -> Vec<u8> {
171 match self {
172 DynamicTensorFloat2D::NdArray(tensor) => {
173 let tensor_data = tensor_to_data_float(tensor);
174 bytemuck::cast_slice(&tensor_data).to_vec()
175 }
176 DynamicTensorFloat2D::Wgpu(tensor) => {
177 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
178 let tensor_data = tensor_to_data_float(tensor);
179 bytemuck::cast_slice(&tensor_data).to_vec()
180 }
181 DynamicTensorFloat2D::Candle(tensor) => {
182 let tensor_data = tensor_to_data_float(tensor);
183 bytemuck::cast_slice(&tensor_data).to_vec()
184 }
185 }
186 }
187
188 fn nrows(&self) -> usize {
189 match self {
190 DynamicTensorFloat2D::NdArray(tensor) => tensor.dims()[0],
191 DynamicTensorFloat2D::Wgpu(tensor) => tensor.dims()[0],
192 DynamicTensorFloat2D::Candle(tensor) => tensor.dims()[0],
193 }
194 }
195
196 fn shape(&self) -> (usize, usize) {
197 match self {
198 DynamicTensorFloat2D::NdArray(tensor) => (tensor.dims()[0], tensor.dims()[1]),
199 DynamicTensorFloat2D::Wgpu(tensor) => (tensor.dims()[0], tensor.dims()[1]),
200 DynamicTensorFloat2D::Candle(tensor) => (tensor.dims()[0], tensor.dims()[1]),
201 }
202 }
203
204 fn to_vec(&self) -> Vec<f32> {
205 match &self {
206 DynamicTensorFloat2D::NdArray(tensor) => tensor_to_data_float(tensor),
207 DynamicTensorFloat2D::Wgpu(tensor) => {
208 warn!("Forcing DynamicTensor with Wgpu backend to CPU");
209 tensor_to_data_float(tensor)
210 }
211 DynamicTensorFloat2D::Candle(tensor) => tensor_to_data_float(tensor),
212 }
213 }
214
215 fn min_vec(&self) -> Vec<f32> {
216 match &self {
217 DynamicTensorFloat2D::NdArray(tensor) => {
218 let min_tensor = tensor.clone().min_dim(0);
219 tensor_to_data_float(&min_tensor)
220 }
221 DynamicTensorFloat2D::Wgpu(tensor) => {
222 let min_tensor = tensor.clone().min_dim(0);
223 tensor_to_data_float(&min_tensor)
224 }
225 DynamicTensorFloat2D::Candle(tensor) => {
226 let min_tensor = tensor.clone().min_dim(0);
227 tensor_to_data_float(&min_tensor)
228 }
229 }
230 }
231
232 fn max_vec(&self) -> Vec<f32> {
233 match &self {
234 DynamicTensorFloat2D::NdArray(tensor) => {
235 let max_tensor = tensor.clone().max_dim(0);
236 tensor_to_data_float(&max_tensor)
237 }
238 DynamicTensorFloat2D::Wgpu(tensor) => {
239 let max_tensor = tensor.clone().max_dim(0);
240 tensor_to_data_float(&max_tensor)
241 }
242 DynamicTensorFloat2D::Candle(tensor) => {
243 let max_tensor = tensor.clone().max_dim(0);
244 tensor_to_data_float(&max_tensor)
245 }
246 }
247 }
248}
249
250impl DynamicTensorOps<u32> for DynamicTensorInt1D {
252 fn as_bytes(&self) -> Vec<u8> {
253 match self {
254 DynamicTensorInt1D::NdArray(tensor) => {
255 let tensor_data = tensor_to_data_int(tensor);
256 let u32_data: Vec<u32> = tensor_data
257 .into_iter()
258 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
259 .collect();
260 bytemuck::cast_slice(&u32_data).to_vec()
261 }
262 DynamicTensorInt1D::Wgpu(tensor) => {
263 let tensor_data = tensor_to_data_int(tensor);
264 let u32_data: Vec<u32> = tensor_data
265 .into_iter()
266 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
267 .collect();
268 bytemuck::cast_slice(&u32_data).to_vec()
269 }
270 DynamicTensorInt1D::Candle(tensor) => {
271 let tensor_data = tensor_to_data_int(tensor);
272 let u32_data: Vec<u32> = tensor_data
273 .into_iter()
274 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
275 .collect();
276 bytemuck::cast_slice(&u32_data).to_vec()
277 }
278 }
279 }
280
281 fn nrows(&self) -> usize {
282 match self {
283 DynamicTensorInt1D::NdArray(tensor) => tensor.dims()[0],
284 DynamicTensorInt1D::Wgpu(tensor) => tensor.dims()[0],
285 DynamicTensorInt1D::Candle(tensor) => tensor.dims()[0],
286 }
287 }
288
289 fn shape(&self) -> (usize, usize) {
290 match self {
291 DynamicTensorInt1D::NdArray(tensor) => (tensor.dims()[0], 1),
292 DynamicTensorInt1D::Wgpu(tensor) => (tensor.dims()[0], 1),
293 DynamicTensorInt1D::Candle(tensor) => (tensor.dims()[0], 1),
294 }
295 }
296
297 fn to_vec(&self) -> Vec<u32> {
298 match &self {
299 DynamicTensorInt1D::NdArray(tensor) => {
300 let data = tensor_to_data_int(tensor);
301 data.into_iter()
302 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
303 .collect()
304 }
305 DynamicTensorInt1D::Wgpu(tensor) => {
306 let data = tensor_to_data_int(tensor);
307 data.into_iter()
308 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
309 .collect()
310 }
311 DynamicTensorInt1D::Candle(tensor) => {
312 let data = tensor_to_data_int(tensor);
313 data.into_iter()
314 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
315 .collect()
316 }
317 }
318 }
319
320 fn min_vec(&self) -> Vec<u32> {
321 vec![self.to_vec().into_iter().min().unwrap_or(0)]
322 }
323
324 fn max_vec(&self) -> Vec<u32> {
325 vec![self.to_vec().into_iter().max().unwrap_or(0)]
326 }
327}
328
329impl DynamicTensorOps<u32> for DynamicTensorInt2D {
331 fn as_bytes(&self) -> Vec<u8> {
332 match self {
333 DynamicTensorInt2D::NdArray(tensor) => {
334 let tensor_data = tensor_to_data_int(tensor);
335 let u32_data: Vec<u32> = tensor_data
336 .into_iter()
337 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
338 .collect();
339 bytemuck::cast_slice(&u32_data).to_vec()
340 }
341 DynamicTensorInt2D::Wgpu(tensor) => {
342 let tensor_data = tensor_to_data_int(tensor);
343 let u32_data: Vec<u32> = tensor_data
344 .into_iter()
345 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
346 .collect();
347 bytemuck::cast_slice(&u32_data).to_vec()
348 }
349 DynamicTensorInt2D::Candle(tensor) => {
350 let tensor_data = tensor_to_data_int(tensor);
351 let u32_data: Vec<u32> = tensor_data
352 .into_iter()
353 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
354 .collect();
355 bytemuck::cast_slice(&u32_data).to_vec()
356 }
357 }
358 }
359
360 fn nrows(&self) -> usize {
361 match self {
362 DynamicTensorInt2D::NdArray(tensor) => tensor.dims()[0],
363 DynamicTensorInt2D::Wgpu(tensor) => tensor.dims()[0],
364 DynamicTensorInt2D::Candle(tensor) => tensor.dims()[0],
365 }
366 }
367
368 fn shape(&self) -> (usize, usize) {
369 match self {
370 DynamicTensorInt2D::NdArray(tensor) => (tensor.dims()[0], tensor.dims()[1]),
371 DynamicTensorInt2D::Wgpu(tensor) => (tensor.dims()[0], tensor.dims()[1]),
372 DynamicTensorInt2D::Candle(tensor) => (tensor.dims()[0], tensor.dims()[1]),
373 }
374 }
375
376 fn to_vec(&self) -> Vec<u32> {
377 match &self {
378 DynamicTensorInt2D::NdArray(tensor) => {
379 let data = tensor_to_data_int(tensor);
380 data.into_iter()
381 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
382 .collect()
383 }
384 DynamicTensorInt2D::Wgpu(tensor) => {
385 let data = tensor_to_data_int(tensor);
386 data.into_iter()
387 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
388 .collect()
389 }
390 DynamicTensorInt2D::Candle(tensor) => {
391 let data = tensor_to_data_int(tensor);
392 data.into_iter()
393 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
394 .collect()
395 }
396 }
397 }
398
399 fn min_vec(&self) -> Vec<u32> {
400 match &self {
401 DynamicTensorInt2D::NdArray(tensor) => {
402 let min_tensor = tensor.clone().min_dim(0);
403 tensor_to_data_int(&min_tensor)
404 .into_iter()
405 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
406 .collect()
407 }
408 DynamicTensorInt2D::Wgpu(tensor) => {
409 let min_tensor = tensor.clone().min_dim(0);
410 tensor_to_data_int(&min_tensor)
411 .into_iter()
412 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
413 .collect()
414 }
415 DynamicTensorInt2D::Candle(tensor) => {
416 let min_tensor = tensor.clone().min_dim(0);
417 tensor_to_data_int(&min_tensor)
418 .into_iter()
419 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
420 .collect()
421 }
422 }
423 }
424
425 fn max_vec(&self) -> Vec<u32> {
426 match &self {
427 DynamicTensorInt2D::NdArray(tensor) => {
428 let max_tensor = tensor.clone().max_dim(0);
429 tensor_to_data_int(&max_tensor)
430 .into_iter()
431 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
432 .collect()
433 }
434 DynamicTensorInt2D::Wgpu(tensor) => {
435 let max_tensor = tensor.clone().max_dim(0);
436 tensor_to_data_int(&max_tensor)
437 .into_iter()
438 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
439 .collect()
440 }
441 DynamicTensorInt2D::Candle(tensor) => {
442 let max_tensor = tensor.clone().max_dim(0);
443 tensor_to_data_int(&max_tensor)
444 .into_iter()
445 .map(|x| x.try_into().expect("Negative value found during conversion to u32"))
446 .collect()
447 }
448 }
449 }
450}
451
452pub trait DynamicMatrixOps<T> {
454 fn from_dmatrix(matrix: &na::DMatrix<T>) -> Self;
455 fn to_dmatrix(&self) -> na::DMatrix<T>;
456 fn into_dmatrix(self) -> na::DMatrix<T>;
457}
458
459impl DynamicMatrixOps<f32> for DynamicTensorFloat2D {
461 fn from_dmatrix(matrix: &na::DMatrix<f32>) -> Self {
462 match std::any::TypeId::of::<DefaultBackend>() {
463 id if id == std::any::TypeId::of::<NdArray>() => {
464 let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
465 DynamicTensorFloat2D::NdArray(tensor)
466 }
467 id if id == std::any::TypeId::of::<Candle>() => {
468 let tensor = matrix.to_burn(&CandleDevice::Cpu);
469 DynamicTensorFloat2D::Candle(tensor)
470 }
471 id if id == std::any::TypeId::of::<Wgpu>() => {
472 let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
473 DynamicTensorFloat2D::Wgpu(tensor)
474 }
475 _ => panic!("Unsupported backend!"),
476 }
477 }
478
479 fn to_dmatrix(&self) -> na::DMatrix<f32> {
480 match self {
481 DynamicTensorFloat2D::NdArray(tensor) => tensor.to_nalgebra(),
482 DynamicTensorFloat2D::Wgpu(tensor) => tensor.to_nalgebra(),
483 DynamicTensorFloat2D::Candle(tensor) => tensor.to_nalgebra(),
484 }
485 }
486
487 fn into_dmatrix(self) -> na::DMatrix<f32> {
488 match self {
489 DynamicTensorFloat2D::NdArray(tensor) => tensor.into_nalgebra(),
490 DynamicTensorFloat2D::Wgpu(tensor) => tensor.into_nalgebra(),
491 DynamicTensorFloat2D::Candle(tensor) => tensor.into_nalgebra(),
492 }
493 }
494}
495
496impl DynamicMatrixOps<u32> for DynamicTensorInt2D {
498 fn from_dmatrix(matrix: &na::DMatrix<u32>) -> Self {
499 match std::any::TypeId::of::<DefaultBackend>() {
500 id if id == std::any::TypeId::of::<NdArray>() => {
501 let tensor = matrix.to_burn(&NdArrayDevice::Cpu);
502 DynamicTensorInt2D::NdArray(tensor)
503 }
504 id if id == std::any::TypeId::of::<Candle>() => {
505 let tensor = matrix.to_burn(&CandleDevice::Cpu);
506 DynamicTensorInt2D::Candle(tensor)
507 }
508 id if id == std::any::TypeId::of::<Wgpu>() => {
509 let tensor = matrix.to_burn(&WgpuDevice::BestAvailable);
510 DynamicTensorInt2D::Wgpu(tensor)
511 }
512 _ => panic!("Unsupported backend!"),
513 }
514 }
515
516 fn to_dmatrix(&self) -> na::DMatrix<u32> {
517 match self {
518 DynamicTensorInt2D::NdArray(tensor) => tensor.to_nalgebra(),
519 DynamicTensorInt2D::Wgpu(tensor) => tensor.to_nalgebra(),
520 DynamicTensorInt2D::Candle(tensor) => tensor.to_nalgebra(),
521 }
522 }
523
524 fn into_dmatrix(self) -> na::DMatrix<u32> {
525 match self {
526 DynamicTensorInt2D::NdArray(tensor) => tensor.into_nalgebra(),
527 DynamicTensorInt2D::Wgpu(tensor) => tensor.into_nalgebra(),
528 DynamicTensorInt2D::Candle(tensor) => tensor.into_nalgebra(),
529 }
530 }
531}
532
533pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
539 let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); tensor.div(norm) }
542
543pub fn cross_product<B: Backend>(
545 a: &Tensor<B, 2, Float>, b: &Tensor<B, 2, Float>, ) -> Tensor<B, 2, Float> {
548 let a_chunks = a.clone().chunk(3, 1); let b_chunks = b.clone().chunk(3, 1); let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); let cz = ax.mul(by).sub(ay.mul(bx)); Tensor::stack(vec![cx, cy, cz], 1) }