Skip to main content

lumen_core/dynamic/
integer.rs

1
2use crate::{DType, Dim, Layout, Shape, Tensor, TensorId};
3
4#[derive(Clone)]
5pub enum IntTensor {
6    U8(Tensor<u8>),
7    I32(Tensor<i32>),
8    U32(Tensor<u32>),
9}
10
11impl IntTensor {
12    pub fn id(&self) -> TensorId {
13        match self {
14            Self::U8(t) => t.id(),
15            Self::I32(t) => t.id(),
16            Self::U32(t) => t.id(),
17        }
18    }
19
20    pub fn shape(&self) -> &Shape {
21        match self {
22            Self::U8(t) => t.shape(),
23            Self::I32(t) => t.shape(),
24            Self::U32(t) => t.shape(),
25        }
26    }
27
28    pub fn dtype(&self) -> DType {
29        match self {
30            Self::U8(t) => t.dtype(),
31            Self::I32(t) => t.dtype(),
32            Self::U32(t) => t.dtype(),
33        }
34    }
35
36    pub fn layout(&self) -> &Layout {
37        match self {
38            Self::U8(t) => t.layout(),
39            Self::I32(t) => t.layout(),
40            Self::U32(t) => t.layout(),
41        }
42    }
43
44    pub fn dims(&self) -> &[usize] {
45        match self {
46            Self::U8(t) => t.dims(),
47            Self::I32(t) => t.dims(),
48            Self::U32(t) => t.dims(),
49        }
50    }
51
52    pub fn dim<D: Dim>(&self, dim: D) -> crate::Result<usize> {
53        match self {
54            Self::U8(t) => t.dim(dim),
55            Self::I32(t) => t.dim(dim),
56            Self::U32(t) => t.dim(dim),
57        }
58    }
59
60    pub fn element_count(&self) -> usize {
61        match self {
62            Self::U8(t) => t.element_count(),
63            Self::I32(t) => t.element_count(),
64            Self::U32(t) => t.element_count(),
65        }
66    }
67
68    pub fn is_contiguous(&self) -> bool {
69        match self {
70            Self::U8(t) => t.is_contiguous(),
71            Self::I32(t) => t.is_contiguous(),
72            Self::U32(t) => t.is_contiguous(),
73        }
74    }
75
76    pub fn rank(&self) -> usize {
77        match self {
78            Self::U8(t) => t.rank(),
79            Self::I32(t) => t.rank(),
80            Self::U32(t) => t.rank(),
81        }
82    }
83
84    pub fn flatten_all(&self) -> crate::Result<Self> {
85        match self {
86            Self::U8(t) => t.flatten_all().map(Self::U8),
87            Self::I32(t) => t.flatten_all().map(Self::I32),
88            Self::U32(t) => t.flatten_all().map(Self::U32),
89        }
90    }
91}
92
93impl From<Tensor<u8>> for IntTensor {
94    fn from(value: Tensor<u8>) -> Self {
95        Self::U8(value)
96    }
97}
98
99impl From<&Tensor<u8>> for IntTensor {
100    fn from(value: &Tensor<u8>) -> Self {
101        Self::U8(value.clone())
102    }
103}
104
105impl From<Tensor<i32>> for IntTensor {
106    fn from(value: Tensor<i32>) -> Self {
107        Self::I32(value)
108    }
109}
110
111impl From<&Tensor<i32>> for IntTensor {
112    fn from(value: &Tensor<i32>) -> Self {
113        Self::I32(value.clone())
114    }
115}
116
117impl From<Tensor<u32>> for IntTensor {
118    fn from(value: Tensor<u32>) -> Self {
119        Self::U32(value)
120    }
121}
122
123impl From<&Tensor<u32>> for IntTensor {
124    fn from(value: &Tensor<u32>) -> Self {
125        Self::U32(value.clone())
126    }
127}