Skip to main content

diskann_benchmark_runner/utils/
datatype.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use half::f16;
7use serde::{Deserialize, Serialize};
8
9use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
10
11/// An enum representation for common DiskANN data types.
12///
13/// [`DispatchRule]`s are defined for each type here and it's corresponding [`Type`].
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum DataType {
17    Float64,
18    Float32,
19    Float16,
20    UInt8,
21    UInt16,
22    UInt32,
23    UInt64,
24    Int8,
25    Int16,
26    Int32,
27    Int64,
28    Bool,
29}
30
31impl DataType {
32    /// Return the string representation of the enum.
33    ///
34    /// This is more efficient than using `serde` directly.
35    pub const fn as_str(self) -> &'static str {
36        match self {
37            Self::Float64 => "float64",
38            Self::Float32 => "float32",
39            Self::Float16 => "float16",
40            Self::UInt8 => "uint8",
41            Self::UInt16 => "uint16",
42            Self::UInt32 => "uint32",
43            Self::UInt64 => "uint64",
44            Self::Int8 => "int8",
45            Self::Int16 => "int16",
46            Self::Int32 => "int32",
47            Self::Int64 => "int64",
48            Self::Bool => "bool",
49        }
50    }
51}
52
53impl std::fmt::Display for DataType {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "{}", self.as_str())
56    }
57}
58
59/// Lifting the enum `DataType` into the Rust type domain.
60#[derive(Debug, Default, Clone, Copy)]
61pub struct Type<T>(std::marker::PhantomData<T>);
62
63pub const MATCH_FAIL: FailureScore = FailureScore(1000);
64
65macro_rules! dispatch_rule {
66    ($type:ty, $var:ident) => {
67        impl DispatchRule<DataType> for Type<$type> {
68            type Error = std::convert::Infallible;
69
70            fn try_match(from: &DataType) -> Result<MatchScore, FailureScore> {
71                match from {
72                    DataType::$var => Ok(MatchScore(0)),
73                    _ => Err(MATCH_FAIL),
74                }
75            }
76
77            fn convert(from: DataType) -> Result<Self, Self::Error> {
78                assert!(matches!(from, DataType::$var), "invalid dispatch");
79                Ok(Self::default())
80            }
81
82            fn description(
83                f: &mut std::fmt::Formatter<'_>,
84                v: Option<&DataType>,
85            ) -> std::fmt::Result {
86                match v {
87                    Some(v) => match Self::try_match(v) {
88                        Ok(_) => write!(f, "successful match"),
89                        Err(_) => write!(
90                            f,
91                            "expected \"{}\" but found {:?}",
92                            stringify!($var).to_lowercase(),
93                            v.as_str()
94                        ),
95                    },
96                    None => write!(f, "{}", stringify!($var).to_lowercase()),
97                }
98            }
99        }
100
101        impl DispatchRule<&DataType> for Type<$type> {
102            type Error = std::convert::Infallible;
103            fn try_match(from: &&DataType) -> Result<MatchScore, FailureScore> {
104                Self::try_match(*from)
105            }
106            fn convert(from: &DataType) -> Result<Self, Self::Error> {
107                Self::convert(*from)
108            }
109            fn description(
110                f: &mut std::fmt::Formatter<'_>,
111                v: Option<&&DataType>,
112            ) -> std::fmt::Result {
113                Self::description(f, v.map(|v| *v))
114            }
115        }
116    };
117}
118
119dispatch_rule!(f64, Float64);
120dispatch_rule!(f32, Float32);
121dispatch_rule!(f16, Float16);
122dispatch_rule!(u8, UInt8);
123dispatch_rule!(u16, UInt16);
124dispatch_rule!(u32, UInt32);
125dispatch_rule!(u64, UInt64);
126dispatch_rule!(i8, Int8);
127dispatch_rule!(i16, Int16);
128dispatch_rule!(i32, Int32);
129dispatch_rule!(i64, Int64);
130dispatch_rule!(bool, Bool);
131
132///////////
133// Tests //
134///////////
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    use crate::dispatcher::{Description, Why};
141
142    #[test]
143    fn test_as_str() {
144        let test = |x: DataType| {
145            assert_eq!(format!("{}", x), x.as_str());
146            assert_eq!(
147                x.as_str(),
148                serde_json::to_string(&x).unwrap().trim_matches('"')
149            );
150        };
151
152        test(DataType::Float32);
153        test(DataType::Float16);
154        test(DataType::UInt8);
155        test(DataType::UInt16);
156        test(DataType::UInt32);
157        test(DataType::UInt64);
158        test(DataType::Int8);
159        test(DataType::Int16);
160        test(DataType::Int32);
161        test(DataType::Int64);
162        test(DataType::Bool);
163    }
164
165    fn test_description<T>(typename: &str)
166    where
167        Type<T>: DispatchRule<DataType>,
168    {
169        assert_eq!(
170            Description::<DataType, Type<T>>::new().to_string(),
171            typename
172        );
173    }
174
175    fn test_dispatch_fail<T>(datatype: DataType, typename: &str)
176    where
177        Type<T>: DispatchRule<DataType>,
178    {
179        assert_eq!(<Type<T>>::try_match(&datatype), Err(MATCH_FAIL));
180        assert_eq!(
181            Why::<DataType, Type<T>>::new(&datatype).to_string(),
182            format!("expected \"{}\" but found \"{}\"", typename, datatype)
183        );
184    }
185
186    fn test_dispatch_success<T>(datatype: DataType)
187    where
188        Type<T>: DispatchRule<DataType>,
189    {
190        assert_eq!(<Type<T>>::try_match(&datatype), Ok(MatchScore(0)));
191        assert_eq!(
192            Why::<DataType, Type<T>>::new(&datatype).to_string(),
193            "successful match",
194        );
195    }
196
197    macro_rules! type_test {
198        ($test:ident, $T:ty, $var:ident, $($fails:ident),* $(,)?) => {
199            #[test]
200            fn $test() {
201                let typename = stringify!($var).to_lowercase();
202
203                test_description::<$T>(&typename);
204                test_dispatch_success::<$T>(DataType::$var);
205                $(test_dispatch_fail::<$T>(DataType::$fails, &typename);)*
206            }
207        }
208    }
209
210    type_test!(test_f64, f64, Float64, Float16, UInt8);
211    type_test!(test_f32, f32, Float32, Float16, UInt8);
212    type_test!(test_f16, f16, Float16, UInt8, UInt16);
213    type_test!(test_u8, u8, UInt8, UInt16, UInt32);
214    type_test!(test_u16, u16, UInt16, UInt32, UInt64);
215    type_test!(test_u32, u32, UInt32, UInt64, Int8);
216    type_test!(test_u64, u64, UInt64, Int8, Int16);
217    type_test!(test_i8, i8, Int8, Int16, Int32);
218    type_test!(test_i16, i16, Int16, Int32, Int64);
219    type_test!(test_i32, i32, Int32, Int64, Bool);
220    type_test!(test_i64, i64, Int64, Bool, Float32);
221    type_test!(test_bool, bool, Bool, Float32, Float16);
222}