1use std::fmt::Formatter;
7
8use super::{
9 ArgumentMismatch, DispatchRule, FailureScore, Map, MatchScore, Signature, TaggedFailureScore,
10};
11
12fn coalesce<const N: usize>(
16 input: &[Result<MatchScore, FailureScore>; N],
17) -> Option<[MatchScore; N]> {
18 let mut output = [MatchScore(0); N];
19 for i in 0..N {
20 output[i] = match input[i] {
21 Ok(score) => score,
22 Err(_) => return None,
23 }
24 }
25 Some(output)
26}
27
28fn all_match<const N: usize, T>(input: &[Result<MatchScore, T>; N]) -> bool {
30 input.iter().all(|i| matches!(i, Ok(MatchScore(_))))
31}
32
33struct TaggedMatch<'a, const N: usize> {
37 method: &'a str,
38 score: [Result<MatchScore, TaggedFailureScore<'a>>; N],
39}
40
41impl<'a, const N: usize> From<TaggedMatch<'a, N>> for ArgumentMismatch<'a, N> {
42 fn from(value: TaggedMatch<'a, N>) -> Self {
43 ArgumentMismatch {
44 method: value.method,
45 mismatches: value.score.map(|r| match r {
46 Ok(_) => None,
47 Err(tagged) => Some(tagged.why),
48 }),
49 }
50 }
51}
52
53struct Queue<'a, const N: usize> {
55 buffer: Vec<TaggedMatch<'a, N>>,
56 max_methods: usize,
57}
58
59impl<'a, const N: usize> Queue<'a, N> {
60 fn new(max_methods: usize) -> Self {
61 Self {
62 buffer: Vec::with_capacity(max_methods),
63 max_methods,
64 }
65 }
66
67 fn finish(self) -> Vec<ArgumentMismatch<'a, N>> {
68 self.buffer.into_iter().map(|m| m.into()).collect()
69 }
70
71 fn push(&mut self, y: TaggedMatch<'a, N>) -> Result<(), ()> {
77 use std::cmp::Ordering;
78
79 if all_match(&y.score) {
80 return Err(());
81 }
82
83 let lt = |x: &TaggedMatch<'a, N>| {
86 for i in 0..N {
87 let xi = &x.score[i];
88 let yi = &y.score[i];
89 match xi {
90 Ok(MatchScore(x_score)) => match yi {
91 Ok(MatchScore(y_score)) => match x_score.cmp(y_score) {
92 Ordering::Equal => {}
93 strict => return strict,
94 },
95 Err(_) => {
96 return Ordering::Less;
97 }
98 },
99 Err(TaggedFailureScore { score: x_score, .. }) => match yi {
100 Ok(_) => {
101 return Ordering::Greater;
102 }
103 Err(TaggedFailureScore { score: y_score, .. }) => {
104 match x_score.cmp(y_score) {
105 Ordering::Equal => {}
106 strict => return strict,
107 }
108 }
109 },
110 }
111 }
112 Ordering::Equal
113 };
114
115 let i = match self.buffer.binary_search_by(lt) {
120 Ok(i) => i,
121 Err(i) => i,
122 };
123
124 if self.buffer.len() == self.max_methods {
125 if i > self.buffer.len() {
127 return Ok(());
128 }
129 self.buffer.insert(i, y);
130 self.buffer.truncate(self.max_methods);
131 } else {
132 self.buffer.insert(i, y);
133 }
134 Ok(())
135 }
136}
137
138pub trait Sealed {}
139
140macro_rules! implement_dispatch {
141 ($trait:ident,
142 $method:ident,
143 $dispatcher:ident,
144 $N:literal,
145 { $($T:ident )+ },
146 { $($x:ident )+ },
147 { $($A:ident )+ },
148 { $($lf:lifetime )+ }
149 ) => {
150 pub trait $trait<R, $($T,)*>: Sealed
181 where
182 $($T: Map,)*
183 {
184 fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result<MatchScore, FailureScore>; $N];
189
190 fn call(&self, $($x: $T::Type<'_>,)*) -> R;
200
201 fn signatures(&self) -> [Signature; $N];
203
204 fn try_match_verbose<'a, $($lf,)*>(
209 &'a self,
210 $($x: &'a $T::Type<$lf>,)*
211 ) -> [Result<MatchScore, TaggedFailureScore<'a>>; $N]
212 where
213 $($lf: 'a,)*;
214 }
215
216 pub struct $method<R, $($A,)*>
230 where
231 $($A: Map,)*
232 {
233 f: Box<dyn for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R>,
234 _types: std::marker::PhantomData<($($A,)*)>,
235 }
236
237 impl<R, $($A,)*> $method<R, $($A,)*>
259 where
260 $($A: Map,)*
261 {
262 fn new<F>(f: F) -> Self
263 where
264 F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static,
265 {
266 Self {
267 f: Box::new(f),
268 _types: std::marker::PhantomData,
269 }
270 }
271 }
272
273 impl<R, $($A,)*> Sealed for $method<R, $($A,)*>
274 where
275 $($A: Map,)*
276 {}
277
278 impl<R, $($T,)* $($A,)*> $trait<R, $($T,)*> for $method<R, $($A,)*>
279 where
280 $($T: Map,)*
281 $($A: Map,)*
282 $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)*
283 {
284 fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result<MatchScore, FailureScore>; $N] {
285 [$($A::Type::try_match($x),)*]
287 }
288
289 fn call(&self, $($x: $T::Type<'_>,)*) -> R {
290 (self.f)($($A::Type::convert($x).unwrap(),)*)
292 }
293
294 fn signatures(&self) -> [Signature; $N] {
295 [
301 $(Signature(|f: &mut Formatter<'_>| {
302 $A::Type::description(f, None::<&$T::Type<'_>>)
303 }),)*
304 ]
305 }
306
307 fn try_match_verbose<'a, $($lf,)*>(
308 &self,
309 $($x: &'a $T::Type<$lf>,)*
310 ) -> [Result<MatchScore, TaggedFailureScore<'a>>; $N]
311 where
312 $($lf: 'a,)*
313 {
314 [$($A::Type::try_match_verbose($x),)*]
316 }
317 }
318
319 pub struct $dispatcher<R, $($T,)*>
321 where
322 R: 'static,
323 $($T: Map,)*
324 {
325 pub(super) methods: Vec<(String, Box<dyn $trait<R, $($T,)*>>)>,
326 }
327
328 impl<R, $($T,)*> Default for $dispatcher<R, $($T,)*>
329 where
330 R: 'static,
331 $($T: Map,)*
332 {
333 fn default() -> Self {
334 Self::new()
335 }
336 }
337
338 impl<R, $($T,)*> $dispatcher<R, $($T,)*>
339 where
340 R: 'static,
341 $($T: Map,)*
342 {
343 pub fn new() -> Self {
345 Self { methods: Vec::new() }
346 }
347
348 pub fn register<F, $($A,)*>(&mut self, name: impl Into<String>, f: F)
350 where
351 $($A: Map,)*
352 $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)*
353 F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static,
354 {
355 let method = $method::<R, $($A,)*>::new(f);
356 self.methods.push((name.into(), Box::new(method)))
357 }
358
359 pub fn call(&self, $($x: $T::Type<'_>,)*) -> Option<R> {
363 let mut method: Option<(&_, [MatchScore; $N])> = None;
364 self.methods.iter().for_each(|m| {
365 match coalesce(&(m.1.try_match($(&$x,)*))) {
366 Some(score) => match method.as_mut() {
368 Some(method) => {
369 if score < method.1 {
370 *method = (m, score)
371 }
372 }
373 None => {
374 method.replace((m, score));
375 }
376 },
377 None => {}
378 }
379 });
380
381 method.map(|(m, _)| m.1.call($($x,)*))
383 }
384
385 pub fn methods(
387 &self
388 ) -> impl ExactSizeIterator<Item = &(String, Box<dyn $trait<R, $($T,)*>>)> {
389 self.methods.iter()
390 }
391
392 pub fn has_match(&self, $($x: &$T::Type<'_>,)*) -> bool {
395 for m in self.methods.iter() {
396 if all_match(&m.1.try_match($(&$x,)*)) {
397 return true;
398 }
399 }
400 return false;
401 }
402
403 pub fn debug<'a, $($lf,)*>(
413 &'a self,
414 max_methods: usize,
415 $($x: &'a $T::Type<$lf>,)*
416 ) -> Result<(), Vec<ArgumentMismatch<'a, $N>>>
417 where
418 $($lf: 'a,)*
419 {
420 let mut methods = Queue::new(max_methods);
421 for m in self.methods.iter() {
422 let t = TaggedMatch {
423 method: &m.0,
424 score: m.1.try_match_verbose($($x,)*),
425 };
426 match methods.push(t) {
427 Ok(()) => {},
428 Err(()) => return Ok(()),
429 }
430 }
431 Err(methods.finish())
432 }
433 }
434 }
435}
436
437implement_dispatch!(Dispatch1, Method1, Dispatcher1, 1, { T0 }, { x0 }, { A0 }, { 'a0 });
438implement_dispatch!(
439 Dispatch2, Method2, Dispatcher2, 2,
440 { T0 T1 }, { x0 x1 }, { A0 A1 }, { 'a0 'a1 }
441);
442implement_dispatch!(
443 Dispatch3, Method3, Dispatcher3, 3,
444 { T0 T1 T2 }, { x0 x1 x2 }, { A0 A1 A2 }, { 'a0 'a1 'a2 }
445);
446
447#[cfg(test)]
452mod tests {
453 use super::*;
454
455 struct Num<const N: usize>;
456
457 impl<const N: usize> Map for Num<N> {
458 type Type<'a> = Self;
459 }
460
461 impl<const N: usize> DispatchRule<usize> for Num<N> {
462 type Error = std::convert::Infallible;
463
464 fn try_match(from: &usize) -> Result<MatchScore, FailureScore> {
467 let diff = from.abs_diff(N);
468 if diff <= 2 {
469 Ok(MatchScore(diff as u32))
470 } else {
471 Err(FailureScore(diff as u32))
472 }
473 }
474
475 fn convert(from: usize) -> Result<Self, Self::Error> {
476 assert!(from.abs_diff(N) <= 2);
477 Ok(Self)
478 }
479
480 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&usize>) -> std::fmt::Result {
481 match from {
482 None => write!(f, "{}", N),
483 Some(value) => {
484 let diff = value.abs_diff(N);
485 match diff {
486 0 => write!(f, "success: exact match"),
487 1 => write!(f, "success: off by 1"),
488 2 => write!(f, "success: off by 2"),
489 x => write!(f, "error: off by {}", x),
490 }
491 }
492 }
493 }
494 }
495
496 #[test]
501 fn test_dispatch_1() {
502 let mut x = Dispatcher1::<usize, usize>::default();
503 x.register::<_, Num<0>>("method 0", |_| 0);
504 x.register::<_, Num<3>>("method 3", |_| 3);
505 x.register::<_, Num<5>>("method 5", |_| 5);
506 x.register::<_, Num<8>>("method 8", |_| 8);
507
508 {
509 let methods: Vec<_> = x.methods().collect();
510 assert_eq!(methods.len(), 4);
511 assert_eq!(methods[0].0, "method 0");
512 assert_eq!(methods[0].1.signatures()[0].to_string(), "0");
513
514 assert_eq!(methods[1].0, "method 3");
515 assert_eq!(methods[1].1.signatures()[0].to_string(), "3");
516 }
517
518 assert_eq!(x.call(0), Some(0));
520 assert_eq!(x.call(1), Some(0));
521 assert_eq!(x.call(2), Some(3));
522 assert_eq!(x.call(3), Some(3));
523 assert_eq!(x.call(4), Some(3));
524 assert_eq!(x.call(5), Some(5));
525 assert_eq!(x.call(6), Some(5));
526 assert_eq!(x.call(7), Some(8));
527 assert_eq!(x.call(8), Some(8));
528 assert_eq!(x.call(11), None);
529
530 for i in 0..11 {
531 assert!(x.has_match(&i));
532 }
533 for i in 11..20 {
534 assert!(!x.has_match(&i));
535 }
536
537 assert!(x.debug(3, &10).is_ok());
539
540 let mismatches = x.debug(3, &11).unwrap_err();
541 assert_eq!(mismatches.len(), 3);
542
543 assert_eq!(mismatches[0].method(), "method 8");
545 assert_eq!(
546 mismatches[0].mismatches()[0].as_ref().unwrap().to_string(),
547 "error: off by 3"
548 );
549
550 assert_eq!(mismatches[1].method(), "method 5");
552 assert_eq!(
553 mismatches[1].mismatches()[0].as_ref().unwrap().to_string(),
554 "error: off by 6"
555 );
556
557 assert_eq!(mismatches[2].method(), "method 3");
559 assert_eq!(
560 mismatches[2].mismatches()[0].as_ref().unwrap().to_string(),
561 "error: off by 8"
562 );
563
564 assert_eq!(x.debug(10, &20).unwrap_err().len(), 4);
567 }
568
569 #[test]
574 fn test_dispatch_2() {
575 let mut x = Dispatcher2::<usize, usize, usize>::default();
576
577 x.register::<_, Num<10>, Num<10>>("method 0", |_, _| 0);
578 x.register::<_, Num<10>, Num<13>>("method 1", |_, _| 1);
579 x.register::<_, Num<13>, Num<12>>("method 3", |_, _| 3);
580 x.register::<_, Num<12>, Num<10>>("method 2", |_, _| 2);
581
582 {
583 let methods: Vec<_> = x.methods().collect();
584 assert_eq!(methods.len(), 4);
585 assert_eq!(methods[0].0, "method 0");
586 assert_eq!(methods[0].1.signatures()[0].to_string(), "10");
587 assert_eq!(methods[0].1.signatures()[1].to_string(), "10");
588
589 assert_eq!(methods[1].0, "method 1");
590 assert_eq!(methods[1].1.signatures()[0].to_string(), "10");
591 assert_eq!(methods[1].1.signatures()[1].to_string(), "13");
592 }
593
594 assert_eq!(x.call(10, 10), Some(0)); assert_eq!(x.call(10, 11), Some(0)); assert_eq!(x.call(10, 12), Some(1)); assert_eq!(x.call(11, 12), Some(1)); assert_eq!(x.call(12, 12), Some(2)); assert_eq!(x.call(13, 12), Some(3)); {
604 assert!(x.call(10, 7).is_none());
605 let m = x.debug(3, &9, &7).unwrap_err();
606 assert_eq!(m[0].method(), "method 0");
608 assert_eq!(m[1].method(), "method 1");
609 assert_eq!(m[2].method(), "method 2");
610
611 let mismatches = m[0].mismatches();
612 assert!(mismatches[0].is_none());
614 assert_eq!(
615 mismatches[1].as_ref().unwrap().to_string(),
616 "error: off by 3"
617 );
618
619 let mismatches = m[2].mismatches();
620 assert_eq!(
621 mismatches[0].as_ref().unwrap().to_string(),
622 "error: off by 3"
623 );
624 assert_eq!(
625 mismatches[1].as_ref().unwrap().to_string(),
626 "error: off by 3"
627 );
628 }
629
630 {
632 let m = x.debug(4, &16, &12).unwrap_err();
633 assert_eq!(m[0].method(), "method 3");
634 assert_eq!(m[1].method(), "method 2");
635 assert_eq!(m[2].method(), "method 1");
636 }
637 }
638}