Skip to main content

diskann_benchmark_runner/utils/
datatype.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use half::f16;
7use serde::{Deserialize, Serialize};
8
9/// An enum representation for common DiskANN data types.
10///
11/// See also: [`AsDataType`].
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum DataType {
15    Float64,
16    Float32,
17    Float16,
18    UInt8,
19    UInt16,
20    UInt32,
21    UInt64,
22    Int8,
23    Int16,
24    Int32,
25    Int64,
26    Bool,
27}
28
29impl DataType {
30    /// Return the string representation of the enum.
31    ///
32    /// This is more efficient than using `serde` directly.
33    pub const fn as_str(self) -> &'static str {
34        match self {
35            Self::Float64 => "float64",
36            Self::Float32 => "float32",
37            Self::Float16 => "float16",
38            Self::UInt8 => "uint8",
39            Self::UInt16 => "uint16",
40            Self::UInt32 => "uint32",
41            Self::UInt64 => "uint64",
42            Self::Int8 => "int8",
43            Self::Int16 => "int16",
44            Self::Int32 => "int32",
45            Self::Int64 => "int64",
46            Self::Bool => "bool",
47        }
48    }
49}
50
51impl std::fmt::Display for DataType {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}", self.as_str())
54    }
55}
56
57/// Associate a primitive type `T` with a [`DataType`] enum variant.
58pub trait AsDataType: 'static {
59    /// The [`DataType`] this type is associated with.
60    const DATA_TYPE: DataType;
61
62    /// Return `true` only if `data_type == Self::DATA_TYPE`.
63    fn is_match(data_type: DataType) -> bool {
64        data_type == Self::DATA_TYPE
65    }
66
67    /// Return a [`std::fmt::Display`] compatible struct describing the match with `data_type`.
68    /// ```
69    /// use diskann_benchmark_runner::utils::datatype::{DataType, AsDataType};
70    ///
71    /// // Matched data type.
72    /// let desc = f32::describe(DataType::Float32);
73    /// assert!(desc.is_match());
74    /// assert_eq!(desc.to_string(), "successful match");
75    ///
76    /// // Mismatched data type.
77    /// let desc = f32::describe(DataType::Float16);
78    /// assert!(!desc.is_match());
79    /// assert_eq!(desc.to_string(), "expected \"float32\" but found \"float16\"");
80    /// ```
81    fn describe(data_type: DataType) -> Describe {
82        if data_type == Self::DATA_TYPE {
83            Describe(DescribeInner::Match)
84        } else {
85            Describe(DescribeInner::Mismatch {
86                expected: Self::DATA_TYPE,
87                got: data_type,
88            })
89        }
90    }
91}
92
93/// A [`std::fmt::Display`] compatible result for [`AsDataType::describe`].
94#[derive(Debug, Clone, Copy)]
95pub struct Describe(DescribeInner);
96
97impl Describe {
98    /// Return `true` if the data type match was successful.
99    pub fn is_match(&self) -> bool {
100        matches!(self.0, DescribeInner::Match)
101    }
102}
103
104impl std::fmt::Display for Describe {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        self.0.fmt(f)
107    }
108}
109
110#[derive(Debug, Clone, Copy)]
111enum DescribeInner {
112    Match,
113    Mismatch { expected: DataType, got: DataType },
114}
115
116impl std::fmt::Display for DescribeInner {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Match => write!(f, "successful match"),
120            Self::Mismatch { expected, got } => {
121                write!(f, "expected \"{}\" but found \"{}\"", expected, got)
122            }
123        }
124    }
125}
126
127macro_rules! as_data_type {
128    ($type:ty, $var:ident) => {
129        impl AsDataType for $type {
130            const DATA_TYPE: DataType = DataType::$var;
131        }
132    };
133}
134
135as_data_type!(f64, Float64);
136as_data_type!(f32, Float32);
137as_data_type!(f16, Float16);
138as_data_type!(u8, UInt8);
139as_data_type!(u16, UInt16);
140as_data_type!(u32, UInt32);
141as_data_type!(u64, UInt64);
142as_data_type!(i8, Int8);
143as_data_type!(i16, Int16);
144as_data_type!(i32, Int32);
145as_data_type!(i64, Int64);
146as_data_type!(bool, Bool);
147
148///////////
149// Tests //
150///////////
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_as_str() {
158        let test = |x: DataType| {
159            assert_eq!(format!("{}", x), x.as_str());
160            assert_eq!(
161                x.as_str(),
162                serde_json::to_string(&x).unwrap().trim_matches('"')
163            );
164        };
165
166        test(DataType::Float32);
167        test(DataType::Float16);
168        test(DataType::UInt8);
169        test(DataType::UInt16);
170        test(DataType::UInt32);
171        test(DataType::UInt64);
172        test(DataType::Int8);
173        test(DataType::Int16);
174        test(DataType::Int32);
175        test(DataType::Int64);
176        test(DataType::Bool);
177    }
178
179    fn test_description<T>(typename: &str)
180    where
181        T: AsDataType,
182    {
183        assert_eq!(T::DATA_TYPE.as_str(), typename);
184    }
185
186    fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
187    where
188        T: AsDataType,
189    {
190        assert!(!T::is_match(datatype));
191        assert_eq!(
192            T::describe(datatype).to_string(),
193            format!("expected \"{}\" but found \"{}\"", typename, datatype)
194        );
195    }
196
197    fn test_dispatch_success<T>(datatype: DataType)
198    where
199        T: AsDataType,
200    {
201        assert!(T::is_match(datatype));
202        assert_eq!(T::describe(datatype).to_string(), "successful match",);
203    }
204
205    macro_rules! type_test {
206        ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
207            #[test]
208            fn $test() {
209                let typename = stringify!($var).to_lowercase();
210
211                test_description::<$T>(&typename);
212                test_dispatch_success::<$T>(DataType::$var);
213                $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
214            }
215        }
216    }
217
218    type_test!(test_f64, f64, Float64, Float16, UInt8);
219    type_test!(test_f32, f32, Float32, Float16, UInt8);
220    type_test!(test_f16, f16, Float16, UInt8, UInt16);
221    type_test!(test_u8, u8, UInt8, UInt16, UInt32);
222    type_test!(test_u16, u16, UInt16, UInt32, UInt64);
223    type_test!(test_u32, u32, UInt32, UInt64, Int8);
224    type_test!(test_u64, u64, UInt64, Int8, Int16);
225    type_test!(test_i8, i8, Int8, Int16, Int32);
226    type_test!(test_i16, i16, Int16, Int32, Int64);
227    type_test!(test_i32, i32, Int32, Int64, Bool);
228    type_test!(test_i64, i64, Int64, Bool, Float32);
229    type_test!(test_bool, bool, Bool, Float32, Float16);
230}