burn_rmexp_dyntensor/
kind.rs1use burn::Tensor;
2use burn::prelude::{Backend, Bool, Float, Int};
3use burn::tensor::{DType, TensorKind};
4use serde::{Deserialize, Serialize};
5use std::any::Any;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct KindError {
9 pub msg: String,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum KindFlag {
15 Float,
16 Int,
17 Bool,
18}
19
20impl KindFlag {
21 pub fn kind<B: Backend, const R: usize, K: TensorKind<B> + 'static>(
23 tensor: &Tensor<B, R, K>
24 ) -> Result<Self, KindError> {
25 let any: &dyn Any = tensor;
26
27 if any.downcast_ref::<Tensor<B, R, Float>>().is_some() {
28 Ok(Self::Float)
29 } else if any.downcast_ref::<Tensor<B, R, Int>>().is_some() {
30 Ok(Self::Int)
31 } else if any.downcast_ref::<Tensor<B, R, Bool>>().is_some() {
32 Ok(Self::Bool)
33 } else {
34 Err(KindError {
35 msg: format!("Unsupported tensor kind: {:?}", K::name()),
36 })
37 }
38 }
39}
40
41impl From<DType> for KindFlag {
42 fn from(val: DType) -> Self {
43 if val.is_float() {
44 KindFlag::Float
45 } else if val.is_int() {
46 KindFlag::Int
47 } else {
48 KindFlag::Bool
49 }
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn test_kind() {
59 type B = burn::backend::Wgpu;
60 let device = Default::default();
61
62 assert_eq!(
63 KindFlag::kind(&Tensor::<B, 2, Float>::ones([2, 3], &device)).unwrap(),
64 KindFlag::Float
65 );
66 assert_eq!(
67 KindFlag::kind(&Tensor::<B, 2, Int>::ones([2, 3], &device)).unwrap(),
68 KindFlag::Int
69 );
70 assert_eq!(
71 KindFlag::kind(&Tensor::<B, 2, Bool>::ones([2, 3], &device)).unwrap(),
72 KindFlag::Bool
73 );
74 }
75}