Skip to main content

diskann_benchmark_runner/dispatcher/
api.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::fmt::{self, Display, Formatter};
7
8/// Successful matches from [`DispatchRule`] will return `MatchScores`.
9///
10/// A lower numerical value indicates a better match for purposes of overload resolution.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub struct MatchScore(pub u32);
13
14impl Display for MatchScore {
15    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
16        write!(f, "success ({})", self.0)
17    }
18}
19
20/// Successful matches from [`DispatchRule`] will return `FailureScores`.
21///
22/// A lower numerical value indicates a better match, which can help when compiling a
23/// list of considered and rejected candidates.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25pub struct FailureScore(pub u32);
26
27impl Display for FailureScore {
28    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
29        write!(f, "fail ({})", self.0)
30    }
31}
32
33/// A version of [`FailureScore`] that contains the score as well as the reason for the
34/// failure.
35pub struct TaggedFailureScore<'a> {
36    pub(crate) score: u32,
37    pub(crate) why: Box<dyn std::fmt::Display + 'a>,
38}
39
40impl TaggedFailureScore<'_> {
41    /// Return the failure score for `Self`.
42    pub fn score(&self) -> FailureScore {
43        FailureScore(self.score)
44    }
45}
46
47impl fmt::Debug for TaggedFailureScore<'_> {
48    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49        f.debug_struct("TaggedFailureScore")
50            .field("score", &self.score)
51            .field("why", &self.why.to_string())
52            .finish()
53    }
54}
55
56impl Display for TaggedFailureScore<'_> {
57    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58        write!(f, "{}", self.why)
59    }
60}
61
62/// The primary trait for conducting dispatch matches from the type `From`.
63pub trait DispatchRule<From>: Sized {
64    /// Errors that can occur during `convert`.
65    type Error: std::fmt::Debug + std::fmt::Display + 'static;
66
67    /// Attempt to match the value `From` to the type represented by `Self`.
68    ///
69    /// If `from` has a compatible value, return `Ok(score)` where `score` attempts to
70    /// describe how good of a fit `from` is to allow for overload resolution.
71    ///
72    /// If `from` is incompatible, return `Err(score)`
73    fn try_match(from: &From) -> Result<MatchScore, FailureScore>;
74
75    /// Perform the actual conversion.
76    ///
77    /// It is expected that this method will only be called when `try_match(&from)` returns
78    /// success. An error type can be returned due to either:
79    ///
80    /// 1. `try_match` returning `Ok()` erroneously due to an incorrect implementation.
81    /// 2. `from` originally looked like a match, but broke some invariant of `Self`s
82    ///    constructor.
83    fn convert(from: From) -> Result<Self, Self::Error>;
84
85    //////////////////////
86    // Provided Methods //
87    //////////////////////
88
89    /// Write a description of the dispatch rule and outcome to the formatter.
90    ///
91    /// If `from.is_none()`, then a description of `Self` should be provided.
92    ///
93    /// Otherwise, the implementation should provide a description of the dispatching logic
94    /// (success or failure) for the argument.
95    fn description(f: &mut Formatter<'_>, _from: Option<&From>) -> fmt::Result {
96        write!(f, "<no description>")
97    }
98
99    /// The equivalent of `try_match` but returns a reason for a failed score.
100    ///
101    /// This allows emission of diagnostics for method mismatches.
102    ///
103    /// The provided implementation of this method calls [`Self::description(_, Some(from))`].
104    fn try_match_verbose<'a>(from: &'a From) -> Result<MatchScore, TaggedFailureScore<'a>>
105    where
106        Self: 'a,
107    {
108        match Self::try_match(from) {
109            Ok(score) => Ok(score),
110            Err(score) => Err(TaggedFailureScore {
111                score: score.0,
112                why: Box::new(Why::<From, Self>::new(from)),
113            }),
114        }
115    }
116}
117
118/// A helper struct to help dscribe the reason for a match failure.
119#[derive(Debug, Clone, Copy)]
120pub struct Why<'a, From, To> {
121    from: &'a From,
122    _to: std::marker::PhantomData<To>,
123}
124
125impl<'a, From, To> Why<'a, From, To> {
126    pub fn new(from: &'a From) -> Self {
127        Self {
128            from,
129            _to: std::marker::PhantomData,
130        }
131    }
132}
133
134impl<From, To> std::fmt::Display for Why<'_, From, To>
135where
136    To: DispatchRule<From>,
137{
138    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
139        To::description(f, Some(self.from))
140    }
141}
142
143/// A helper struct to retrieve the empty description from a [`DispatchRule`].
144#[derive(Debug, Clone, Copy)]
145pub struct Description<From, To> {
146    _from: std::marker::PhantomData<From>,
147    _to: std::marker::PhantomData<To>,
148}
149
150impl<From, To> Description<From, To> {
151    pub fn new() -> Self {
152        Self {
153            _from: std::marker::PhantomData,
154            _to: std::marker::PhantomData,
155        }
156    }
157}
158
159impl<From, To> Default for Description<From, To> {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<From, To> std::fmt::Display for Description<From, To>
166where
167    To: DispatchRule<From>,
168{
169    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
170        To::description(f, None)
171    }
172}
173
174/////////////////////////////
175// Blanket Implementations //
176/////////////////////////////
177
178/// A score assigned to implicit matches either via the identity transformation or through
179/// mut-ref to ref conversion.
180pub const IMPLICIT_MATCH_SCORE: MatchScore = MatchScore(100000);
181
182impl<T: Sized> DispatchRule<T> for T {
183    type Error = std::convert::Infallible;
184
185    fn try_match(_from: &T) -> Result<MatchScore, FailureScore> {
186        Ok(IMPLICIT_MATCH_SCORE)
187    }
188
189    fn convert(from: T) -> Result<T, Self::Error> {
190        Ok(from)
191    }
192
193    fn description(f: &mut Formatter<'_>, from: Option<&T>) -> fmt::Result {
194        match from {
195            None => write!(f, "{}", std::any::type_name::<T>()),
196            Some(_) => write!(f, "identity match"),
197        }
198    }
199}
200
201// Allow mutable references to be forwarded to const-references.
202impl<'a, T: Sized> DispatchRule<&'a mut T> for &'a T {
203    type Error = std::convert::Infallible;
204
205    fn try_match(_from: &&'a mut T) -> Result<MatchScore, FailureScore> {
206        Ok(IMPLICIT_MATCH_SCORE)
207    }
208
209    fn convert(from: &'a mut T) -> Result<&'a T, Self::Error> {
210        Ok(from)
211    }
212
213    fn description(f: &mut Formatter<'_>, from: Option<&&'a mut T>) -> fmt::Result {
214        match from {
215            None => write!(f, "&{}", std::any::type_name::<T>()),
216            Some(_) => write!(f, "identity match"),
217        }
218    }
219}
220
221///////////
222// Tests //
223///////////
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_match_score() {
231        let x = MatchScore(10);
232        let y = MatchScore(20);
233        assert!(x < y);
234        assert!(x <= y);
235        assert!(x <= x);
236        assert!(x == x);
237        assert!(x != y);
238
239        assert!(y == y);
240        assert!(y != x);
241        assert!(y > x);
242        assert!(y >= x);
243
244        assert_eq!(x.to_string(), "success (10)");
245    }
246
247    #[test]
248    fn test_fail_score() {
249        let x = FailureScore(10);
250        let y = FailureScore(20);
251        assert!(x < y);
252        assert!(x <= y);
253        assert!(x <= x);
254        assert!(x == x);
255        assert!(x != y);
256
257        assert!(y == y);
258        assert!(y != x);
259        assert!(y > x);
260        assert!(y >= x);
261
262        assert_eq!(x.to_string(), "fail (10)");
263    }
264
265    #[test]
266    fn test_tagged_failure() {
267        let tagged = TaggedFailureScore {
268            score: 10,
269            why: Box::new(20),
270        };
271
272        assert_eq!(tagged.score(), FailureScore(10));
273
274        // Formatted goes through the inner formatter.
275        assert_eq!(tagged.to_string(), "20");
276
277        assert_eq!(
278            format!("{:?}", tagged),
279            "TaggedFailureScore { score: 10, why: \"20\" }"
280        );
281    }
282
283    enum TestEnum {
284        A,
285        B,
286    }
287
288    struct TestType;
289
290    impl DispatchRule<TestEnum> for TestType {
291        type Error = std::convert::Infallible;
292        fn try_match(x: &TestEnum) -> Result<MatchScore, FailureScore> {
293            match x {
294                TestEnum::A => Ok(MatchScore(10)),
295                TestEnum::B => Err(FailureScore(20)),
296            }
297        }
298
299        fn convert(x: TestEnum) -> Result<Self, Self::Error> {
300            assert!(matches!(x, TestEnum::A));
301            Ok(TestType)
302        }
303
304        fn description(f: &mut Formatter<'_>, from: Option<&TestEnum>) -> fmt::Result {
305            match from {
306                None => write!(f, "TestEnum::A"),
307                Some(value) => match value {
308                    TestEnum::A => write!(f, "success"),
309                    TestEnum::B => write!(f, "expected TestEnum::B"),
310                },
311            }
312        }
313    }
314
315    #[test]
316    fn test_dispatch_helpers() {
317        let desc = Description::<TestEnum, TestType>::default().to_string();
318        assert_eq!(desc, "TestEnum::A");
319
320        let a = TestEnum::A;
321        let why = Why::<_, TestType>::new(&a).to_string();
322        assert_eq!(why, "success");
323
324        let b = TestEnum::B;
325        let why = Why::<_, TestType>::new(&b).to_string();
326        assert_eq!(why, "expected TestEnum::B");
327
328        let result = TestType::try_match_verbose(&a).unwrap();
329        assert_eq!(result, MatchScore(10));
330
331        let result = TestType::try_match_verbose(&b).unwrap_err();
332        assert_eq!(result.score(), FailureScore(20));
333        assert_eq!(result.to_string(), "expected TestEnum::B");
334
335        TestType::convert(TestEnum::A).unwrap();
336    }
337
338    #[test]
339    fn test_implicit_conversions() {
340        // Identity
341        let x = f32::try_match(&0.0f32).unwrap();
342        assert_eq!(x, IMPLICIT_MATCH_SCORE);
343
344        let x = f32::convert(0.0f32).unwrap();
345        assert_eq!(x, 0.0f32);
346
347        let x = <&f32>::try_match(&&mut 0.0f32).unwrap();
348        assert_eq!(x, IMPLICIT_MATCH_SCORE);
349
350        let mut x: f32 = 10.0;
351        let x = <&f32>::convert(&mut x).unwrap();
352        assert_eq!(*x, 10.0);
353
354        assert_eq!(Description::<f32, f32>::new().to_string(), "f32");
355        assert_eq!(Why::<f32, f32>::new(&0.0f32).to_string(), "identity match");
356
357        assert_eq!(Description::<&mut f32, &f32>::new().to_string(), "&f32");
358        assert_eq!(
359            Why::<&mut f32, &f32>::new(&&mut 0.0f32).to_string(),
360            "identity match"
361        );
362    }
363
364    #[test]
365    #[should_panic]
366    fn convert_panics() {
367        let _ = TestType::convert(TestEnum::B);
368    }
369}