Skip to main content

diskann_benchmark_runner/utils/
datatype.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use half::f16;
7use serde::{Deserialize, Serialize};
8
9use crate::dispatcher::{DispatchRule, FailureScore, Map, MatchScore};
10
11/// An enum representation for common DiskANN data types.
12///
13/// [`DispatchRule]`s are defined for each type here and it's corresponding [`Type`].
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum DataType {
17    Float64,
18    Float32,
19    Float16,
20    UInt8,
21    UInt16,
22    UInt32,
23    UInt64,
24    Int8,
25    Int16,
26    Int32,
27    Int64,
28    Bool,
29}
30
31impl DataType {
32    /// Return the string representation of the enum.
33    ///
34    /// This is more efficient than using `serde` directly.
35    pub const fn as_str(self) -> &'static str {
36        match self {
37            Self::Float64 => "float64",
38            Self::Float32 => "float32",
39            Self::Float16 => "float16",
40            Self::UInt8 => "uint8",
41            Self::UInt16 => "uint16",
42            Self::UInt32 => "uint32",
43            Self::UInt64 => "uint64",
44            Self::Int8 => "int8",
45            Self::Int16 => "int16",
46            Self::Int32 => "int32",
47            Self::Int64 => "int64",
48            Self::Bool => "bool",
49        }
50    }
51}
52
53impl std::fmt::Display for DataType {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "{}", self.as_str())
56    }
57}
58
59/// Lifting the enum `DataType` into the Rust type domain.
60#[derive(Debug, Default, Clone, Copy)]
61pub struct Type<T>(std::marker::PhantomData<T>);
62
63/// The `Type` meta variable maps to itself.
64impl<T: 'static> Map for Type<T> {
65    type Type<'a> = Self;
66}
67
68pub const MATCH_FAIL: FailureScore = FailureScore(1000);
69
70macro_rules! dispatch_rule {
71    ($type:ty, $var:ident) => {
72        impl DispatchRule<DataType> for Type<$type> {
73            type Error = std::convert::Infallible;
74
75            fn try_match(from: &DataType) -> Result<MatchScore, FailureScore> {
76                match from {
77                    DataType::$var => Ok(MatchScore(0)),
78                    _ => Err(MATCH_FAIL),
79                }
80            }
81
82            fn convert(from: DataType) -> Result<Self, Self::Error> {
83                assert!(matches!(from, DataType::$var), "invalid dispatch");
84                Ok(Self::default())
85            }
86
87            fn description(
88                f: &mut std::fmt::Formatter<'_>,
89                v: Option<&DataType>,
90            ) -> std::fmt::Result {
91                match v {
92                    Some(v) => match Self::try_match(v) {
93                        Ok(_) => write!(f, "successful match"),
94                        Err(_) => write!(
95                            f,
96                            "expected \"{}\" but found {:?}",
97                            stringify!($var).to_lowercase(),
98                            v.as_str()
99                        ),
100                    },
101                    None => write!(f, "{}", stringify!($var).to_lowercase()),
102                }
103            }
104        }
105
106        impl DispatchRule<&DataType> for Type<$type> {
107            type Error = std::convert::Infallible;
108            fn try_match(from: &&DataType) -> Result<MatchScore, FailureScore> {
109                Self::try_match(*from)
110            }
111            fn convert(from: &DataType) -> Result<Self, Self::Error> {
112                Self::convert(*from)
113            }
114            fn description(
115                f: &mut std::fmt::Formatter<'_>,
116                v: Option<&&DataType>,
117            ) -> std::fmt::Result {
118                Self::description(f, v.map(|v| *v))
119            }
120        }
121    };
122}
123
124dispatch_rule!(f64, Float64);
125dispatch_rule!(f32, Float32);
126dispatch_rule!(f16, Float16);
127dispatch_rule!(u8, UInt8);
128dispatch_rule!(u16, UInt16);
129dispatch_rule!(u32, UInt32);
130dispatch_rule!(u64, UInt64);
131dispatch_rule!(i8, Int8);
132dispatch_rule!(i16, Int16);
133dispatch_rule!(i32, Int32);
134dispatch_rule!(i64, Int64);
135dispatch_rule!(bool, Bool);
136
137///////////
138// Tests //
139///////////
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    use crate::dispatcher::{Description, Why};
146
147    #[test]
148    fn test_as_str() {
149        let test = |x: DataType| {
150            assert_eq!(format!("{}", x), x.as_str());
151            assert_eq!(
152                x.as_str(),
153                serde_json::to_string(&x).unwrap().trim_matches('"')
154            );
155        };
156
157        test(DataType::Float32);
158        test(DataType::Float16);
159        test(DataType::UInt8);
160        test(DataType::UInt16);
161        test(DataType::UInt32);
162        test(DataType::UInt64);
163        test(DataType::Int8);
164        test(DataType::Int16);
165        test(DataType::Int32);
166        test(DataType::Int64);
167        test(DataType::Bool);
168    }
169
170    fn test_description<T>(typename: &str)
171    where
172        Type<T>: DispatchRule<DataType>,
173    {
174        assert_eq!(
175            Description::<DataType, Type<T>>::new().to_string(),
176            typename
177        );
178    }
179
180    fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
181    where
182        Type<T>: DispatchRule<DataType>,
183    {
184        assert_eq!(<Type<T>>::try_match(&datatype), Err(MATCH_FAIL));
185        assert_eq!(
186            Why::<DataType, Type<T>>::new(&datatype).to_string(),
187            format!("expected \"{}\" but found \"{}\"", typename, datatype)
188        );
189    }
190
191    fn test_dispatch_success<T>(datatype: DataType)
192    where
193        Type<T>: DispatchRule<DataType>,
194    {
195        assert_eq!(<Type<T>>::try_match(&datatype), Ok(MatchScore(0)));
196        assert_eq!(
197            Why::<DataType, Type<T>>::new(&datatype).to_string(),
198            "successful match",
199        );
200    }
201
202    macro_rules! type_test {
203        ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
204            #[test]
205            fn $test() {
206                let typename = stringify!($var).to_lowercase();
207
208                test_description::<$T>(&typename);
209                test_dispatch_success::<$T>(DataType::$var);
210                $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
211            }
212        }
213    }
214
215    type_test!(test_f64, f64, Float64, Float16, UInt8);
216    type_test!(test_f32, f32, Float32, Float16, UInt8);
217    type_test!(test_f16, f16, Float16, UInt8, UInt16);
218    type_test!(test_u8, u8, UInt8, UInt16, UInt32);
219    type_test!(test_u16, u16, UInt16, UInt32, UInt64);
220    type_test!(test_u32, u32, UInt32, UInt64, Int8);
221    type_test!(test_u64, u64, UInt64, Int8, Int16);
222    type_test!(test_i8, i8, Int8, Int16, Int32);
223    type_test!(test_i16, i16, Int16, Int32, Int64);
224    type_test!(test_i32, i32, Int32, Int64, Bool);
225    type_test!(test_i64, i64, Int64, Bool, Float32);
226    type_test!(test_bool, bool, Bool, Float32, Float16);
227}