Skip to main content

lumen_core/dynamic/
mod.rs

1mod float;
2mod integer;
3pub use float::*;
4pub use integer::*;
5use crate::{DType, Shape, Tensor, WithDType};
6use paste::paste;
7
8#[derive(Clone)]
9pub enum DynTensor {
10    Bool(Tensor<bool>),
11    F32(Tensor<f32>),
12    F64(Tensor<f64>),
13    I32(Tensor<i32>),
14    U32(Tensor<u32>),
15    U8(Tensor<u8>),
16}
17
18impl DynTensor {
19    pub fn dtype(&self) -> DType {
20        match self {
21            Self::Bool(_) => DType::Bool,
22            Self::F32(_) => DType::F32,
23            Self::F64(_) => DType::F64,
24            Self::U8(_) => DType::U8,
25            Self::I32(_) => DType::I32,
26            Self::U32(_) => DType::U32,
27        }
28    }
29
30    pub fn shape(&self) -> &Shape {
31        match self {
32            Self::Bool(t) => t.shape(),
33            Self::F32(t) => t.shape(),
34            Self::F64(t) => t.shape(),
35            Self::U8(t) => t.shape(),
36            Self::I32(t) => t.shape(),
37            Self::U32(t) => t.shape(),
38        }
39    }
40
41    pub fn as_tensor<T: WithDType>(&self) -> crate::Result<Tensor<T>> {
42        T::from_dyn(self)
43    }
44}
45
46macro_rules! impl_convert_with_type {
47    ($variant:ident, $inner:ty) => {
48        paste! {
49            impl DynTensor {
50                pub fn [< is_ $inner >](&self) -> bool {
51                    match self {
52                        Self::$variant(_) => true,
53                        _ => false,
54                    }
55                }
56
57                pub fn [< as_ $inner >](&self) -> Option<Tensor<$inner>> {
58                    match self {
59                        Self::$variant(t) => Some(t.clone()),
60                        _ => None,
61                    }
62                }
63            }
64    
65            impl From<Tensor<$inner>> for DynTensor {
66                fn from(t: Tensor<$inner>) -> Self {
67                    DynTensor::$variant(t)
68                }
69            }
70
71            impl From<&Tensor<$inner>> for DynTensor {
72                fn from(t: &Tensor<$inner>) -> Self {
73                    DynTensor::$variant(t.clone())
74                }
75            }
76    
77            impl TryFrom<DynTensor> for Tensor<$inner> {
78                type Error = crate::Error;
79    
80                fn try_from(value: DynTensor) -> Result<Self, Self::Error> {
81                    match value {
82                        DynTensor::$variant(t) => Ok(t),
83                        _ => Err(crate::Error::UnexpectedDType { msg: "in dyn as", expected: DType::$variant, got: value.dtype() }),
84                    }
85                }
86            }
87        }
88    };
89}
90
91impl_convert_with_type!(Bool, bool);
92impl_convert_with_type!(F32, f32);
93impl_convert_with_type!(F64, f64);
94impl_convert_with_type!(I32, i32);
95impl_convert_with_type!(U32, u32);
96impl_convert_with_type!(U8, u8);
97