diskann_benchmark_runner/utils/
datatype.rs1use half::f16;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum DataType {
15 Float64,
16 Float32,
17 Float16,
18 UInt8,
19 UInt16,
20 UInt32,
21 UInt64,
22 Int8,
23 Int16,
24 Int32,
25 Int64,
26 Bool,
27}
28
29impl DataType {
30 pub const fn as_str(self) -> &'static str {
34 match self {
35 Self::Float64 => "float64",
36 Self::Float32 => "float32",
37 Self::Float16 => "float16",
38 Self::UInt8 => "uint8",
39 Self::UInt16 => "uint16",
40 Self::UInt32 => "uint32",
41 Self::UInt64 => "uint64",
42 Self::Int8 => "int8",
43 Self::Int16 => "int16",
44 Self::Int32 => "int32",
45 Self::Int64 => "int64",
46 Self::Bool => "bool",
47 }
48 }
49}
50
51impl std::fmt::Display for DataType {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}", self.as_str())
54 }
55}
56
57pub trait AsDataType: 'static {
59 const DATA_TYPE: DataType;
61
62 fn is_match(data_type: DataType) -> bool {
64 data_type == Self::DATA_TYPE
65 }
66
67 fn describe(data_type: DataType) -> Describe {
82 if data_type == Self::DATA_TYPE {
83 Describe(DescribeInner::Match)
84 } else {
85 Describe(DescribeInner::Mismatch {
86 expected: Self::DATA_TYPE,
87 got: data_type,
88 })
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
95pub struct Describe(DescribeInner);
96
97impl Describe {
98 pub fn is_match(&self) -> bool {
100 matches!(self.0, DescribeInner::Match)
101 }
102}
103
104impl std::fmt::Display for Describe {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 self.0.fmt(f)
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
111enum DescribeInner {
112 Match,
113 Mismatch { expected: DataType, got: DataType },
114}
115
116impl std::fmt::Display for DescribeInner {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 Self::Match => write!(f, "successful match"),
120 Self::Mismatch { expected, got } => {
121 write!(f, "expected \"{}\" but found \"{}\"", expected, got)
122 }
123 }
124 }
125}
126
127macro_rules! as_data_type {
128 ($type:ty, $var:ident) => {
129 impl AsDataType for $type {
130 const DATA_TYPE: DataType = DataType::$var;
131 }
132 };
133}
134
135as_data_type!(f64, Float64);
136as_data_type!(f32, Float32);
137as_data_type!(f16, Float16);
138as_data_type!(u8, UInt8);
139as_data_type!(u16, UInt16);
140as_data_type!(u32, UInt32);
141as_data_type!(u64, UInt64);
142as_data_type!(i8, Int8);
143as_data_type!(i16, Int16);
144as_data_type!(i32, Int32);
145as_data_type!(i64, Int64);
146as_data_type!(bool, Bool);
147
148#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_as_str() {
158 let test = |x: DataType| {
159 assert_eq!(format!("{}", x), x.as_str());
160 assert_eq!(
161 x.as_str(),
162 serde_json::to_string(&x).unwrap().trim_matches('"')
163 );
164 };
165
166 test(DataType::Float32);
167 test(DataType::Float16);
168 test(DataType::UInt8);
169 test(DataType::UInt16);
170 test(DataType::UInt32);
171 test(DataType::UInt64);
172 test(DataType::Int8);
173 test(DataType::Int16);
174 test(DataType::Int32);
175 test(DataType::Int64);
176 test(DataType::Bool);
177 }
178
179 fn test_description<T>(typename: &str)
180 where
181 T: AsDataType,
182 {
183 assert_eq!(T::DATA_TYPE.as_str(), typename);
184 }
185
186 fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
187 where
188 T: AsDataType,
189 {
190 assert!(!T::is_match(datatype));
191 assert_eq!(
192 T::describe(datatype).to_string(),
193 format!("expected \"{}\" but found \"{}\"", typename, datatype)
194 );
195 }
196
197 fn test_dispatch_success<T>(datatype: DataType)
198 where
199 T: AsDataType,
200 {
201 assert!(T::is_match(datatype));
202 assert_eq!(T::describe(datatype).to_string(), "successful match",);
203 }
204
205 macro_rules! type_test {
206 ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
207 #[test]
208 fn $test() {
209 let typename = stringify!($var).to_lowercase();
210
211 test_description::<$T>(&typename);
212 test_dispatch_success::<$T>(DataType::$var);
213 $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
214 }
215 }
216 }
217
218 type_test!(test_f64, f64, Float64, Float16, UInt8);
219 type_test!(test_f32, f32, Float32, Float16, UInt8);
220 type_test!(test_f16, f16, Float16, UInt8, UInt16);
221 type_test!(test_u8, u8, UInt8, UInt16, UInt32);
222 type_test!(test_u16, u16, UInt16, UInt32, UInt64);
223 type_test!(test_u32, u32, UInt32, UInt64, Int8);
224 type_test!(test_u64, u64, UInt64, Int8, Int16);
225 type_test!(test_i8, i8, Int8, Int16, Int32);
226 type_test!(test_i16, i16, Int16, Int32, Int64);
227 type_test!(test_i32, i32, Int32, Int64, Bool);
228 type_test!(test_i64, i64, Int64, Bool, Float32);
229 type_test!(test_bool, bool, Bool, Float32, Float16);
230}