Skip to main content

diskann_benchmark_runner/dispatcher/
dispatch.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::fmt::Formatter;
7
8use super::{
9    ArgumentMismatch, DispatchRule, FailureScore, Map, MatchScore, Signature, TaggedFailureScore,
10};
11
12/// Return `Some` if all the entries in `Input` are `Ok(MatchScore)`.
13///
14/// Otherwise, return None
15fn coalesce<const N: usize>(
16    input: &[Result<MatchScore, FailureScore>; N],
17) -> Option<[MatchScore; N]> {
18    let mut output = [MatchScore(0); N];
19    for i in 0..N {
20        output[i] = match input[i] {
21            Ok(score) => score,
22            Err(_) => return None,
23        }
24    }
25    Some(output)
26}
27
28/// Return `true` if all values in `inpu` are `Ok`.
29fn all_match<const N: usize, T>(input: &[Result<MatchScore, T>; N]) -> bool {
30    input.iter().all(|i| matches!(i, Ok(MatchScore(_))))
31}
32
33/// A method match along with a tagged failure score.
34///
35/// This is used as part of the match failure debugging process.
36struct TaggedMatch<'a, const N: usize> {
37    method: &'a str,
38    score: [Result<MatchScore, TaggedFailureScore<'a>>; N],
39}
40
41impl<'a, const N: usize> From<TaggedMatch<'a, N>> for ArgumentMismatch<'a, N> {
42    fn from(value: TaggedMatch<'a, N>) -> Self {
43        ArgumentMismatch {
44            method: value.method,
45            mismatches: value.score.map(|r| match r {
46                Ok(_) => None,
47                Err(tagged) => Some(tagged.why),
48            }),
49        }
50    }
51}
52
53/// An ordered priority queue that keeps track of the "closest" mismatches.
54struct Queue<'a, const N: usize> {
55    buffer: Vec<TaggedMatch<'a, N>>,
56    max_methods: usize,
57}
58
59impl<'a, const N: usize> Queue<'a, N> {
60    fn new(max_methods: usize) -> Self {
61        Self {
62            buffer: Vec::with_capacity(max_methods),
63            max_methods,
64        }
65    }
66
67    fn finish(self) -> Vec<ArgumentMismatch<'a, N>> {
68        self.buffer.into_iter().map(|m| m.into()).collect()
69    }
70
71    /// Insert `r` into the queue in sorted order.
72    ///
73    /// Returns `Err(())` if all entries in `r` are matches. This provies a means for
74    /// algorithms reporting errors to detect if in fact the collection of arguments
75    /// are dispatchable and debugging is not actually needed.
76    fn push(&mut self, y: TaggedMatch<'a, N>) -> Result<(), ()> {
77        use std::cmp::Ordering;
78
79        if all_match(&y.score) {
80            return Err(());
81        }
82
83        // Now we get the fun part of ranking methods.
84        // We rank first on `MatchScore`, then on `FailureScore`.
85        let lt = |x: &TaggedMatch<'a, N>| {
86            for i in 0..N {
87                let xi = &x.score[i];
88                let yi = &y.score[i];
89                match xi {
90                    Ok(MatchScore(x_score)) => match yi {
91                        Ok(MatchScore(y_score)) => match x_score.cmp(y_score) {
92                            Ordering::Equal => {}
93                            strict => return strict,
94                        },
95                        Err(_) => {
96                            return Ordering::Less;
97                        }
98                    },
99                    Err(TaggedFailureScore { score: x_score, .. }) => match yi {
100                        Ok(_) => {
101                            return Ordering::Greater;
102                        }
103                        Err(TaggedFailureScore { score: y_score, .. }) => {
104                            match x_score.cmp(y_score) {
105                                Ordering::Equal => {}
106                                strict => return strict,
107                            }
108                        }
109                    },
110                }
111            }
112            Ordering::Equal
113        };
114
115        // `binary_search_by` will always return an index that will allow the key to be
116        // placed in sorted order.
117        //
118        // We do not care if the method is present or not, we just want the index.
119        let i = match self.buffer.binary_search_by(lt) {
120            Ok(i) => i,
121            Err(i) => i,
122        };
123
124        if self.buffer.len() == self.max_methods {
125            // No need to insert, it's greater than our worst match so far.
126            if i > self.buffer.len() {
127                return Ok(());
128            }
129            self.buffer.insert(i, y);
130            self.buffer.truncate(self.max_methods);
131        } else {
132            self.buffer.insert(i, y);
133        }
134        Ok(())
135    }
136}
137
138pub trait Sealed {}
139
140macro_rules! implement_dispatch {
141    ($trait:ident,
142     $method:ident,
143     $dispatcher:ident,
144     $N:literal,
145     { $($T:ident )+ },
146     { $($x:ident )+ },
147     { $($A:ident )+ },
148     { $($lf:lifetime )+ }
149    ) => {
150        /// A dispatchable method.
151        ///
152        /// # Macro Expansion
153        ///
154        /// Generates the code below:
155        /// ```text
156        /// pub trait DispatcherN<R, T0, T1, ...>
157        /// where
158        ///     T0: Map,
159        ///     T1: Map,
160        ///     ...,
161        /// {
162        ///     fn try_match(&self, x0: &T0::Type<'_>, x1: &T1::Type<'_>, ...);
163        ///
164        ///     fn call(&self, x0: T0::Type<'_>, x1: T1::Type<'_), ...) -> R;
165        ///
166        ///     fn signatures(&self) -> [Signature; N];
167        ///
168        ///     fn try_match_verbose<'a, 'a0, 'a1, ...>(
169        ///         &'a self,
170        ///         x0: &'a T0::Type<'a0>,
171        ///         x1: &'a T1::Type<'a1>,
172        ///         ...
173        ///     ) -> [Result<MatchScore, TaggedFailureScore<'a>>; N]
174        ///     where
175        ///         'a0: 'a,
176        ///         'a1: 'a,
177        ///         ...;
178        /// }
179        /// ```
180        pub trait $trait<R, $($T,)*>: Sealed
181        where
182            $($T: Map,)*
183        {
184            /// Invoke [`DispatchRule::try_match`] on each argument/type pair where the type
185            /// comes from the backend method.
186            ///
187            /// Return all results.
188            fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result<MatchScore, FailureScore>; $N];
189
190            /// Invoke this method with the given types, invoking [`DispatchRule::convert`]
191            /// on each argument to the target types of the backend method.
192            ///
193            /// This function is only safe to call if [`Self::try_match`] returns a success.
194            /// Calling this method incorrectly may panic.
195            ///
196            /// # Panics
197            ///
198            /// Panics if any call to [`DispatchRule::convert`] fails.
199            fn call(&self, $($x: $T::Type<'_>,)*) -> R;
200
201            /// Return the signatures for each back-end argument type.
202            fn signatures(&self) -> [Signature; $N];
203
204            /// The equivalent of [`Self::try_match`], but using the
205            /// [`DispatchRule::try_match_verbose`] interface.
206            ///
207            /// This provides a method for inspecting the reason for match failures.
208            fn try_match_verbose<'a, $($lf,)*>(
209                &'a self,
210                $($x: &'a $T::Type<$lf>,)*
211            ) -> [Result<MatchScore, TaggedFailureScore<'a>>; $N]
212            where
213                $($lf: 'a,)*;
214        }
215
216        /// # Macro Expansion
217        ///
218        /// ```text
219        /// pub struct MethodN<R, A0, A1, ...>
220        /// where
221        ///     A0: Map,
222        ///     A1: Map,
223        ///     ...,
224        /// {
225        ///     f: Box<dyn for<'a0, 'a1, ...> Fn(A0::Type<'a0>, A1::Type<'a1>, ...) -> R>,
226        ///     _types: std::marker::PhantomData<(A0, A1, ...)>,
227        /// }
228        /// ```
229        pub struct $method<R, $($A,)*>
230        where
231            $($A: Map,)*
232        {
233            f: Box<dyn for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R>,
234            _types: std::marker::PhantomData<($($A,)*)>,
235        }
236
237        /// # Macro Expansion
238        ///
239        /// ```text
240        /// impl <R, A0, A1, ...> MethodN<R, A0, A1, ...>
241        /// where
242        ///     R: 'static,
243        ///     A0: Map,
244        ///     A1: Map,
245        ///     ...,
246        /// {
247        ///     pub fn new<F>(f: F) -> Self
248        ///     where
249        ///         F: for<'a0, 'a1, ...> Fn(A0::Type<'a0>, A1::Type<'a1>, ...) -> R + 'static,
250        ///     {
251        ///         Self {
252        ///             f: Box::new(f),
253        ///             _types: std::marker::PhantomData,
254        ///         }
255        ///     }
256        /// }
257        /// ```
258        impl<R, $($A,)*> $method<R, $($A,)*>
259        where
260            $($A: Map,)*
261        {
262            fn new<F>(f: F) -> Self
263            where
264                F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static,
265            {
266                Self {
267                    f: Box::new(f),
268                    _types: std::marker::PhantomData,
269                }
270            }
271        }
272
273        impl<R, $($A,)*> Sealed for $method<R, $($A,)*>
274        where
275            $($A: Map,)*
276        {}
277
278        impl<R, $($T,)* $($A,)*> $trait<R, $($T,)*> for $method<R, $($A,)*>
279        where
280            $($T: Map,)*
281            $($A: Map,)*
282            $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)*
283        {
284            fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result<MatchScore, FailureScore>; $N] {
285                // Splat out all the pair-wise `try_match`es.
286                [$($A::Type::try_match($x),)*]
287            }
288
289            fn call(&self, $($x: $T::Type<'_>,)*) -> R {
290                // Convert and unwrap all pair-wise matches.
291                (self.f)($($A::Type::convert($x).unwrap(),)*)
292            }
293
294            fn signatures(&self) -> [Signature; $N] {
295                // The strategy here involves decaying a stateless lambda to a function
296                // pointer, and generating one such lambda for each input type.
297                //
298                // Note that we need to couple it with its corresponding dispatch type
299                // to ensure we get routed to the correct description.
300                [
301                    $(Signature(|f: &mut Formatter<'_>| {
302                        $A::Type::description(f, None::<&$T::Type<'_>>)
303                    }),)*
304                ]
305            }
306
307            fn try_match_verbose<'a, $($lf,)*>(
308                &self,
309                $($x: &'a $T::Type<$lf>,)*
310            ) -> [Result<MatchScore, TaggedFailureScore<'a>>; $N]
311            where
312                $($lf: 'a,)*
313            {
314                // Simply construct an array by calling `try_match_verbose` on each pair.
315                [$($A::Type::try_match_verbose($x),)*]
316            }
317        }
318
319        /// A central dispatcher for multi-method overloading.
320        pub struct $dispatcher<R, $($T,)*>
321        where
322            R: 'static,
323            $($T: Map,)*
324        {
325            pub(super) methods: Vec<(String, Box<dyn $trait<R, $($T,)*>>)>,
326        }
327
328        impl<R, $($T,)*> Default for $dispatcher<R, $($T,)*>
329        where
330            R: 'static,
331            $($T: Map,)*
332        {
333            fn default() -> Self {
334                Self::new()
335            }
336        }
337
338        impl<R, $($T,)*> $dispatcher<R, $($T,)*>
339        where
340            R: 'static,
341            $($T: Map,)*
342        {
343            /// Construct a new, empty dispatcher.
344            pub fn new() -> Self {
345                Self { methods: Vec::new() }
346            }
347
348            /// Register the new named method with the dispatcher.
349            pub fn register<F, $($A,)*>(&mut self, name: impl Into<String>, f: F)
350            where
351                $($A: Map,)*
352                $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)*
353                F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static,
354            {
355                let method = $method::<R, $($A,)*>::new(f);
356                self.methods.push((name.into(), Box::new(method)))
357            }
358
359            /// Try to invoke the best fitting method with the given arguments.
360            ///
361            /// If no such method can be found, returns `None`.
362            pub fn call(&self, $($x: $T::Type<'_>,)*) -> Option<R> {
363                let mut method: Option<(&_, [MatchScore; $N])> = None;
364                self.methods.iter().for_each(|m| {
365                    match coalesce(&(m.1.try_match($(&$x,)*))) {
366                        // Valid match
367                        Some(score) => match method.as_mut() {
368                            Some(method) => {
369                                if score < method.1 {
370                                    *method = (m, score)
371                                }
372                            }
373                            None => {
374                                method.replace((m, score));
375                            }
376                        },
377                        None => {}
378                    }
379                });
380
381                // Invoke the best method
382                method.map(|(m, _)| m.1.call($($x,)*))
383            }
384
385            /// Return an iterator to the methods registered in this dispatcher.
386            pub fn methods(
387                &self
388            ) -> impl ExactSizeIterator<Item = &(String, Box<dyn $trait<R, $($T,)*>>)> {
389                self.methods.iter()
390            }
391
392            /// Query whether the combination of values has a valid matching method without
393            /// trying to invoke that method.
394            pub fn has_match(&self, $($x: &$T::Type<'_>,)*) -> bool {
395                for m in self.methods.iter() {
396                    if all_match(&m.1.try_match($(&$x,)*)) {
397                        return true;
398                    }
399                }
400                return false;
401            }
402
403            /// Check if a back-end method exists for the arguments.
404            ///
405            /// If so, returns `Ok(())`.
406            ///
407            /// Otherwise, returns a vector of `ArgumentMismatch` for the up-to
408            /// `max_methods` closest methods.
409            ///
410            /// In this context, "closeness" is defined by first comparing match or failure
411            /// scores for argument 0, followed by argument 1 if equal and so on.
412            pub fn debug<'a, $($lf,)*>(
413                &'a self,
414                max_methods: usize,
415                $($x: &'a $T::Type<$lf>,)*
416            ) -> Result<(), Vec<ArgumentMismatch<'a, $N>>>
417            where
418                $($lf: 'a,)*
419            {
420                let mut methods = Queue::new(max_methods);
421                for m in self.methods.iter() {
422                    let t = TaggedMatch {
423                        method: &m.0,
424                        score: m.1.try_match_verbose($($x,)*),
425                    };
426                    match methods.push(t) {
427                        Ok(()) => {},
428                        Err(()) => return Ok(()),
429                    }
430                }
431                Err(methods.finish())
432            }
433        }
434    }
435}
436
437implement_dispatch!(Dispatch1, Method1, Dispatcher1, 1, { T0 }, { x0 }, { A0 }, { 'a0 });
438implement_dispatch!(
439    Dispatch2, Method2, Dispatcher2, 2,
440    { T0 T1 }, { x0 x1 }, { A0 A1 }, { 'a0 'a1 }
441);
442implement_dispatch!(
443    Dispatch3, Method3, Dispatcher3, 3,
444    { T0 T1 T2 }, { x0 x1 x2 }, { A0 A1 A2 }, { 'a0 'a1 'a2 }
445);
446
447///////////
448// Tests //
449///////////
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    struct Num<const N: usize>;
456
457    impl<const N: usize> Map for Num<N> {
458        type Type<'a> = Self;
459    }
460
461    impl<const N: usize> DispatchRule<usize> for Num<N> {
462        type Error = std::convert::Infallible;
463
464        // For testing purposes, we accept values within 2 of `N`, but with decreasing
465        // precedence.
466        fn try_match(from: &usize) -> Result<MatchScore, FailureScore> {
467            let diff = from.abs_diff(N);
468            if diff <= 2 {
469                Ok(MatchScore(diff as u32))
470            } else {
471                Err(FailureScore(diff as u32))
472            }
473        }
474
475        fn convert(from: usize) -> Result<Self, Self::Error> {
476            assert!(from.abs_diff(N) <= 2);
477            Ok(Self)
478        }
479
480        fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&usize>) -> std::fmt::Result {
481            match from {
482                None => write!(f, "{}", N),
483                Some(value) => {
484                    let diff = value.abs_diff(N);
485                    match diff {
486                        0 => write!(f, "success: exact match"),
487                        1 => write!(f, "success: off by 1"),
488                        2 => write!(f, "success: off by 2"),
489                        x => write!(f, "error: off by {}", x),
490                    }
491                }
492            }
493        }
494    }
495
496    ////////////////
497    // Dispatch 1 //
498    ////////////////
499
500    #[test]
501    fn test_dispatch_1() {
502        let mut x = Dispatcher1::<usize, usize>::default();
503        x.register::<_, Num<0>>("method 0", |_| 0);
504        x.register::<_, Num<3>>("method 3", |_| 3);
505        x.register::<_, Num<5>>("method 5", |_| 5);
506        x.register::<_, Num<8>>("method 8", |_| 8);
507
508        {
509            let methods: Vec<_> = x.methods().collect();
510            assert_eq!(methods.len(), 4);
511            assert_eq!(methods[0].0, "method 0");
512            assert_eq!(methods[0].1.signatures()[0].to_string(), "0");
513
514            assert_eq!(methods[1].0, "method 3");
515            assert_eq!(methods[1].1.signatures()[0].to_string(), "3");
516        }
517
518        // Test that dispatching works properly.
519        assert_eq!(x.call(0), Some(0));
520        assert_eq!(x.call(1), Some(0));
521        assert_eq!(x.call(2), Some(3));
522        assert_eq!(x.call(3), Some(3));
523        assert_eq!(x.call(4), Some(3));
524        assert_eq!(x.call(5), Some(5));
525        assert_eq!(x.call(6), Some(5));
526        assert_eq!(x.call(7), Some(8));
527        assert_eq!(x.call(8), Some(8));
528        assert_eq!(x.call(11), None);
529
530        for i in 0..11 {
531            assert!(x.has_match(&i));
532        }
533        for i in 11..20 {
534            assert!(!x.has_match(&i));
535        }
536
537        // Make sure `Debug` works.
538        assert!(x.debug(3, &10).is_ok());
539
540        let mismatches = x.debug(3, &11).unwrap_err();
541        assert_eq!(mismatches.len(), 3);
542
543        // Method 8 is the closest.
544        assert_eq!(mismatches[0].method(), "method 8");
545        assert_eq!(
546            mismatches[0].mismatches()[0].as_ref().unwrap().to_string(),
547            "error: off by 3"
548        );
549
550        // Method 5 is next.
551        assert_eq!(mismatches[1].method(), "method 5");
552        assert_eq!(
553            mismatches[1].mismatches()[0].as_ref().unwrap().to_string(),
554            "error: off by 6"
555        );
556
557        // Method 3 is next.
558        assert_eq!(mismatches[2].method(), "method 3");
559        assert_eq!(
560            mismatches[2].mismatches()[0].as_ref().unwrap().to_string(),
561            "error: off by 8"
562        );
563
564        // Make sure that if we request more than the total number of methods that it is
565        // capped.
566        assert_eq!(x.debug(10, &20).unwrap_err().len(), 4);
567    }
568
569    ////////////////
570    // Dispatch 2 //
571    ////////////////
572
573    #[test]
574    fn test_dispatch_2() {
575        let mut x = Dispatcher2::<usize, usize, usize>::default();
576
577        x.register::<_, Num<10>, Num<10>>("method 0", |_, _| 0);
578        x.register::<_, Num<10>, Num<13>>("method 1", |_, _| 1);
579        x.register::<_, Num<13>, Num<12>>("method 3", |_, _| 3);
580        x.register::<_, Num<12>, Num<10>>("method 2", |_, _| 2);
581
582        {
583            let methods: Vec<_> = x.methods().collect();
584            assert_eq!(methods.len(), 4);
585            assert_eq!(methods[0].0, "method 0");
586            assert_eq!(methods[0].1.signatures()[0].to_string(), "10");
587            assert_eq!(methods[0].1.signatures()[1].to_string(), "10");
588
589            assert_eq!(methods[1].0, "method 1");
590            assert_eq!(methods[1].1.signatures()[0].to_string(), "10");
591            assert_eq!(methods[1].1.signatures()[1].to_string(), "13");
592        }
593
594        // This is where things get weird.
595        assert_eq!(x.call(10, 10), Some(0)); // Match method 0
596        assert_eq!(x.call(10, 11), Some(0)); // Match method 0
597        assert_eq!(x.call(10, 12), Some(1)); // Match method 1
598        assert_eq!(x.call(11, 12), Some(1)); // Match method 1
599        assert_eq!(x.call(12, 12), Some(2)); // Match method 2
600        assert_eq!(x.call(13, 12), Some(3)); // Match method 3
601
602        // Check error handling.
603        {
604            assert!(x.call(10, 7).is_none());
605            let m = x.debug(3, &9, &7).unwrap_err();
606            // The closest hit is method 0, followed by method 1.
607            assert_eq!(m[0].method(), "method 0");
608            assert_eq!(m[1].method(), "method 1");
609            assert_eq!(m[2].method(), "method 2");
610
611            let mismatches = m[0].mismatches();
612            // The first argument is a match - the second argument is a mismatch.
613            assert!(mismatches[0].is_none());
614            assert_eq!(
615                mismatches[1].as_ref().unwrap().to_string(),
616                "error: off by 3"
617            );
618
619            let mismatches = m[2].mismatches();
620            assert_eq!(
621                mismatches[0].as_ref().unwrap().to_string(),
622                "error: off by 3"
623            );
624            assert_eq!(
625                mismatches[1].as_ref().unwrap().to_string(),
626                "error: off by 3"
627            );
628        }
629
630        // Try again, but this time from the other direction.
631        {
632            let m = x.debug(4, &16, &12).unwrap_err();
633            assert_eq!(m[0].method(), "method 3");
634            assert_eq!(m[1].method(), "method 2");
635            assert_eq!(m[2].method(), "method 1");
636        }
637    }
638}