Skip to main content

radiate_utils/datatype/
dtype.rs

1use super::Scalar;
2use crate::{Primitive, SmallStr};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt::Display;
6
7pub mod dtype_names {
8    pub const NULL: &str = "null";
9    pub const BOOLEAN: &str = "boolean";
10    pub const UINT8: &str = "uint8";
11    pub const UINT16: &str = "uint16";
12    pub const UINT32: &str = "uint32";
13    pub const UINT64: &str = "uint64";
14    pub const UINT128: &str = "uint128";
15    pub const INT8: &str = "int8";
16    pub const INT16: &str = "int16";
17    pub const INT32: &str = "int32";
18    pub const INT64: &str = "int64";
19    pub const INT128: &str = "int128";
20    pub const FLOAT32: &str = "float32";
21    pub const FLOAT64: &str = "float64";
22    pub const USIZE: &str = "usize";
23    pub const BINARY: &str = "binary";
24    pub const CHAR: &str = "char";
25    pub const STRING: &str = "string";
26    pub const DURATION: &str = "duration";
27    pub const VEC: &str = "vec";
28    pub const STRUCT: &str = "struct";
29    pub const DICT: &str = "dict";
30}
31
32pub trait DType {
33    fn dtype(&self) -> DataType;
34}
35
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
38pub enum DataType {
39    #[default]
40    Null,
41
42    UInt8,
43    UInt16,
44    UInt32,
45    UInt64,
46    UInt128,
47
48    Int8,
49    Int16,
50    Int32,
51    Int64,
52    Int128,
53
54    Float32,
55    Float64,
56
57    Usize,
58
59    Duration,
60
61    Boolean,
62
63    Char,
64    String,
65
66    List(Box<DataType>),
67    Dict(Vec<(SmallStr, DataType)>),
68    Struct(SmallStr, Vec<(SmallStr, DataType)>),
69}
70
71impl DataType {
72    pub fn is_nested(&self) -> bool {
73        use DataType as D;
74        matches!(self, D::List(_) | D::Dict(_) | D::Struct(_, _))
75    }
76
77    pub fn is_numeric(&self) -> bool {
78        use DataType as D;
79        matches!(
80            self,
81            D::Int8
82                | D::Int16
83                | D::Int32
84                | D::Int64
85                | D::Int128
86                | D::UInt8
87                | D::UInt16
88                | D::UInt32
89                | D::UInt64
90                | D::Float32
91                | D::Float64
92                | D::Usize
93        )
94    }
95
96    pub fn is_primitive(&self) -> bool {
97        use DataType as D;
98        matches!(
99            self,
100            D::Null
101                | D::Boolean
102                | D::Int8
103                | D::Int16
104                | D::Int32
105                | D::Int64
106                | D::Int128
107                | D::UInt8
108                | D::UInt16
109                | D::UInt32
110                | D::UInt64
111                | D::Float32
112                | D::Float64
113                | D::Usize
114        )
115    }
116
117    pub fn max(&self) -> Option<Scalar> {
118        use DataType as D;
119        match self {
120            D::Int8 => Some(Scalar::from(<i8 as Primitive>::MAX)),
121            D::Int16 => Some(Scalar::from(<i16 as Primitive>::MAX)),
122            D::Int32 => Some(Scalar::from(<i32 as Primitive>::MAX)),
123            D::Int64 => Some(Scalar::from(<i64 as Primitive>::MAX)),
124            D::Int128 => Some(Scalar::from(<i128 as Primitive>::MAX)),
125            D::UInt8 => Some(Scalar::from(<u8 as Primitive>::MAX)),
126            D::UInt16 => Some(Scalar::from(<u16 as Primitive>::MAX)),
127            D::UInt32 => Some(Scalar::from(<u32 as Primitive>::MAX)),
128            D::UInt64 => Some(Scalar::from(<u64 as Primitive>::MAX)),
129            D::UInt128 => Some(Scalar::from(<u128 as Primitive>::MAX)),
130            D::Float32 => Some(Scalar::from(<f32 as Primitive>::MAX)),
131            D::Float64 => Some(Scalar::from(<f64 as Primitive>::MAX)),
132            _ => None,
133        }
134    }
135
136    pub fn min(&self) -> Option<Scalar> {
137        use DataType as D;
138        match self {
139            D::Int8 => Some(Scalar::from(<i8 as Primitive>::MIN)),
140            D::Int16 => Some(Scalar::from(<i16 as Primitive>::MIN)),
141            D::Int32 => Some(Scalar::from(<i32 as Primitive>::MIN)),
142            D::Int64 => Some(Scalar::from(<i64 as Primitive>::MIN)),
143            D::Int128 => Some(Scalar::from(<i128 as Primitive>::MIN)),
144            D::UInt8 => Some(Scalar::from(<u8 as Primitive>::MIN)),
145            D::UInt16 => Some(Scalar::from(<u16 as Primitive>::MIN)),
146            D::UInt32 => Some(Scalar::from(<u32 as Primitive>::MIN)),
147            D::UInt64 => Some(Scalar::from(<u64 as Primitive>::MIN)),
148            D::UInt128 => Some(Scalar::from(<u128 as Primitive>::MIN)),
149            D::Float32 => Some(Scalar::from(<f32 as Primitive>::MIN)),
150            D::Float64 => Some(Scalar::from(<f64 as Primitive>::MIN)),
151            _ => None,
152        }
153    }
154
155    pub fn primitive_bounds(&self) -> Option<(Scalar, Scalar)> {
156        match (self.min(), self.max()) {
157            (Some(min), Some(max)) => Some((min, max)),
158            _ => None,
159        }
160    }
161}
162
163impl From<String> for DataType {
164    fn from(value: String) -> Self {
165        match value.trim().to_lowercase().as_str() {
166            dtype_names::NULL => DataType::Null,
167
168            dtype_names::UINT8 => DataType::UInt8,
169            dtype_names::UINT16 => DataType::UInt16,
170            dtype_names::UINT32 => DataType::UInt32,
171            dtype_names::UINT64 => DataType::UInt64,
172            dtype_names::UINT128 => DataType::UInt128,
173
174            dtype_names::INT8 => DataType::Int8,
175            dtype_names::INT16 => DataType::Int16,
176            dtype_names::INT32 => DataType::Int32,
177            dtype_names::INT64 => DataType::Int64,
178            dtype_names::INT128 => DataType::Int128,
179
180            dtype_names::FLOAT32 => DataType::Float32,
181            dtype_names::FLOAT64 => DataType::Float64,
182
183            dtype_names::USIZE => DataType::Usize,
184
185            dtype_names::BOOLEAN => DataType::Boolean,
186
187            dtype_names::CHAR => DataType::Char,
188            dtype_names::STRING => DataType::String,
189
190            dtype_names::VEC => DataType::List(Box::new(DataType::Null)),
191            dtype_names::STRUCT => DataType::Struct(SmallStr::from_static("struct"), Vec::new()),
192
193            dtype_names::DICT => DataType::Dict(Vec::new()),
194
195            _ => panic!("Unknown data type: {}", value),
196        }
197    }
198}
199
200impl Display for DataType {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        match self {
203            DataType::Null => write!(f, "{}", dtype_names::NULL)?,
204
205            DataType::UInt8 => write!(f, "{}", dtype_names::UINT8)?,
206            DataType::UInt16 => write!(f, "{}", dtype_names::UINT16)?,
207            DataType::UInt32 => write!(f, "{}", dtype_names::UINT32)?,
208            DataType::UInt64 => write!(f, "{}", dtype_names::UINT64)?,
209            DataType::UInt128 => write!(f, "{}", dtype_names::UINT128)?,
210
211            DataType::Int8 => write!(f, "{}", dtype_names::INT8)?,
212            DataType::Int16 => write!(f, "{}", dtype_names::INT16)?,
213            DataType::Int32 => write!(f, "{}", dtype_names::INT32)?,
214            DataType::Int64 => write!(f, "{}", dtype_names::INT64)?,
215            DataType::Int128 => write!(f, "{}", dtype_names::INT128)?,
216
217            DataType::Float32 => write!(f, "{}", dtype_names::FLOAT32)?,
218            DataType::Float64 => write!(f, "{}", dtype_names::FLOAT64)?,
219
220            DataType::Usize => write!(f, "{}", dtype_names::USIZE)?,
221
222            DataType::Duration => write!(f, "{}", dtype_names::DURATION)?,
223
224            DataType::Boolean => write!(f, "{}", dtype_names::BOOLEAN)?,
225
226            DataType::Char => write!(f, "{}", dtype_names::CHAR)?,
227            DataType::String => write!(f, "{}", dtype_names::STRING)?,
228
229            DataType::List(inner) => write!(f, "{}({})", dtype_names::VEC, inner)?,
230            DataType::Dict(vals) => write!(
231                f,
232                "{}({})",
233                dtype_names::DICT,
234                vals.iter()
235                    .map(|(name, _)| format!("{}", name))
236                    .collect::<Vec<_>>()
237                    .join(", ")
238            )?,
239
240            DataType::Struct(name, fields) => write!(
241                f,
242                "{} {} {{ {} }}",
243                dtype_names::STRUCT,
244                name,
245                fields
246                    .iter()
247                    .map(|(name, _)| format!("{}", name))
248                    .collect::<Vec<_>>()
249                    .join(", ")
250            )?,
251        };
252
253        Ok(())
254    }
255}