diskann_benchmark_runner/utils/
datatype.rs1use half::f16;
7use serde::{Deserialize, Serialize};
8
9use crate::dispatcher::{DispatchRule, FailureScore, Map, MatchScore};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum DataType {
17 Float64,
18 Float32,
19 Float16,
20 UInt8,
21 UInt16,
22 UInt32,
23 UInt64,
24 Int8,
25 Int16,
26 Int32,
27 Int64,
28 Bool,
29}
30
31impl DataType {
32 pub const fn as_str(self) -> &'static str {
36 match self {
37 Self::Float64 => "float64",
38 Self::Float32 => "float32",
39 Self::Float16 => "float16",
40 Self::UInt8 => "uint8",
41 Self::UInt16 => "uint16",
42 Self::UInt32 => "uint32",
43 Self::UInt64 => "uint64",
44 Self::Int8 => "int8",
45 Self::Int16 => "int16",
46 Self::Int32 => "int32",
47 Self::Int64 => "int64",
48 Self::Bool => "bool",
49 }
50 }
51}
52
53impl std::fmt::Display for DataType {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}", self.as_str())
56 }
57}
58
59#[derive(Debug, Default, Clone, Copy)]
61pub struct Type<T>(std::marker::PhantomData<T>);
62
63impl<T: 'static> Map for Type<T> {
65 type Type<'a> = Self;
66}
67
68pub const MATCH_FAIL: FailureScore = FailureScore(1000);
69
70macro_rules! dispatch_rule {
71 ($type:ty, $var:ident) => {
72 impl DispatchRule<DataType> for Type<$type> {
73 type Error = std::convert::Infallible;
74
75 fn try_match(from: &DataType) -> Result<MatchScore, FailureScore> {
76 match from {
77 DataType::$var => Ok(MatchScore(0)),
78 _ => Err(MATCH_FAIL),
79 }
80 }
81
82 fn convert(from: DataType) -> Result<Self, Self::Error> {
83 assert!(matches!(from, DataType::$var), "invalid dispatch");
84 Ok(Self::default())
85 }
86
87 fn description(
88 f: &mut std::fmt::Formatter<'_>,
89 v: Option<&DataType>,
90 ) -> std::fmt::Result {
91 match v {
92 Some(v) => match Self::try_match(v) {
93 Ok(_) => write!(f, "successful match"),
94 Err(_) => write!(
95 f,
96 "expected \"{}\" but found {:?}",
97 stringify!($var).to_lowercase(),
98 v.as_str()
99 ),
100 },
101 None => write!(f, "{}", stringify!($var).to_lowercase()),
102 }
103 }
104 }
105
106 impl DispatchRule<&DataType> for Type<$type> {
107 type Error = std::convert::Infallible;
108 fn try_match(from: &&DataType) -> Result<MatchScore, FailureScore> {
109 Self::try_match(*from)
110 }
111 fn convert(from: &DataType) -> Result<Self, Self::Error> {
112 Self::convert(*from)
113 }
114 fn description(
115 f: &mut std::fmt::Formatter<'_>,
116 v: Option<&&DataType>,
117 ) -> std::fmt::Result {
118 Self::description(f, v.map(|v| *v))
119 }
120 }
121 };
122}
123
124dispatch_rule!(f64, Float64);
125dispatch_rule!(f32, Float32);
126dispatch_rule!(f16, Float16);
127dispatch_rule!(u8, UInt8);
128dispatch_rule!(u16, UInt16);
129dispatch_rule!(u32, UInt32);
130dispatch_rule!(u64, UInt64);
131dispatch_rule!(i8, Int8);
132dispatch_rule!(i16, Int16);
133dispatch_rule!(i32, Int32);
134dispatch_rule!(i64, Int64);
135dispatch_rule!(bool, Bool);
136
137#[cfg(test)]
142mod tests {
143 use super::*;
144
145 use crate::dispatcher::{Description, Why};
146
147 #[test]
148 fn test_as_str() {
149 let test = |x: DataType| {
150 assert_eq!(format!("{}", x), x.as_str());
151 assert_eq!(
152 x.as_str(),
153 serde_json::to_string(&x).unwrap().trim_matches('"')
154 );
155 };
156
157 test(DataType::Float32);
158 test(DataType::Float16);
159 test(DataType::UInt8);
160 test(DataType::UInt16);
161 test(DataType::UInt32);
162 test(DataType::UInt64);
163 test(DataType::Int8);
164 test(DataType::Int16);
165 test(DataType::Int32);
166 test(DataType::Int64);
167 test(DataType::Bool);
168 }
169
170 fn test_description<T>(typename: &str)
171 where
172 Type<T>: DispatchRule<DataType>,
173 {
174 assert_eq!(
175 Description::<DataType, Type<T>>::new().to_string(),
176 typename
177 );
178 }
179
180 fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
181 where
182 Type<T>: DispatchRule<DataType>,
183 {
184 assert_eq!(<Type<T>>::try_match(&datatype), Err(MATCH_FAIL));
185 assert_eq!(
186 Why::<DataType, Type<T>>::new(&datatype).to_string(),
187 format!("expected \"{}\" but found \"{}\"", typename, datatype)
188 );
189 }
190
191 fn test_dispatch_success<T>(datatype: DataType)
192 where
193 Type<T>: DispatchRule<DataType>,
194 {
195 assert_eq!(<Type<T>>::try_match(&datatype), Ok(MatchScore(0)));
196 assert_eq!(
197 Why::<DataType, Type<T>>::new(&datatype).to_string(),
198 "successful match",
199 );
200 }
201
202 macro_rules! type_test {
203 ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
204 #[test]
205 fn $test() {
206 let typename = stringify!($var).to_lowercase();
207
208 test_description::<$T>(&typename);
209 test_dispatch_success::<$T>(DataType::$var);
210 $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
211 }
212 }
213 }
214
215 type_test!(test_f64, f64, Float64, Float16, UInt8);
216 type_test!(test_f32, f32, Float32, Float16, UInt8);
217 type_test!(test_f16, f16, Float16, UInt8, UInt16);
218 type_test!(test_u8, u8, UInt8, UInt16, UInt32);
219 type_test!(test_u16, u16, UInt16, UInt32, UInt64);
220 type_test!(test_u32, u32, UInt32, UInt64, Int8);
221 type_test!(test_u64, u64, UInt64, Int8, Int16);
222 type_test!(test_i8, i8, Int8, Int16, Int32);
223 type_test!(test_i16, i16, Int16, Int32, Int64);
224 type_test!(test_i32, i32, Int32, Int64, Bool);
225 type_test!(test_i64, i64, Int64, Bool, Float32);
226 type_test!(test_bool, bool, Bool, Float32, Float16);
227}