diskann_benchmark_runner/utils/
datatype.rs1use half::f16;
7use serde::{Deserialize, Serialize};
8
9use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum DataType {
17 Float64,
18 Float32,
19 Float16,
20 UInt8,
21 UInt16,
22 UInt32,
23 UInt64,
24 Int8,
25 Int16,
26 Int32,
27 Int64,
28 Bool,
29}
30
31impl DataType {
32 pub const fn as_str(self) -> &'static str {
36 match self {
37 Self::Float64 => "float64",
38 Self::Float32 => "float32",
39 Self::Float16 => "float16",
40 Self::UInt8 => "uint8",
41 Self::UInt16 => "uint16",
42 Self::UInt32 => "uint32",
43 Self::UInt64 => "uint64",
44 Self::Int8 => "int8",
45 Self::Int16 => "int16",
46 Self::Int32 => "int32",
47 Self::Int64 => "int64",
48 Self::Bool => "bool",
49 }
50 }
51}
52
53impl std::fmt::Display for DataType {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}", self.as_str())
56 }
57}
58
59#[derive(Debug, Default, Clone, Copy)]
61pub struct Type<T>(std::marker::PhantomData<T>);
62
63pub const MATCH_FAIL: FailureScore = FailureScore(1000);
64
65macro_rules! dispatch_rule {
66 ($type:ty, $var:ident) => {
67 impl DispatchRule<DataType> for Type<$type> {
68 type Error = std::convert::Infallible;
69
70 fn try_match(from: &DataType) -> Result<MatchScore, FailureScore> {
71 match from {
72 DataType::$var => Ok(MatchScore(0)),
73 _ => Err(MATCH_FAIL),
74 }
75 }
76
77 fn convert(from: DataType) -> Result<Self, Self::Error> {
78 assert!(matches!(from, DataType::$var), "invalid dispatch");
79 Ok(Self::default())
80 }
81
82 fn description(
83 f: &mut std::fmt::Formatter<'_>,
84 v: Option<&DataType>,
85 ) -> std::fmt::Result {
86 match v {
87 Some(v) => match Self::try_match(v) {
88 Ok(_) => write!(f, "successful match"),
89 Err(_) => write!(
90 f,
91 "expected \"{}\" but found {:?}",
92 stringify!($var).to_lowercase(),
93 v.as_str()
94 ),
95 },
96 None => write!(f, "{}", stringify!($var).to_lowercase()),
97 }
98 }
99 }
100
101 impl DispatchRule<&DataType> for Type<$type> {
102 type Error = std::convert::Infallible;
103 fn try_match(from: &&DataType) -> Result<MatchScore, FailureScore> {
104 Self::try_match(*from)
105 }
106 fn convert(from: &DataType) -> Result<Self, Self::Error> {
107 Self::convert(*from)
108 }
109 fn description(
110 f: &mut std::fmt::Formatter<'_>,
111 v: Option<&&DataType>,
112 ) -> std::fmt::Result {
113 Self::description(f, v.map(|v| *v))
114 }
115 }
116 };
117}
118
119dispatch_rule!(f64, Float64);
120dispatch_rule!(f32, Float32);
121dispatch_rule!(f16, Float16);
122dispatch_rule!(u8, UInt8);
123dispatch_rule!(u16, UInt16);
124dispatch_rule!(u32, UInt32);
125dispatch_rule!(u64, UInt64);
126dispatch_rule!(i8, Int8);
127dispatch_rule!(i16, Int16);
128dispatch_rule!(i32, Int32);
129dispatch_rule!(i64, Int64);
130dispatch_rule!(bool, Bool);
131
132#[cfg(test)]
137mod tests {
138 use super::*;
139
140 use crate::dispatcher::{Description, Why};
141
142 #[test]
143 fn test_as_str() {
144 let test = |x: DataType| {
145 assert_eq!(format!("{}", x), x.as_str());
146 assert_eq!(
147 x.as_str(),
148 serde_json::to_string(&x).unwrap().trim_matches('"')
149 );
150 };
151
152 test(DataType::Float32);
153 test(DataType::Float16);
154 test(DataType::UInt8);
155 test(DataType::UInt16);
156 test(DataType::UInt32);
157 test(DataType::UInt64);
158 test(DataType::Int8);
159 test(DataType::Int16);
160 test(DataType::Int32);
161 test(DataType::Int64);
162 test(DataType::Bool);
163 }
164
165 fn test_description<T>(typename: &str)
166 where
167 Type<T>: DispatchRule<DataType>,
168 {
169 assert_eq!(
170 Description::<DataType, Type<T>>::new().to_string(),
171 typename
172 );
173 }
174
175 fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
176 where
177 Type<T>: DispatchRule<DataType>,
178 {
179 assert_eq!(<Type<T>>::try_match(&datatype), Err(MATCH_FAIL));
180 assert_eq!(
181 Why::<DataType, Type<T>>::new(&datatype).to_string(),
182 format!("expected \"{}\" but found \"{}\"", typename, datatype)
183 );
184 }
185
186 fn test_dispatch_success<T>(datatype: DataType)
187 where
188 Type<T>: DispatchRule<DataType>,
189 {
190 assert_eq!(<Type<T>>::try_match(&datatype), Ok(MatchScore(0)));
191 assert_eq!(
192 Why::<DataType, Type<T>>::new(&datatype).to_string(),
193 "successful match",
194 );
195 }
196
197 macro_rules! type_test {
198 ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
199 #[test]
200 fn $test() {
201 let typename = stringify!($var).to_lowercase();
202
203 test_description::<$T>(&typename);
204 test_dispatch_success::<$T>(DataType::$var);
205 $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
206 }
207 }
208 }
209
210 type_test!(test_f64, f64, Float64, Float16, UInt8);
211 type_test!(test_f32, f32, Float32, Float16, UInt8);
212 type_test!(test_f16, f16, Float16, UInt8, UInt16);
213 type_test!(test_u8, u8, UInt8, UInt16, UInt32);
214 type_test!(test_u16, u16, UInt16, UInt32, UInt64);
215 type_test!(test_u32, u32, UInt32, UInt64, Int8);
216 type_test!(test_u64, u64, UInt64, Int8, Int16);
217 type_test!(test_i8, i8, Int8, Int16, Int32);
218 type_test!(test_i16, i16, Int16, Int32, Int64);
219 type_test!(test_i32, i32, Int32, Int64, Bool);
220 type_test!(test_i64, i64, Int64, Bool, Float32);
221 type_test!(test_bool, bool, Bool, Float32, Float16);
222}