burn_rmexp_dyntensor/
dyn_tensor.rs

1use crate::clone_box::CloneBox;
2use crate::errors::DynTensorError;
3use crate::kind::KindFlag;
4use crate::operations;
5use crate::rank_dispatch::RankHandler;
6use crate::{indexing, rank_dispatch};
7use burn::Tensor;
8use burn::prelude::{Backend, Bool, Float, Int, Shape, SliceArg, TensorData};
9use burn::tensor::{BasicOps, DType, Slice};
10
11/// Values conversion trait for [`DynTensor::slice_assign`].
12pub trait ValuesArg<B: Backend>: Sized {
13    /// Convert to a [`DynTensor`] on a given device.
14    fn into_values(
15        self,
16        device: &B::Device,
17    ) -> Result<DynTensor<B>, DynTensorError>;
18}
19
20impl<B: Backend, T: Into<DynTensor<B>>> ValuesArg<B> for T {
21    fn into_values(
22        self,
23        device: &B::Device,
24    ) -> Result<DynTensor<B>, DynTensorError> {
25        self.into().to_device(device)
26    }
27}
28
29impl<B: Backend> ValuesArg<B> for TensorData {
30    fn into_values(
31        self,
32        device: &B::Device,
33    ) -> Result<DynTensor<B>, DynTensorError> {
34        DynTensor::from_data(self, device)
35    }
36}
37
38/// A dynamic [`Tensor`] wrapper that can be sliced.
39#[derive(Debug, Clone)]
40pub struct DynTensor<B: Backend> {
41    shape: Shape,
42    dtype: DType,
43    kind: KindFlag,
44    device: B::Device,
45    tensor: Box<dyn CloneBox>,
46    phantom: std::marker::PhantomData<B>,
47}
48
49impl<B: Backend, const R: usize, K> From<Tensor<B, R, K>> for DynTensor<B>
50where
51    K: 'static + BasicOps<B>,
52{
53    fn from(val: Tensor<B, R, K>) -> Self {
54        DynTensor::new(val)
55    }
56}
57
58impl<B: Backend> DynTensor<B> {
59    /// Create a new `TensorStub` from a tensor.
60    pub fn new<const R: usize, K>(tensor: Tensor<B, R, K>) -> Self
61    where
62        K: BasicOps<B> + 'static,
63    {
64        Self {
65            shape: tensor.shape(),
66            dtype: tensor.dtype(),
67            kind: tensor.dtype().into(),
68            device: tensor.device(),
69            tensor: Box::new(tensor),
70            phantom: std::marker::PhantomData,
71        }
72    }
73
74    /// Get the tensor rank.
75    pub fn rank(&self) -> usize {
76        self.shape.rank()
77    }
78
79    /// Get the tensor shape.
80    pub fn shape(&self) -> Shape {
81        self.shape.clone()
82    }
83
84    /// Get the number of elements in the tensor.
85    pub fn num_elements(&self) -> usize {
86        self.shape.num_elements()
87    }
88
89    /// Returns the size estimate of the tensor in bytes.
90    ///
91    /// This is `self.dtype().size() * self.num_elements()`.
92    pub fn size_estimate(&self) -> usize {
93        self.dtype.size() * self.num_elements()
94    }
95
96    /// Get the tensor data type.
97    pub fn dtype(&self) -> DType {
98        self.dtype
99    }
100
101    /// Get the tensor kind.
102    pub fn kind(&self) -> KindFlag {
103        self.kind
104    }
105
106    /// Get the tensor device.
107    pub fn device(&self) -> B::Device {
108        self.device.clone()
109    }
110
111    /// Downcasts the tensor to a specific rank and kind.
112    ///
113    /// # Result
114    /// - `Some(Tensor<B, R, K>)`: if the params are correct,
115    /// - `None`: otherwise.
116    pub fn downcast_clone<const R: usize, K>(&self) -> Option<Tensor<B, R, K>>
117    where
118        K: 'static + BasicOps<B>,
119    {
120        self.tensor.downcast_ref::<Tensor<B, R, K>>().cloned()
121    }
122
123    /// Downcasts to a static tensor.
124    ///
125    /// # Result
126    /// - the static tensor: if the params are correct,
127    ///
128    /// # Panics
129    /// If the types are incorrect.
130    pub fn unwrap_clone<const R: usize, K>(&self) -> Tensor<B, R, K>
131    where
132        K: 'static + BasicOps<B>,
133    {
134        self.downcast_clone::<R, K>()
135            .expect("downcast_clone failed")
136    }
137
138    /// Slice the stub tensor.
139    ///
140    /// Dispatches via [`rank_dispatch::dispatch_rank`].
141    ///
142    /// # Arguments
143    /// - `slices`: a `SliceArg<R>`.
144    ///
145    /// # Result
146    /// - `Ok(DynTensor)`: the sliced tensor.
147    /// - `Err(DynTensorError)`: an error.
148    pub fn slice<const R: usize, S>(
149        self,
150        slices: S,
151    ) -> Result<Self, DynTensorError>
152    where
153        S: SliceArg<R>,
154    {
155        let rank = self.rank();
156        let slices = self.shape().into_slices(slices);
157
158        indexing::check_slices_bounds(&self.shape(), &slices)
159            .map_err(DynTensorError::SliceError)?;
160
161        struct SliceHandler<B: Backend, const R: usize> {
162            this: DynTensor<B>,
163            slices: [Slice; R],
164        }
165        impl<B: Backend, const R: usize> RankHandler for SliceHandler<B, R> {
166            type Output = DynTensor<B>;
167            fn call<const R2: usize>(self) -> Result<Self::Output, DynTensorError> {
168                Ok(match self.this.kind {
169                    KindFlag::Float => self
170                        .this
171                        .unwrap_clone::<R, Float>()
172                        .slice(self.slices)
173                        .into(),
174                    KindFlag::Int => self.this.unwrap_clone::<R, Int>().slice(self.slices).into(),
175                    KindFlag::Bool => self
176                        .this
177                        .unwrap_clone::<R, Bool>()
178                        .slice(self.slices)
179                        .into(),
180                })
181            }
182        }
183        rank_dispatch::dispatch_rank(rank, SliceHandler { this: self, slices })
184    }
185
186    /// A dynamic version of [`DynTensor::slice`].
187    ///
188    /// Dispatches via [`rank_dispatch::dispatch_rank`].
189    ///
190    /// # Arguments
191    /// - `slices`: a dynamic slice of `Slice`.
192    ///
193    /// # Result
194    /// - `Ok(DynTensor)`: the sliced tensor.
195    /// - `Err(DynTensorError)`: an error.
196    pub fn slice_dyn(
197        self,
198        slices: &[Slice],
199    ) -> Result<Self, DynTensorError> {
200        let rank = self.rank();
201
202        indexing::check_slices_bounds(&self.shape(), slices).map_err(DynTensorError::SliceError)?;
203
204        struct SliceDynHandler<'a, B: Backend> {
205            this: DynTensor<B>,
206            slices: &'a [Slice],
207        }
208        impl<'a, B: Backend> RankHandler for SliceDynHandler<'a, B> {
209            type Output = DynTensor<B>;
210            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
211                Ok(match self.this.kind {
212                    KindFlag::Float => {
213                        operations::slice_dyn(self.this.unwrap_clone::<R, Float>(), self.slices)
214                            .into()
215                    }
216                    KindFlag::Int => {
217                        operations::slice_dyn(self.this.unwrap_clone::<R, Int>(), self.slices)
218                            .into()
219                    }
220                    KindFlag::Bool => {
221                        operations::slice_dyn(self.this.unwrap_clone::<R, Bool>(), self.slices)
222                            .into()
223                    }
224                })
225            }
226        }
227        rank_dispatch::dispatch_rank(rank, SliceDynHandler { this: self, slices })
228    }
229
230    /// Assign values to a slice.
231    ///
232    /// Dispatches via [`rank_dispatch::dispatch_rank`].
233    ///
234    /// # Arguments
235    /// - `slices`: a `SlicesArg<R2>`.
236    /// - `values`: a coercible value; see [`ValuesArg`].
237    ///
238    /// # Result
239    /// - `Ok(DynTensor)`: a converted tensor.
240    /// - `Err(DynTensorError)`: an error.
241    pub fn slice_assign<const R2: usize, S, V>(
242        self,
243        slices: S,
244        values: V,
245    ) -> Result<Self, DynTensorError>
246    where
247        S: SliceArg<R2>,
248        V: ValuesArg<B>,
249    {
250        let rank = self.rank();
251        let slices = self.shape().into_slices(slices);
252        let values: DynTensor<B> = values.into_values(&self.device())?;
253
254        indexing::check_slices_bounds(&self.shape(), &slices)
255            .map_err(DynTensorError::SliceError)?;
256
257        if rank != values.rank() {
258            return Err(DynTensorError::InvalidArgument {
259                msg: format!(
260                    "slice of rank ({}) cannot be assigned to tensor of rank ({})",
261                    values.rank(),
262                    rank
263                ),
264            });
265        }
266
267        let values = values.cast(self.dtype())?;
268
269        // TODO: check that slices shape == source.shape
270
271        struct SliceAssignHandler<B: Backend, const R2: usize> {
272            this: DynTensor<B>,
273            slices: [Slice; R2],
274            values: DynTensor<B>,
275        }
276        impl<B: Backend, const R2: usize> RankHandler for SliceAssignHandler<B, R2> {
277            type Output = DynTensor<B>;
278            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
279                Ok(match self.this.kind {
280                    KindFlag::Float => self
281                        .this
282                        .unwrap_clone::<R, Float>()
283                        .slice_assign(self.slices, self.values.unwrap_clone())
284                        .into(),
285                    KindFlag::Int => self
286                        .this
287                        .unwrap_clone::<R, Int>()
288                        .slice_assign(self.slices, self.values.unwrap_clone())
289                        .into(),
290                    KindFlag::Bool => self
291                        .this
292                        .unwrap_clone::<R, Bool>()
293                        .slice_assign(self.slices, self.values.unwrap_clone())
294                        .into(),
295                })
296            }
297        }
298        rank_dispatch::dispatch_rank(
299            rank,
300            SliceAssignHandler {
301                this: self.clone(),
302                slices,
303                values,
304            },
305        )
306    }
307
308    /// Dynamic slice rank version of [`DynTensor::slice_assign`].
309    ///
310    /// Dispatches via [`rank_dispatch::dispatch_rank`].
311    ///
312    /// # Arguments
313    /// - `slices`: a dynamic slice of `Slice`.
314    /// - `values`: a coercible value; see [`ValuesArg`].
315    ///
316    /// # Result
317    /// - `Ok(DynTensor)`: a converted tensor.
318    /// - `Err(DynTensorError)`: an error.
319    pub fn slice_assign_dyn<V>(
320        self,
321        slices: &[Slice],
322        values: V,
323    ) -> Result<Self, DynTensorError>
324    where
325        V: ValuesArg<B>,
326    {
327        struct SliceAssignDynHandler<'a, B: Backend> {
328            this: DynTensor<B>,
329            slices: &'a [Slice],
330            values: DynTensor<B>,
331        }
332        impl<'a, B: Backend> RankHandler for SliceAssignDynHandler<'a, B> {
333            type Output = DynTensor<B>;
334            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
335                let slices: [Slice; R] = self.slices.try_into().unwrap();
336                self.this.slice_assign(slices, self.values)
337            }
338        }
339        let values = values.into_values(&self.device())?;
340        rank_dispatch::dispatch_rank(
341            self.rank(),
342            SliceAssignDynHandler {
343                this: self,
344                slices,
345                values,
346            },
347        )
348    }
349
350    /// Flatten the tensor.
351    ///
352    /// Dispatches via [`rank_dispatch::dispatch_rank`].
353    ///
354    /// # Result
355    /// - `Ok(DynTensor)`: a flattened (rank=1) tensor.
356    /// - `Err(DynTensorError)`: an error.
357    pub fn flatten(self) -> Result<Self, DynTensorError> {
358        struct FlattenHandler<B: Backend> {
359            tensor: DynTensor<B>,
360        }
361        impl<B: Backend> RankHandler for FlattenHandler<B> {
362            type Output = DynTensor<B>;
363            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
364                Ok(match self.tensor.kind {
365                    KindFlag::Float => self
366                        .tensor
367                        .unwrap_clone::<R, Float>()
368                        .flatten::<1>(0, self.tensor.rank() - 1)
369                        .into(),
370                    KindFlag::Int => self
371                        .tensor
372                        .unwrap_clone::<R, Int>()
373                        .flatten::<1>(0, self.tensor.rank() - 1)
374                        .into(),
375                    KindFlag::Bool => self
376                        .tensor
377                        .unwrap_clone::<R, Bool>()
378                        .flatten::<1>(0, self.tensor.rank() - 1)
379                        .into(),
380                })
381            }
382        }
383        rank_dispatch::dispatch_rank(self.rank(), FlattenHandler { tensor: self })
384    }
385
386    /// Cast the tensor.
387    ///
388    /// Auto-converts kind if necessary.
389    ///
390    /// Dispatches via [`rank_dispatch::dispatch_rank`].
391    ///
392    /// # Arguments
393    /// - `dtype`: the target data type.
394    ///
395    /// # Result
396    /// - `Ok(DynTensor)`: a converted tensor.
397    /// - `Err(DynTensorError)`: an error.
398    pub fn cast(
399        self,
400        dtype: DType,
401    ) -> Result<Self, DynTensorError> {
402        struct CastHandler<B: Backend> {
403            this: DynTensor<B>,
404            dtype: DType,
405        }
406        impl<B: Backend> RankHandler for CastHandler<B> {
407            type Output = DynTensor<B>;
408            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
409                let target_kind: KindFlag = self.dtype.into();
410                Ok(match self.this.kind {
411                    KindFlag::Float => {
412                        let tensor: Tensor<B, R, Float> = self.this.unwrap_clone();
413                        match target_kind {
414                            KindFlag::Float => tensor.cast(self.dtype).into(),
415                            KindFlag::Int => tensor.int().cast(self.dtype).into(),
416                            KindFlag::Bool => tensor.bool().into(),
417                        }
418                    }
419                    KindFlag::Int => {
420                        let tensor: Tensor<B, R, Int> = self.this.unwrap_clone();
421                        match target_kind {
422                            KindFlag::Float => tensor.float().cast(self.dtype).into(),
423                            KindFlag::Int => tensor.cast(self.dtype).into(),
424                            KindFlag::Bool => tensor.bool().into(),
425                        }
426                    }
427                    KindFlag::Bool => {
428                        let tensor: Tensor<B, R, Bool> = self.this.unwrap_clone();
429                        match target_kind {
430                            KindFlag::Float => tensor.float().cast(self.dtype).into(),
431                            KindFlag::Int => tensor.int().cast(self.dtype).into(),
432                            KindFlag::Bool => self.this,
433                        }
434                    }
435                })
436            }
437        }
438        rank_dispatch::dispatch_rank(self.rank(), CastHandler { this: self, dtype })
439    }
440
441    /// Move the tensor to the given device.
442    ///
443    /// Moving to the same device is an inexpensive no-op.
444    ///
445    /// Dispatches via [`rank_dispatch::dispatch_rank`].
446    ///
447    /// # Arguments
448    /// - `device`: the target device.
449    ///
450    /// # Result
451    /// - `Ok(DynTensor<B>)`: the moved tensor.
452    /// - `Err(DynTensorError)`: an error.
453    pub fn to_device(
454        self,
455        device: &B::Device,
456    ) -> Result<Self, DynTensorError> {
457        if &self.device() == device {
458            return Ok(self);
459        }
460
461        struct ToDeviceHandler<'a, B: Backend> {
462            this: DynTensor<B>,
463            device: &'a B::Device,
464        }
465        impl<'a, B: Backend> RankHandler for ToDeviceHandler<'a, B> {
466            type Output = DynTensor<B>;
467            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
468                Ok(match self.this.kind {
469                    KindFlag::Float => self
470                        .this
471                        .unwrap_clone::<R, Float>()
472                        .to_device(self.device)
473                        .into(),
474                    KindFlag::Int => self
475                        .this
476                        .unwrap_clone::<R, Int>()
477                        .to_device(self.device)
478                        .into(),
479                    KindFlag::Bool => self
480                        .this
481                        .unwrap_clone::<R, Bool>()
482                        .to_device(self.device)
483                        .into(),
484                })
485            }
486        }
487        rank_dispatch::dispatch_rank(self.rank(), ToDeviceHandler { this: self, device })
488    }
489
490    /// Convert a [`TensorData`] to a [`DynTensor`].
491    ///
492    /// Dispatches via [`rank_dispatch::dispatch_rank`].
493    ///
494    /// # Arguments
495    /// - `data`: source [`TensorData`].
496    /// - `device`: the target device.
497    ///
498    /// # Result
499    /// - `Ok(DynTensor<B>)`: the converted tensor.
500    /// - `Err(DynTensorError)`: an error.
501    pub fn from_data(
502        data: TensorData,
503        device: &B::Device,
504    ) -> Result<Self, DynTensorError> {
505        struct FromDataHandler<'a, B: Backend> {
506            data: TensorData,
507            device: &'a B::Device,
508        }
509        impl<'a, B: Backend> RankHandler for FromDataHandler<'a, B> {
510            type Output = DynTensor<B>;
511            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
512                let kind: KindFlag = self.data.dtype.into();
513                Ok(match kind {
514                    KindFlag::Float => {
515                        Tensor::<B, R, Float>::from_data(self.data, self.device).into()
516                    }
517                    KindFlag::Int => Tensor::<B, R, Int>::from_data(self.data, self.device).into(),
518                    KindFlag::Bool => {
519                        Tensor::<B, R, Bool>::from_data(self.data, self.device).into()
520                    }
521                })
522            }
523        }
524        rank_dispatch::dispatch_rank(data.rank(), FromDataHandler { data, device })
525    }
526
527    /// Convert the tensor to a [`TensorData`].
528    ///
529    /// Dispatches via [`rank_dispatch::dispatch_rank`].
530    ///
531    /// # Result
532    /// - `Ok(TensorData)`: the converted data.
533    /// - `Err(DynTensorError)`: an error.
534    pub fn into_data(self) -> Result<TensorData, DynTensorError> {
535        struct ToDataHandler<B: Backend> {
536            this: DynTensor<B>,
537        }
538        impl<B: Backend> RankHandler for ToDataHandler<B> {
539            type Output = TensorData;
540            fn call<const R: usize>(self) -> Result<Self::Output, DynTensorError> {
541                Ok(match self.this.kind {
542                    KindFlag::Float => self.this.unwrap_clone::<R, Float>().into_data(),
543                    KindFlag::Int => self.this.unwrap_clone::<R, Int>().into_data(),
544                    KindFlag::Bool => self.this.unwrap_clone::<R, Bool>().into_data(),
545                })
546            }
547        }
548        rank_dispatch::dispatch_rank(self.rank(), ToDataHandler { this: self })
549    }
550
551    /// Convert the tensor to a [`TensorData`].
552    ///
553    /// Dispatches via [`rank_dispatch::dispatch_rank`].
554    ///
555    /// # Result
556    /// - `Ok(TensorData)`: the converted data.
557    /// - `Err(DynTensorError)`: an error.
558    pub fn to_data(self) -> Result<TensorData, DynTensorError> {
559        self.clone().into_data()
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use burn::backend::Wgpu;
567    use burn::prelude::s;
568    use burn::tensor::Distribution;
569
570    fn assert_send<T: Send>() {}
571
572    #[test]
573    fn test_send() {
574        type B = Wgpu;
575        assert_send::<DynTensor<B>>();
576    }
577
578    #[test]
579    fn test_stub_float() {
580        type B = Wgpu;
581        let device = Default::default();
582
583        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
584
585        let stub = DynTensor::new(source.clone());
586
587        assert_eq!(stub.rank(), 2);
588        assert_eq!(stub.shape(), source.shape());
589        assert_eq!(stub.num_elements(), 6);
590
591        assert_eq!(stub.dtype(), source.dtype());
592        assert_eq!(
593            stub.size_estimate(),
594            stub.num_elements() * source.dtype().size()
595        );
596
597        assert_eq!(stub.kind(), KindFlag::Float);
598
599        assert_eq!(stub.device(), device);
600
601        assert!(stub.downcast_clone::<2, Int>().is_none());
602        assert!(stub.downcast_clone::<2, Bool>().is_none());
603
604        assert!(stub.downcast_clone::<3, Float>().is_none());
605
606        let clone = stub.downcast_clone::<2, Float>().unwrap();
607        clone.to_data().assert_eq(&source.clone().to_data(), true);
608
609        stub.clone()
610            .into_data()
611            .unwrap()
612            .assert_eq(&source.clone().to_data(), true);
613
614        let flatten = stub.clone().flatten().unwrap();
615        assert_eq!(flatten.shape(), [6].into());
616        flatten
617            .into_data()
618            .unwrap()
619            .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
620    }
621
622    #[test]
623    fn test_stub_int() {
624        type B = Wgpu;
625        let device = Default::default();
626
627        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
628        let source = source.int();
629
630        let stub = DynTensor::new(source.clone());
631
632        assert_eq!(stub.rank(), 2);
633        assert_eq!(stub.shape(), source.shape());
634        assert_eq!(stub.num_elements(), 6);
635
636        assert_eq!(stub.dtype(), source.dtype());
637        assert_eq!(
638            stub.size_estimate(),
639            stub.num_elements() * source.dtype().size()
640        );
641
642        assert_eq!(stub.kind(), KindFlag::Int);
643
644        assert_eq!(stub.device(), device);
645
646        assert!(stub.downcast_clone::<2, Float>().is_none());
647        assert!(stub.downcast_clone::<2, Bool>().is_none());
648
649        assert!(stub.downcast_clone::<3, Int>().is_none());
650
651        let clone = stub.downcast_clone::<2, Int>().unwrap();
652        clone.to_data().assert_eq(&source.clone().to_data(), true);
653
654        stub.clone()
655            .into_data()
656            .unwrap()
657            .assert_eq(&source.clone().to_data(), true);
658
659        let flatten = stub.clone().flatten().unwrap();
660        assert_eq!(flatten.shape(), [6].into());
661        flatten
662            .into_data()
663            .unwrap()
664            .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
665    }
666
667    #[test]
668    fn test_stub_bool() {
669        type B = Wgpu;
670        let device = Default::default();
671
672        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Bernoulli(0.5), &device);
673        let source = source.bool();
674
675        let stub = DynTensor::new(source.clone());
676
677        assert_eq!(stub.rank(), 2);
678        assert_eq!(stub.shape(), source.shape());
679        assert_eq!(stub.num_elements(), 6);
680
681        assert_eq!(stub.dtype(), source.dtype());
682        assert_eq!(
683            stub.size_estimate(),
684            stub.num_elements() * source.dtype().size()
685        );
686
687        assert_eq!(stub.kind(), KindFlag::Bool);
688
689        assert_eq!(stub.device(), device);
690
691        assert!(stub.downcast_clone::<2, Int>().is_none());
692        assert!(stub.downcast_clone::<2, Float>().is_none());
693
694        assert!(stub.downcast_clone::<3, Bool>().is_none());
695
696        let clone = stub.downcast_clone::<2, Bool>().unwrap();
697        clone.to_data().assert_eq(&source.clone().to_data(), true);
698
699        stub.clone()
700            .into_data()
701            .unwrap()
702            .assert_eq(&source.clone().to_data(), true);
703
704        let flatten = stub.clone().flatten().unwrap();
705        assert_eq!(flatten.shape(), [6].into());
706        flatten
707            .into_data()
708            .unwrap()
709            .assert_eq(&source.clone().flatten::<1>(0, 1).to_data(), true);
710    }
711
712    #[test]
713    fn test_clone() {
714        type B = Wgpu;
715        let device = Default::default();
716
717        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
718
719        let stub = DynTensor::new(source.clone());
720
721        let stub_clone = stub.clone();
722
723        assert!(stub_clone.downcast_clone::<3, Float>().is_none());
724        assert!(stub_clone.downcast_clone::<2, Int>().is_none());
725        let clone = stub_clone.downcast_clone::<2, Float>().unwrap();
726        clone.to_data().assert_eq(&source.clone().to_data(), true);
727    }
728
729    #[test]
730    fn test_slice() {
731        type B = Wgpu;
732        let device = Default::default();
733
734        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
735
736        let stub = DynTensor::new(source.clone());
737
738        let slice = stub.slice(s![.., 1..]).unwrap();
739        assert_eq!(slice.shape(), [2, 2].into());
740        slice
741            .downcast_clone::<2, Float>()
742            .unwrap()
743            .to_data()
744            .assert_eq(&source.clone().slice(s![.., 1..]).to_data(), true);
745    }
746
747    #[test]
748    fn test_slice_dyn() {
749        type B = Wgpu;
750        let device = Default::default();
751
752        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
753
754        let stub = DynTensor::new(source.clone());
755
756        let slice = stub
757            .slice_dyn(&vec![Slice::new(0, None, 1), Slice::new(1, None, 1)])
758            .unwrap();
759        assert_eq!(slice.shape(), [2, 2].into());
760        slice
761            .downcast_clone::<2, Float>()
762            .unwrap()
763            .to_data()
764            .assert_eq(&source.clone().slice(s![.., 1..]).to_data(), true);
765    }
766}