Skip to main content

diskann_benchmark_runner/dispatcher/
api.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::fmt::{self, Display, Formatter};
7
8/// Successful matches from [`DispatchRule`] will return `MatchScores`.
9///
10/// A lower numerical value indicates a better match for purposes of overload resolution.
11#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
12pub struct MatchScore(pub u32);
13
14impl Display for MatchScore {
15    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
16        write!(f, "success ({})", self.0)
17    }
18}
19
20/// Successful matches from [`DispatchRule`] will return `FailureScores`.
21///
22/// A lower numerical value indicates a better match, which can help when compiling a
23/// list of considered and rejected candidates.
24#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
25pub struct FailureScore(pub u32);
26
27impl Display for FailureScore {
28    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
29        write!(f, "fail ({})", self.0)
30    }
31}
32
33/// A version of [`FailureScore`] that contains the score as well as the reason for the
34/// failure.
35pub struct TaggedFailureScore<'a> {
36    pub(crate) score: u32,
37    pub(crate) why: Box<dyn std::fmt::Display + 'a>,
38}
39
40impl TaggedFailureScore<'_> {
41    /// Return the failure score for `Self`.
42    pub fn score(&self) -> FailureScore {
43        FailureScore(self.score)
44    }
45}
46
47impl fmt::Debug for TaggedFailureScore<'_> {
48    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49        f.debug_struct("TaggedFailureScore")
50            .field("score", &self.score)
51            .field("why", &self.why.to_string())
52            .finish()
53    }
54}
55
56impl Display for TaggedFailureScore<'_> {
57    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58        write!(f, "{}", self.why)
59    }
60}
61
62/// The primary trait for conducting dispatch matches from the type `From`.
63pub trait DispatchRule<From>: Sized {
64    /// Errors that can occur during `convert`.
65    type Error: std::fmt::Debug + std::fmt::Display + 'static;
66
67    /// Attempt to match the value `From` to the type represented by `Self`.
68    ///
69    /// If `from` has a compatible value, return `Ok(score)` where `score` attempts to
70    /// describe how good of a fit `from` is to allow for overload resolution.
71    ///
72    /// If `from` is incompatible, return `Err(score)`
73    fn try_match(from: &From) -> Result<MatchScore, FailureScore>;
74
75    /// Perform the actual conversion.
76    ///
77    /// It is expected that this method will only be called when `try_match(&from)` returns
78    /// success. An error type can be returned due to either:
79    ///
80    /// 1. `try_match` returning `Ok()` erroneously due to an incorrect implementation.
81    /// 2. `from` originally looked like a match, but broke some invariant of `Self`s
82    ///    constructor.
83    fn convert(from: From) -> Result<Self, Self::Error>;
84
85    //////////////////////
86    // Provided Methods //
87    //////////////////////
88
89    /// Write a description of the dispatch rule and outcome to the formatter.
90    ///
91    /// If `from.is_none()`, then a description of `Self` should be provided.
92    ///
93    /// Otherwise, the implementation should provide a description of the dispatching logic
94    /// (success or failure) for the argument.
95    fn description(f: &mut Formatter<'_>, _from: Option<&From>) -> fmt::Result {
96        write!(f, "<no description>")
97    }
98
99    /// The equivalent of `try_match` but returns a reason for a failed score.
100    ///
101    /// This allows emission of diagnostics for method mismatches.
102    ///
103    /// The provided implementation of this method calls [`Self::description(_, Some(from))`].
104    fn try_match_verbose<'a>(from: &'a From) -> Result<MatchScore, TaggedFailureScore<'a>>
105    where
106        Self: 'a,
107    {
108        match Self::try_match(from) {
109            Ok(score) => Ok(score),
110            Err(score) => Err(TaggedFailureScore {
111                score: score.0,
112                why: Box::new(Why::<From, Self>::new(from)),
113            }),
114        }
115    }
116}
117
118/// A helper struct to help dscribe the reason for a match failure.
119#[derive(Debug, Clone, Copy)]
120pub struct Why<'a, From, To> {
121    from: &'a From,
122    _to: std::marker::PhantomData<To>,
123}
124
125impl<'a, From, To> Why<'a, From, To> {
126    pub fn new(from: &'a From) -> Self {
127        Self {
128            from,
129            _to: std::marker::PhantomData,
130        }
131    }
132}
133
134impl<From, To> std::fmt::Display for Why<'_, From, To>
135where
136    To: DispatchRule<From>,
137{
138    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
139        To::description(f, Some(self.from))
140    }
141}
142
143/// A helper struct to retrieve the empty description from a [`DispatchRule`].
144#[derive(Debug, Clone, Copy)]
145pub struct Description<From, To> {
146    _from: std::marker::PhantomData<From>,
147    _to: std::marker::PhantomData<To>,
148}
149
150impl<From, To> Description<From, To> {
151    pub fn new() -> Self {
152        Self {
153            _from: std::marker::PhantomData,
154            _to: std::marker::PhantomData,
155        }
156    }
157}
158
159impl<From, To> Default for Description<From, To> {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<From, To> std::fmt::Display for Description<From, To>
166where
167    To: DispatchRule<From>,
168{
169    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
170        To::description(f, None)
171    }
172}
173
174/////////////////////////////
175// Blanket Implementations //
176/////////////////////////////
177
178/// A score assigned to implicit matches either via the identity transformation or through
179/// mut-ref to ref conversion.
180pub const IMPLICIT_MATCH_SCORE: MatchScore = MatchScore(100000);
181
182impl<T: Sized> DispatchRule<T> for T {
183    type Error = std::convert::Infallible;
184
185    fn try_match(_from: &T) -> Result<MatchScore, FailureScore> {
186        Ok(IMPLICIT_MATCH_SCORE)
187    }
188
189    fn convert(from: T) -> Result<T, Self::Error> {
190        Ok(from)
191    }
192
193    fn description(f: &mut Formatter<'_>, from: Option<&T>) -> fmt::Result {
194        match from {
195            None => write!(f, "{}", std::any::type_name::<T>()),
196            Some(_) => write!(f, "identity match"),
197        }
198    }
199}
200
201// Allow mutable references to be forwarded to const-references.
202impl<'a, T: Sized> DispatchRule<&'a mut T> for &'a T {
203    type Error = std::convert::Infallible;
204
205    fn try_match(_from: &&'a mut T) -> Result<MatchScore, FailureScore> {
206        Ok(IMPLICIT_MATCH_SCORE)
207    }
208
209    fn convert(from: &'a mut T) -> Result<&'a T, Self::Error> {
210        Ok(from)
211    }
212
213    fn description(f: &mut Formatter<'_>, from: Option<&&'a mut T>) -> fmt::Result {
214        match from {
215            None => write!(f, "&{}", std::any::type_name::<T>()),
216            Some(_) => write!(f, "identity match"),
217        }
218    }
219}
220
221/// # Lifetime Mapping
222///
223/// The types in signatures for dispatches need to be `'static` due to Rust.
224/// However, it is nice to allow objects with lifetimes to cross the dispatcher boundary.
225///
226/// The `Map` trait facilitates this by allowing `'static` types to have an optional
227/// lifetime attached as a generic associated type.
228///
229/// This associated type is that is what is actually given to dispatcher methods.
230///
231/// ## Example
232///
233/// To pass a `Vec` across a dispatcher boundary, we can use the [`Type`] helper:
234///
235/// ```
236/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Type};
237///
238/// let mut d = Dispatcher1::<&'static str, Type<Vec<f32>>>::new();
239/// d.register::<_, Type<Vec<f32>>>("method",  |_: Vec<f32>| "called");
240/// assert_eq!(d.call(vec![1.0]), Some("called"));
241/// ```
242///
243/// This is a bit tedious to write every time, so instead types can implement [`Map`] for
244/// themselves:
245///
246/// ```
247/// use diskann_benchmark_runner::{self_map, dispatcher::{Dispatcher1}};
248///
249/// struct MyNum(f32);
250/// self_map!(MyNum);
251///
252/// // Now, `MyNum` can be used directly in dispatcher signatures.
253/// let mut d = Dispatcher1::<f32, MyNum>::new();
254/// d.register::<_, MyNum>("method", |n: MyNum| n.0);
255/// assert_eq!(d.call(MyNum(0.0)), Some(0.0));
256/// ```
257///
258/// ## See Also:
259///
260/// * [`Ref`]: Mapping References
261/// * [`MutRef`]: Mapping Mutable References
262/// * [`Type`]: Mapper for generic types
263/// * [`crate::self_map!`]: Allow types to represent themselves in dispatcher signatures.
264///
265pub trait Map: 'static {
266    /// The actual type provided to the dispatcher, with an optional additional lifetime.
267    type Type<'a>;
268}
269
270/// Allow references to cross dispatcher boundaries as shown in the following example:
271///
272/// ```
273/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Ref};
274///
275/// let mut d = Dispatcher1::<*const f32, Ref<[f32]>>::new();
276/// d.register::<_, Ref<[f32]>>("method", |data: &[f32]| data.as_ptr());
277///
278/// let v = vec![1.0, 2.0];
279/// assert_eq!(d.call(&v), Some(v.as_ptr()));
280/// ```
281pub struct Ref<T: ?Sized + 'static>(std::marker::PhantomData<T>);
282
283impl<T: ?Sized> Map for Ref<T> {
284    type Type<'a> = &'a T;
285}
286
287/// Allow mutable references to cross dispatcher boundaries as shown below.
288///
289/// ```
290/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, MutRef};
291///
292/// let mut d = Dispatcher1::<(), MutRef<Vec<f32>>>::new();
293/// d.register::<_, MutRef<Vec<f32>>>("method", |v: &mut Vec<f32>| v.push(0.0));
294///
295/// let mut v = Vec::new();
296/// d.call(&mut v).unwrap();
297/// assert_eq!(&v, &[0.0]);
298/// ```
299pub struct MutRef<T: ?Sized + 'static>(std::marker::PhantomData<T>);
300impl<T: ?Sized> Map for MutRef<T> {
301    type Type<'a> = &'a mut T;
302}
303
304pub struct Type<T: 'static>(std::marker::PhantomData<T>);
305impl<T> Map for Type<T> {
306    type Type<'a> = T;
307}
308
309#[macro_export]
310macro_rules! self_map {
311    ($($type:tt)*) => {
312        impl $crate::dispatcher::Map for $($type)* {
313            type Type<'a> = $($type)*;
314        }
315    }
316}
317
318self_map!(bool);
319self_map!(usize);
320self_map!(u8);
321self_map!(u16);
322self_map!(u32);
323self_map!(u64);
324self_map!(u128);
325self_map!(i8);
326self_map!(i16);
327self_map!(i32);
328self_map!(i64);
329self_map!(i128);
330self_map!(String);
331self_map!(f32);
332self_map!(f64);
333
334/// Reasons for a method call mismatch.
335///
336/// The name of the associated method can be queried using `self.method()` and reasons
337/// are obtained in `self.mismatches()`.
338pub struct ArgumentMismatch<'a, const N: usize> {
339    pub(crate) method: &'a str,
340    pub(crate) mismatches: [Option<Box<dyn std::fmt::Display + 'a>>; N],
341}
342
343impl<'a, const N: usize> ArgumentMismatch<'a, N> {
344    /// Return the name of the associated method.
345    pub fn method(&self) -> &str {
346        self.method
347    }
348
349    /// Return a slice of reasons for method match failure.
350    ///
351    /// The returned slice contains one entry per argument. An entry is `None` if that
352    /// argument matched the input value.
353    ///
354    /// If the argument did not match the input value, then the corresponding
355    /// [`std::fmt::Display`] object can be used to retrieve the reason.
356    pub fn mismatches(&self) -> &[Option<Box<dyn std::fmt::Display + 'a>>; N] {
357        &self.mismatches
358    }
359}
360
361/// Return the signature for an argument type.
362pub struct Signature(pub(crate) fn(&mut Formatter<'_>) -> std::fmt::Result);
363
364impl std::fmt::Display for Signature {
365    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
366        (self.0)(f)
367    }
368}
369
370///////////
371// Tests //
372///////////
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_match_score() {
380        let x = MatchScore(10);
381        let y = MatchScore(20);
382        assert!(x < y);
383        assert!(x <= y);
384        assert!(x <= x);
385        assert!(x == x);
386        assert!(x != y);
387
388        assert!(y == y);
389        assert!(y != x);
390        assert!(y > x);
391        assert!(y >= x);
392
393        assert_eq!(x.to_string(), "success (10)");
394    }
395
396    #[test]
397    fn test_fail_score() {
398        let x = FailureScore(10);
399        let y = FailureScore(20);
400        assert!(x < y);
401        assert!(x <= y);
402        assert!(x <= x);
403        assert!(x == x);
404        assert!(x != y);
405
406        assert!(y == y);
407        assert!(y != x);
408        assert!(y > x);
409        assert!(y >= x);
410
411        assert_eq!(x.to_string(), "fail (10)");
412    }
413
414    #[test]
415    fn test_tagged_failure() {
416        let tagged = TaggedFailureScore {
417            score: 10,
418            why: Box::new(20),
419        };
420
421        assert_eq!(tagged.score(), FailureScore(10));
422
423        // Formatted goes through the inner formatter.
424        assert_eq!(tagged.to_string(), "20");
425
426        assert_eq!(
427            format!("{:?}", tagged),
428            "TaggedFailureScore { score: 10, why: \"20\" }"
429        );
430    }
431
432    enum TestEnum {
433        A,
434        B,
435    }
436
437    struct TestType;
438
439    impl DispatchRule<TestEnum> for TestType {
440        type Error = std::convert::Infallible;
441        fn try_match(x: &TestEnum) -> Result<MatchScore, FailureScore> {
442            match x {
443                TestEnum::A => Ok(MatchScore(10)),
444                TestEnum::B => Err(FailureScore(20)),
445            }
446        }
447
448        fn convert(x: TestEnum) -> Result<Self, Self::Error> {
449            assert!(matches!(x, TestEnum::A));
450            Ok(TestType)
451        }
452
453        fn description(f: &mut Formatter<'_>, from: Option<&TestEnum>) -> fmt::Result {
454            match from {
455                None => write!(f, "TestEnum::A"),
456                Some(value) => match value {
457                    TestEnum::A => write!(f, "success"),
458                    TestEnum::B => write!(f, "expected TestEnum::B"),
459                },
460            }
461        }
462    }
463
464    #[test]
465    fn test_dispatch_helpers() {
466        let desc = Description::<TestEnum, TestType>::default().to_string();
467        assert_eq!(desc, "TestEnum::A");
468
469        let a = TestEnum::A;
470        let why = Why::<_, TestType>::new(&a).to_string();
471        assert_eq!(why, "success");
472
473        let b = TestEnum::B;
474        let why = Why::<_, TestType>::new(&b).to_string();
475        assert_eq!(why, "expected TestEnum::B");
476
477        let result = TestType::try_match_verbose(&a).unwrap();
478        assert_eq!(result, MatchScore(10));
479
480        let result = TestType::try_match_verbose(&b).unwrap_err();
481        assert_eq!(result.score(), FailureScore(20));
482        assert_eq!(result.to_string(), "expected TestEnum::B");
483
484        TestType::convert(TestEnum::A).unwrap();
485    }
486
487    #[test]
488    fn test_implicit_conversions() {
489        // Identity
490        let x = f32::try_match(&0.0f32).unwrap();
491        assert_eq!(x, IMPLICIT_MATCH_SCORE);
492
493        let x = f32::convert(0.0f32).unwrap();
494        assert_eq!(x, 0.0f32);
495
496        let x = <&f32>::try_match(&&mut 0.0f32).unwrap();
497        assert_eq!(x, IMPLICIT_MATCH_SCORE);
498
499        let mut x: f32 = 10.0;
500        let x = <&f32>::convert(&mut x).unwrap();
501        assert_eq!(*x, 10.0);
502
503        assert_eq!(Description::<f32, f32>::new().to_string(), "f32");
504        assert_eq!(Why::<f32, f32>::new(&0.0f32).to_string(), "identity match");
505
506        assert_eq!(Description::<&mut f32, &f32>::new().to_string(), "&f32");
507        assert_eq!(
508            Why::<&mut f32, &f32>::new(&&mut 0.0f32).to_string(),
509            "identity match"
510        );
511    }
512
513    #[test]
514    #[should_panic]
515    fn convert_panics() {
516        let _ = TestType::convert(TestEnum::B);
517    }
518}