1use crate::clone_box::CloneBox;
2use crate::errors::DynTensorError;
3use crate::kind::KindFlag;
4use crate::operations;
5use crate::rank_dispatch::RankHandler;
6use crate::{indexing, rank_dispatch};
7use burn::Tensor;
8use burn::prelude::{Backend, Bool, Float, Int, Shape, SliceArg, TensorData};
9use burn::tensor::{BasicOps, DType, Slice};
10
11pub trait ValuesArg<B: Backend>: Sized {
13 fn into_values(
15 self,
16 device: &B::Device,
17 ) -> Result<DynTensor<B>, DynTensorError>;
18}
19
20impl<B: Backend, T: Into<DynTensor<B>>> ValuesArg<B> for T {
21 fn into_values(
22 self,
23 device: &B::Device,
24 ) -> Result<DynTensor<B>, DynTensorError> {
25 self.into().to_device(device)
26 }
27}
28
29impl<B: Backend> ValuesArg<B> for TensorData {
30 fn into_values(
31 self,
32 device: &B::Device,
33 ) -> Result<DynTensor<B>, DynTensorError> {
34 DynTensor::from_data(self, device)
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct DynTensor<B: Backend> {
41 shape: Shape,
42 dtype: DType,
43 kind: KindFlag,
44 device: B::Device,
45 tensor: Box<dyn CloneBox>,
46 phantom: std::marker::PhantomData<B>,
47}
48
49impl<B: Backend, const R: usize, K> From<Tensor<B, R, K>> for DynTensor<B>
50where
51 K: 'static + BasicOps<B>,
52{
53 fn from(val: Tensor<B, R, K>) -> Self {
54 DynTensor::new(val)
55 }
56}
57
58impl<B: Backend> DynTensor<B> {
59 pub fn new<const R: usize, K>(tensor: Tensor<B, R, K>) -> Self
61 where
62 K: BasicOps<B> + 'static,
63 {
64 Self {
65 shape: tensor.shape(),
66 dtype: tensor.dtype(),
67 kind: tensor.dtype().into(),
68 device: tensor.device(),
69 tensor: Box::new(tensor),
70 phantom: std::marker::PhantomData,
71 }
72 }
73
74 pub fn rank(&self) -> usize {
76 self.shape.rank()
77 }
78
79 pub fn shape(&self) -> Shape {
81 self.shape.clone()
82 }
83
84 pub fn num_elements(&self) -> usize {
86 self.shape.num_elements()
87 }
88
89 pub fn size_estimate(&self) -> usize {
93 self.dtype.size() * self.num_elements()
94 }
95
96 pub fn dtype(&self) -> DType {
98 self.dtype
99 }
100
101 pub fn kind(&self) -> KindFlag {
103 self.kind
104 }
105
106 pub fn device(&self) -> B::Device {
108 self.device.clone()
109 }
110
111 pub fn downcast_clone<const R: usize, K>(&self) -> Option<Tensor<B, R, K>>
117 where
118 K: 'static + BasicOps<B>,
119 {
120 self.tensor.downcast_ref::<Tensor<B, R, K>>().cloned()
121 }
122
123 pub fn unwrap_clone<const R: usize, K>(&self) -> Tensor<B, R, K>
131 where
132 K: 'static + BasicOps<B>,
133 {
134 self.downcast_clone::<R, K>()
135 .expect("downcast_clone failed")
136 }
137
138 pub fn slice<const R: usize, S>(
149 self,
150 slices: S,
151 ) -> Result<Self, DynTensorError>
152 where
153 S: SliceArg<R>,
154 {
155 let rank = self.rank();
156 let slices = self.shape().into_slices(slices);
157
158 indexing::check_slices_bounds(&self.shape(), &slices)
159 .map_err(DynTensorError::SliceError)?;
160
161 struct SliceHandler<B: Backend, const R: usize> {
162 this: DynTensor<B>,
163 slices: [Slice; R],
164 }
165 impl<B: Backend, const R: usize> RankHandler for SliceHandler<B, R> {
166 type Output = DynTensor<B>;
167 fn call<const R2: usize>(self) -> Result<Self::Output, DynTensorError> {
168 Ok(match self.this.kind {
169 KindFlag::Float => self
170 .this
171 .unwrap_clone::<R, Float>()
172 .slice(self.slices)
173 .into(),
174 KindFlag::Int => self.this.unwrap_clone::<R, Int>().slice(self.slices).into(),
175 KindFlag::Bool => self
176 .this
177 .unwrap_clone::<R, Bool>()
178 .slice(self.slices)
179 .into(),
180 })
181 }
182 }
183 rank_dispatch::dispatch_rank(rank, SliceHandler { this: self, slices })
184 }
185
186 pub fn slice_dyn(
197 self,
198 slices: &[Slice],
199 ) -> Result<Self, DynTensorError> {
200 let rank = self.rank();
201
202 indexing::check_slices_bounds(&self.shape(), slices).map_err(DynTensorError::SliceError)?;
203
204 struct SliceDynHandler<'a, B: Backend> {
205 this: DynTensor<B>,
206 slices: &'a [Slice],
207 }
208 impl<'a, B: Backend> RankHandler for SliceDynHandler<'a, B> {
209 type Output = DynTensor<B>;
210 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
211 Ok(match self.this.kind {
212 KindFlag::Float => {
213 operations::slice_dyn(self.this.unwrap_clone::<R, Float>(), self.slices)
214 .into()
215 }
216 KindFlag::Int => {
217 operations::slice_dyn(self.this.unwrap_clone::<R, Int>(), self.slices)
218 .into()
219 }
220 KindFlag::Bool => {
221 operations::slice_dyn(self.this.unwrap_clone::<R, Bool>(), self.slices)
222 .into()
223 }
224 })
225 }
226 }
227 rank_dispatch::dispatch_rank(rank, SliceDynHandler { this: self, slices })
228 }
229
230 pub fn slice_assign<const R2: usize, S, V>(
242 self,
243 slices: S,
244 values: V,
245 ) -> Result<Self, DynTensorError>
246 where
247 S: SliceArg<R2>,
248 V: ValuesArg<B>,
249 {
250 let rank = self.rank();
251 let slices = self.shape().into_slices(slices);
252 let values: DynTensor<B> = values.into_values(&self.device())?;
253
254 indexing::check_slices_bounds(&self.shape(), &slices)
255 .map_err(DynTensorError::SliceError)?;
256
257 if rank != values.rank() {
258 return Err(DynTensorError::InvalidArgument {
259 msg: format!(
260 "slice of rank ({}) cannot be assigned to tensor of rank ({})",
261 values.rank(),
262 rank
263 ),
264 });
265 }
266
267 let values = values.cast(self.dtype())?;
268
269 struct SliceAssignHandler<B: Backend, const R2: usize> {
272 this: DynTensor<B>,
273 slices: [Slice; R2],
274 values: DynTensor<B>,
275 }
276 impl<B: Backend, const R2: usize> RankHandler for SliceAssignHandler<B, R2> {
277 type Output = DynTensor<B>;
278 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
279 Ok(match self.this.kind {
280 KindFlag::Float => self
281 .this
282 .unwrap_clone::<R, Float>()
283 .slice_assign(self.slices, self.values.unwrap_clone())
284 .into(),
285 KindFlag::Int => self
286 .this
287 .unwrap_clone::<R, Int>()
288 .slice_assign(self.slices, self.values.unwrap_clone())
289 .into(),
290 KindFlag::Bool => self
291 .this
292 .unwrap_clone::<R, Bool>()
293 .slice_assign(self.slices, self.values.unwrap_clone())
294 .into(),
295 })
296 }
297 }
298 rank_dispatch::dispatch_rank(
299 rank,
300 SliceAssignHandler {
301 this: self.clone(),
302 slices,
303 values,
304 },
305 )
306 }
307
308 pub fn slice_assign_dyn<V>(
320 self,
321 slices: &[Slice],
322 values: V,
323 ) -> Result<Self, DynTensorError>
324 where
325 V: ValuesArg<B>,
326 {
327 struct SliceAssignDynHandler<'a, B: Backend> {
328 this: DynTensor<B>,
329 slices: &'a [Slice],
330 values: DynTensor<B>,
331 }
332 impl<'a, B: Backend> RankHandler for SliceAssignDynHandler<'a, B> {
333 type Output = DynTensor<B>;
334 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
335 let slices: [Slice; R] = self.slices.try_into().unwrap();
336 self.this.slice_assign(slices, self.values)
337 }
338 }
339 let values = values.into_values(&self.device())?;
340 rank_dispatch::dispatch_rank(
341 self.rank(),
342 SliceAssignDynHandler {
343 this: self,
344 slices,
345 values,
346 },
347 )
348 }
349
350 pub fn flatten(self) -> Result<Self, DynTensorError> {
358 struct FlattenHandler<B: Backend> {
359 tensor: DynTensor<B>,
360 }
361 impl<B: Backend> RankHandler for FlattenHandler<B> {
362 type Output = DynTensor<B>;
363 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
364 Ok(match self.tensor.kind {
365 KindFlag::Float => self
366 .tensor
367 .unwrap_clone::<R, Float>()
368 .flatten::<1>(0, self.tensor.rank() - 1)
369 .into(),
370 KindFlag::Int => self
371 .tensor
372 .unwrap_clone::<R, Int>()
373 .flatten::<1>(0, self.tensor.rank() - 1)
374 .into(),
375 KindFlag::Bool => self
376 .tensor
377 .unwrap_clone::<R, Bool>()
378 .flatten::<1>(0, self.tensor.rank() - 1)
379 .into(),
380 })
381 }
382 }
383 rank_dispatch::dispatch_rank(self.rank(), FlattenHandler { tensor: self })
384 }
385
386 pub fn cast(
399 self,
400 dtype: DType,
401 ) -> Result<Self, DynTensorError> {
402 struct CastHandler<B: Backend> {
403 this: DynTensor<B>,
404 dtype: DType,
405 }
406 impl<B: Backend> RankHandler for CastHandler<B> {
407 type Output = DynTensor<B>;
408 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
409 let target_kind: KindFlag = self.dtype.into();
410 Ok(match self.this.kind {
411 KindFlag::Float => {
412 let tensor: Tensor<B, R, Float> = self.this.unwrap_clone();
413 match target_kind {
414 KindFlag::Float => tensor.cast(self.dtype).into(),
415 KindFlag::Int => tensor.int().cast(self.dtype).into(),
416 KindFlag::Bool => tensor.bool().into(),
417 }
418 }
419 KindFlag::Int => {
420 let tensor: Tensor<B, R, Int> = self.this.unwrap_clone();
421 match target_kind {
422 KindFlag::Float => tensor.float().cast(self.dtype).into(),
423 KindFlag::Int => tensor.cast(self.dtype).into(),
424 KindFlag::Bool => tensor.bool().into(),
425 }
426 }
427 KindFlag::Bool => {
428 let tensor: Tensor<B, R, Bool> = self.this.unwrap_clone();
429 match target_kind {
430 KindFlag::Float => tensor.float().cast(self.dtype).into(),
431 KindFlag::Int => tensor.int().cast(self.dtype).into(),
432 KindFlag::Bool => self.this,
433 }
434 }
435 })
436 }
437 }
438 rank_dispatch::dispatch_rank(self.rank(), CastHandler { this: self, dtype })
439 }
440
441 pub fn to_device(
454 self,
455 device: &B::Device,
456 ) -> Result<Self, DynTensorError> {
457 if &self.device() == device {
458 return Ok(self);
459 }
460
461 struct ToDeviceHandler<'a, B: Backend> {
462 this: DynTensor<B>,
463 device: &'a B::Device,
464 }
465 impl<'a, B: Backend> RankHandler for ToDeviceHandler<'a, B> {
466 type Output = DynTensor<B>;
467 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
468 Ok(match self.this.kind {
469 KindFlag::Float => self
470 .this
471 .unwrap_clone::<R, Float>()
472 .to_device(self.device)
473 .into(),
474 KindFlag::Int => self
475 .this
476 .unwrap_clone::<R, Int>()
477 .to_device(self.device)
478 .into(),
479 KindFlag::Bool => self
480 .this
481 .unwrap_clone::<R, Bool>()
482 .to_device(self.device)
483 .into(),
484 })
485 }
486 }
487 rank_dispatch::dispatch_rank(self.rank(), ToDeviceHandler { this: self, device })
488 }
489
490 pub fn from_data(
502 data: TensorData,
503 device: &B::Device,
504 ) -> Result<Self, DynTensorError> {
505 struct FromDataHandler<'a, B: Backend> {
506 data: TensorData,
507 device: &'a B::Device,
508 }
509 impl<'a, B: Backend> RankHandler for FromDataHandler<'a, B> {
510 type Output = DynTensor<B>;
511 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
512 let kind: KindFlag = self.data.dtype.into();
513 Ok(match kind {
514 KindFlag::Float => {
515 Tensor::<B, R, Float>::from_data(self.data, self.device).into()
516 }
517 KindFlag::Int => Tensor::<B, R, Int>::from_data(self.data, self.device).into(),
518 KindFlag::Bool => {
519 Tensor::<B, R, Bool>::from_data(self.data, self.device).into()
520 }
521 })
522 }
523 }
524 rank_dispatch::dispatch_rank(data.rank(), FromDataHandler { data, device })
525 }
526
527 pub fn into_data(self) -> Result<TensorData, DynTensorError> {
535 struct ToDataHandler<B: Backend> {
536 this: DynTensor<B>,
537 }
538 impl<B: Backend> RankHandler for ToDataHandler<B> {
539 type Output = TensorData;
540 fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
541 Ok(match self.this.kind {
542 KindFlag::Float => self.this.unwrap_clone::<R, Float>().into_data(),
543 KindFlag::Int => self.this.unwrap_clone::<R, Int>().into_data(),
544 KindFlag::Bool => self.this.unwrap_clone::<R, Bool>().into_data(),
545 })
546 }
547 }
548 rank_dispatch::dispatch_rank(self.rank(), ToDataHandler { this: self })
549 }
550
551 pub fn to_data(self) -> Result<TensorData, DynTensorError> {
559 self.clone().into_data()
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use burn::backend::Wgpu;
567 use burn::prelude::s;
568 use burn::tensor::Distribution;
569
570 fn assert_send<T: Send>() {}
571
572 #[test]
573 fn test_send() {
574 type B = Wgpu;
575 assert_send::<DynTensor<B>>();
576 }
577
578 #[test]
579 fn test_stub_float() {
580 type B = Wgpu;
581 let device = Default::default();
582
583 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
584
585 let stub = DynTensor::new(source.clone());
586
587 assert_eq!(stub.rank(), 2);
588 assert_eq!(stub.shape(), source.shape());
589 assert_eq!(stub.num_elements(), 6);
590
591 assert_eq!(stub.dtype(), source.dtype());
592 assert_eq!(
593 stub.size_estimate(),
594 stub.num_elements() * source.dtype().size()
595 );
596
597 assert_eq!(stub.kind(), KindFlag::Float);
598
599 assert_eq!(stub.device(), device);
600
601 assert!(stub.downcast_clone::<2, Int>().is_none());
602 assert!(stub.downcast_clone::<2, Bool>().is_none());
603
604 assert!(stub.downcast_clone::<3, Float>().is_none());
605
606 let clone = stub.downcast_clone::<2, Float>().unwrap();
607 clone.to_data().assert_eq(&source.clone().to_data(), true);
608
609 stub.clone()
610 .into_data()
611 .unwrap()
612 .assert_eq(&source.clone().to_data(), true);
613
614 let flatten = stub.clone().flatten().unwrap();
615 assert_eq!(flatten.shape(), [6].into());
616 flatten
617 .into_data()
618 .unwrap()
619 .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
620 }
621
622 #[test]
623 fn test_stub_int() {
624 type B = Wgpu;
625 let device = Default::default();
626
627 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
628 let source = source.int();
629
630 let stub = DynTensor::new(source.clone());
631
632 assert_eq!(stub.rank(), 2);
633 assert_eq!(stub.shape(), source.shape());
634 assert_eq!(stub.num_elements(), 6);
635
636 assert_eq!(stub.dtype(), source.dtype());
637 assert_eq!(
638 stub.size_estimate(),
639 stub.num_elements() * source.dtype().size()
640 );
641
642 assert_eq!(stub.kind(), KindFlag::Int);
643
644 assert_eq!(stub.device(), device);
645
646 assert!(stub.downcast_clone::<2, Float>().is_none());
647 assert!(stub.downcast_clone::<2, Bool>().is_none());
648
649 assert!(stub.downcast_clone::<3, Int>().is_none());
650
651 let clone = stub.downcast_clone::<2, Int>().unwrap();
652 clone.to_data().assert_eq(&source.clone().to_data(), true);
653
654 stub.clone()
655 .into_data()
656 .unwrap()
657 .assert_eq(&source.clone().to_data(), true);
658
659 let flatten = stub.clone().flatten().unwrap();
660 assert_eq!(flatten.shape(), [6].into());
661 flatten
662 .into_data()
663 .unwrap()
664 .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
665 }
666
667 #[test]
668 fn test_stub_bool() {
669 type B = Wgpu;
670 let device = Default::default();
671
672 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Bernoulli(0.5), &device);
673 let source = source.bool();
674
675 let stub = DynTensor::new(source.clone());
676
677 assert_eq!(stub.rank(), 2);
678 assert_eq!(stub.shape(), source.shape());
679 assert_eq!(stub.num_elements(), 6);
680
681 assert_eq!(stub.dtype(), source.dtype());
682 assert_eq!(
683 stub.size_estimate(),
684 stub.num_elements() * source.dtype().size()
685 );
686
687 assert_eq!(stub.kind(), KindFlag::Bool);
688
689 assert_eq!(stub.device(), device);
690
691 assert!(stub.downcast_clone::<2, Int>().is_none());
692 assert!(stub.downcast_clone::<2, Float>().is_none());
693
694 assert!(stub.downcast_clone::<3, Bool>().is_none());
695
696 let clone = stub.downcast_clone::<2, Bool>().unwrap();
697 clone.to_data().assert_eq(&source.clone().to_data(), true);
698
699 stub.clone()
700 .into_data()
701 .unwrap()
702 .assert_eq(&source.clone().to_data(), true);
703
704 let flatten = stub.clone().flatten().unwrap();
705 assert_eq!(flatten.shape(), [6].into());
706 flatten
707 .into_data()
708 .unwrap()
709 .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
710 }
711
712 #[test]
713 fn test_clone() {
714 type B = Wgpu;
715 let device = Default::default();
716
717 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
718
719 let stub = DynTensor::new(source.clone());
720
721 let stub_clone = stub.clone();
722
723 assert!(stub_clone.downcast_clone::<3, Float>().is_none());
724 assert!(stub_clone.downcast_clone::<2, Int>().is_none());
725 let clone = stub_clone.downcast_clone::<2, Float>().unwrap();
726 clone.to_data().assert_eq(&source.clone().to_data(), true);
727 }
728
729 #[test]
730 fn test_slice() {
731 type B = Wgpu;
732 let device = Default::default();
733
734 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
735
736 let stub = DynTensor::new(source.clone());
737
738 let slice = stub.slice(s![.., 1..]).unwrap();
739 assert_eq!(slice.shape(), [2, 2].into());
740 slice
741 .downcast_clone::<2, Float>()
742 .unwrap()
743 .to_data()
744 .assert_eq(&source.clone().slice(s![.., 1..]).to_data(), true);
745 }
746
747 #[test]
748 fn test_slice_dyn() {
749 type B = Wgpu;
750 let device = Default::default();
751
752 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
753
754 let stub = DynTensor::new(source.clone());
755
756 let slice = stub
757 .slice_dyn(&vec![Slice::new(0, None, 1), Slice::new(1, None, 1)])
758 .unwrap();
759 assert_eq!(slice.shape(), [2, 2].into());
760 slice
761 .downcast_clone::<2, Float>()
762 .unwrap()
763 .to_data()
764 .assert_eq(&source.clone().slice(s![.., 1..]).to_data(), true);
765 }
766}