lumen_core/dynamic/
mod.rs1mod float;
2mod integer;
3pub use float::*;
4pub use integer::*;
5use crate::{DType, Shape, Tensor, WithDType};
6use paste::paste;
7
8#[derive(Clone)]
9pub enum DynTensor {
10 Bool(Tensor<bool>),
11 F32(Tensor<f32>),
12 F64(Tensor<f64>),
13 I32(Tensor<i32>),
14 U32(Tensor<u32>),
15 U8(Tensor<u8>),
16}
17
18impl DynTensor {
19 pub fn dtype(&self) -> DType {
20 match self {
21 Self::Bool(_) => DType::Bool,
22 Self::F32(_) => DType::F32,
23 Self::F64(_) => DType::F64,
24 Self::U8(_) => DType::U8,
25 Self::I32(_) => DType::I32,
26 Self::U32(_) => DType::U32,
27 }
28 }
29
30 pub fn shape(&self) -> &Shape {
31 match self {
32 Self::Bool(t) => t.shape(),
33 Self::F32(t) => t.shape(),
34 Self::F64(t) => t.shape(),
35 Self::U8(t) => t.shape(),
36 Self::I32(t) => t.shape(),
37 Self::U32(t) => t.shape(),
38 }
39 }
40
41 pub fn as_tensor<T: WithDType>(&self) -> crate::Result<Tensor<T>> {
42 T::from_dyn(self)
43 }
44}
45
46macro_rules! impl_convert_with_type {
47 ($variant:ident, $inner:ty) => {
48 paste! {
49 impl DynTensor {
50 pub fn [< is_ $inner >](&self) -> bool {
51 match self {
52 Self::$variant(_) => true,
53 _ => false,
54 }
55 }
56
57 pub fn [< as_ $inner >](&self) -> Option<Tensor<$inner>> {
58 match self {
59 Self::$variant(t) => Some(t.clone()),
60 _ => None,
61 }
62 }
63 }
64
65 impl From<Tensor<$inner>> for DynTensor {
66 fn from(t: Tensor<$inner>) -> Self {
67 DynTensor::$variant(t)
68 }
69 }
70
71 impl From<&Tensor<$inner>> for DynTensor {
72 fn from(t: &Tensor<$inner>) -> Self {
73 DynTensor::$variant(t.clone())
74 }
75 }
76
77 impl TryFrom<DynTensor> for Tensor<$inner> {
78 type Error = crate::Error;
79
80 fn try_from(value: DynTensor) -> Result<Self, Self::Error> {
81 match value {
82 DynTensor::$variant(t) => Ok(t),
83 _ => Err(crate::Error::UnexpectedDType { msg: "in dyn as", expected: DType::$variant, got: value.dtype() }),
84 }
85 }
86 }
87 }
88 };
89}
90
91impl_convert_with_type!(Bool, bool);
92impl_convert_with_type!(F32, f32);
93impl_convert_with_type!(F64, f64);
94impl_convert_with_type!(I32, i32);
95impl_convert_with_type!(U32, u32);
96impl_convert_with_type!(U8, u8);
97