burn_rmexp_dyntensor/
kind.rs

1use burn::Tensor;
2use burn::prelude::{Backend, Bool, Float, Int};
3use burn::tensor::{DType, TensorKind};
4use serde::{Deserialize, Serialize};
5use std::any::Any;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct KindError {
9    pub msg: String,
10}
11
12/// A flag indicating the tensor kind.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum KindFlag {
15    Float,
16    Int,
17    Bool,
18}
19
20impl KindFlag {
21    /// Returns the kind of the given tensor.
22    pub fn kind<B: Backend, const R: usize, K: TensorKind<B> + 'static>(
23        tensor: &Tensor<B, R, K>
24    ) -> Result<Self, KindError> {
25        let any: &dyn Any = tensor;
26
27        if any.downcast_ref::<Tensor<B, R, Float>>().is_some() {
28            Ok(Self::Float)
29        } else if any.downcast_ref::<Tensor<B, R, Int>>().is_some() {
30            Ok(Self::Int)
31        } else if any.downcast_ref::<Tensor<B, R, Bool>>().is_some() {
32            Ok(Self::Bool)
33        } else {
34            Err(KindError {
35                msg: format!("Unsupported tensor kind: {:?}", K::name()),
36            })
37        }
38    }
39}
40
41impl From<DType> for KindFlag {
42    fn from(val: DType) -> Self {
43        if val.is_float() {
44            KindFlag::Float
45        } else if val.is_int() {
46            KindFlag::Int
47        } else {
48            KindFlag::Bool
49        }
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn test_kind() {
59        type B = burn::backend::Wgpu;
60        let device = Default::default();
61
62        assert_eq!(
63            KindFlag::kind(&Tensor::<B, 2, Float>::ones([2, 3], &device)).unwrap(),
64            KindFlag::Float
65        );
66        assert_eq!(
67            KindFlag::kind(&Tensor::<B, 2, Int>::ones([2, 3], &device)).unwrap(),
68            KindFlag::Int
69        );
70        assert_eq!(
71            KindFlag::kind(&Tensor::<B, 2, Bool>::ones([2, 3], &device)).unwrap(),
72            KindFlag::Bool
73        );
74    }
75}