1use std::collections::{HashMap, HashSet};
23
24use bock_ast::BinOp;
25use bock_interp::TypeTag;
26
27pub type MethodName = &'static str;
29
30pub struct TraitDispatch {
37 binop_methods: HashMap<BinOp, MethodName>,
39 trait_impls: HashMap<MethodName, HashSet<TypeTag>>,
43 known_traits: HashSet<&'static str>,
48 trait_methods: HashMap<&'static str, Vec<MethodName>>,
50}
51
52impl Default for TraitDispatch {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl TraitDispatch {
59 #[must_use]
62 pub fn new() -> Self {
63 let mut binop_methods = HashMap::new();
64
65 binop_methods.insert(BinOp::Lt, "compare");
67 binop_methods.insert(BinOp::Le, "compare");
68 binop_methods.insert(BinOp::Gt, "compare");
69 binop_methods.insert(BinOp::Ge, "compare");
70
71 binop_methods.insert(BinOp::Eq, "equals");
73 binop_methods.insert(BinOp::Ne, "equals");
74
75 binop_methods.insert(BinOp::Add, "add");
77 binop_methods.insert(BinOp::Sub, "sub");
78 binop_methods.insert(BinOp::Mul, "mul");
79 binop_methods.insert(BinOp::Div, "div");
80 binop_methods.insert(BinOp::Rem, "rem");
81
82 let mut dispatch = Self {
83 binop_methods,
84 trait_impls: HashMap::new(),
85 known_traits: HashSet::new(),
86 trait_methods: HashMap::new(),
87 };
88
89 dispatch.register_builtins();
91
92 dispatch.register_prelude_traits();
94
95 dispatch
96 }
97
98 fn register_builtins(&mut self) {
100 for ty in [
102 TypeTag::Int,
103 TypeTag::Float,
104 TypeTag::Bool,
105 TypeTag::String,
106 TypeTag::Char,
107 TypeTag::List,
108 TypeTag::Map,
109 TypeTag::Set,
110 ] {
111 self.register_trait(ty, "compare");
112 }
113
114 for ty in [
116 TypeTag::Int,
117 TypeTag::Float,
118 TypeTag::Bool,
119 TypeTag::String,
120 TypeTag::Char,
121 TypeTag::List,
122 TypeTag::Map,
123 TypeTag::Set,
124 ] {
125 self.register_trait(ty, "equals");
126 }
127
128 for ty in [
130 TypeTag::Int,
131 TypeTag::Float,
132 TypeTag::Bool,
133 TypeTag::String,
134 TypeTag::Char,
135 TypeTag::List,
136 TypeTag::Map,
137 TypeTag::Set,
138 TypeTag::Optional,
139 TypeTag::Result,
140 ] {
141 self.register_trait(ty, "display");
142 }
143
144 for ty in [TypeTag::List, TypeTag::Set, TypeTag::Map, TypeTag::Range] {
146 self.register_trait(ty, "iter");
147 }
148
149 for ty in [TypeTag::Int, TypeTag::Float, TypeTag::String] {
151 self.register_trait(ty, "add");
152 }
153
154 for ty in [TypeTag::Int, TypeTag::Float] {
156 self.register_trait(ty, "sub");
157 }
158
159 for ty in [TypeTag::Int, TypeTag::Float] {
161 self.register_trait(ty, "mul");
162 }
163
164 for ty in [TypeTag::Int, TypeTag::Float] {
166 self.register_trait(ty, "div");
167 }
168
169 for ty in [TypeTag::Int, TypeTag::Float] {
171 self.register_trait(ty, "rem");
172 }
173
174 for ty in [
176 TypeTag::Int,
177 TypeTag::Float,
178 TypeTag::Bool,
179 TypeTag::String,
180 TypeTag::Char,
181 TypeTag::List,
182 TypeTag::Map,
183 TypeTag::Set,
184 ] {
185 self.register_trait(ty, "hash_code");
186 }
187
188 self.register_trait(TypeTag::Int, "into");
191 self.register_trait(TypeTag::Int, "from");
192 self.register_trait(TypeTag::Float, "from");
193 self.register_trait(TypeTag::String, "from");
194
195 for ty in [
197 TypeTag::Int,
198 TypeTag::Float,
199 TypeTag::Bool,
200 TypeTag::String,
201 TypeTag::Char,
202 ] {
203 self.register_trait(ty, "default");
204 }
205 }
206
207 fn register_prelude_traits(&mut self) {
210 self.register_known_trait("Comparable", &["compare"]);
212 self.register_known_trait("Equatable", &["equals"]);
213 self.register_known_trait("Hashable", &["hash_code"]);
214 self.register_known_trait("Displayable", &["display"]);
215 self.register_known_trait("Iterable", &["iter"]);
216 self.register_known_trait("Add", &["add"]);
217 self.register_known_trait("Sub", &["sub"]);
218 self.register_known_trait("Mul", &["mul"]);
219 self.register_known_trait("Div", &["div"]);
220 self.register_known_trait("Rem", &["rem"]);
221 self.register_known_trait("Into", &["into"]);
222 self.register_known_trait("From", &["from"]);
223
224 self.register_known_trait("Default", &["default"]);
226 self.register_known_trait("Serializable", &[]);
227 self.register_known_trait("Cloneable", &[]);
228 self.register_known_trait("TryFrom", &[]);
229 self.register_known_trait("Collectable", &[]);
230 }
231
232 fn register_known_trait(&mut self, name: &'static str, methods: &[MethodName]) {
235 self.known_traits.insert(name);
236 self.trait_methods.insert(name, methods.to_vec());
237 }
238
239 pub fn register_trait(&mut self, type_tag: TypeTag, method: MethodName) {
244 self.trait_impls.entry(method).or_default().insert(type_tag);
245 }
246
247 #[must_use]
249 pub fn is_known_trait(&self, name: &str) -> bool {
250 self.known_traits.contains(name)
251 }
252
253 #[must_use]
257 pub fn trait_method_names(&self, trait_name: &str) -> &[MethodName] {
258 self.trait_methods
259 .get(trait_name)
260 .map(Vec::as_slice)
261 .unwrap_or(&[])
262 }
263
264 #[must_use]
266 pub fn known_trait_names(&self) -> Vec<&'static str> {
267 let mut names: Vec<&'static str> = self.known_traits.iter().copied().collect();
268 names.sort_unstable();
269 names
270 }
271
272 #[must_use]
278 pub fn resolve_binop(&self, op: BinOp, lhs_type: TypeTag) -> Option<MethodName> {
279 let method = self.binop_methods.get(&op)?;
280 if self.has_trait(lhs_type, method) {
281 Some(method)
282 } else {
283 None
284 }
285 }
286
287 #[must_use]
289 pub fn resolve_for_in(&self, collection_type: TypeTag) -> Option<MethodName> {
290 if self.has_trait(collection_type, "iter") {
291 Some("iter")
292 } else {
293 None
294 }
295 }
296
297 #[must_use]
300 pub fn resolve_display(&self, type_tag: TypeTag) -> Option<MethodName> {
301 if self.has_trait(type_tag, "display") {
302 Some("display")
303 } else {
304 None
305 }
306 }
307
308 #[must_use]
311 pub fn resolve_conversion(
312 &self,
313 type_tag: TypeTag,
314 direction: ConversionDirection,
315 ) -> Option<MethodName> {
316 let method = match direction {
317 ConversionDirection::Into => "into",
318 ConversionDirection::From => "from",
319 };
320 if self.has_trait(type_tag, method) {
321 Some(method)
322 } else {
323 None
324 }
325 }
326
327 #[must_use]
329 pub fn has_trait(&self, type_tag: TypeTag, method: &str) -> bool {
330 self.trait_impls
331 .get(method)
332 .is_some_and(|types| types.contains(&type_tag))
333 }
334
335 #[must_use]
337 pub fn types_implementing(&self, method: &str) -> Vec<TypeTag> {
338 self.trait_impls
339 .get(method)
340 .map(|types| types.iter().copied().collect())
341 .unwrap_or_default()
342 }
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq)]
347pub enum ConversionDirection {
348 Into,
350 From,
352}
353
354#[cfg(test)]
357mod tests {
358 use super::*;
359
360 fn dispatch() -> TraitDispatch {
361 TraitDispatch::new()
362 }
363
364 #[test]
367 fn comparable_resolves_lt_for_int() {
368 let d = dispatch();
369 assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Int), Some("compare"));
370 }
371
372 #[test]
373 fn comparable_resolves_ge_for_string() {
374 let d = dispatch();
375 assert_eq!(d.resolve_binop(BinOp::Ge, TypeTag::String), Some("compare"));
376 }
377
378 #[test]
379 fn comparable_resolves_gt_for_float() {
380 let d = dispatch();
381 assert_eq!(d.resolve_binop(BinOp::Gt, TypeTag::Float), Some("compare"));
382 }
383
384 #[test]
385 fn comparable_resolves_le_for_list() {
386 let d = dispatch();
387 assert_eq!(d.resolve_binop(BinOp::Le, TypeTag::List), Some("compare"));
388 }
389
390 #[test]
391 fn comparable_none_for_function() {
392 let d = dispatch();
393 assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Function), None);
394 }
395
396 #[test]
399 fn equatable_resolves_eq_for_int() {
400 let d = dispatch();
401 assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Int), Some("equals"));
402 }
403
404 #[test]
405 fn equatable_resolves_ne_for_bool() {
406 let d = dispatch();
407 assert_eq!(d.resolve_binop(BinOp::Ne, TypeTag::Bool), Some("equals"));
408 }
409
410 #[test]
411 fn equatable_none_for_iterator() {
412 let d = dispatch();
413 assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Iterator), None);
414 }
415
416 #[test]
419 fn iterable_resolves_for_list() {
420 let d = dispatch();
421 assert_eq!(d.resolve_for_in(TypeTag::List), Some("iter"));
422 }
423
424 #[test]
425 fn iterable_resolves_for_set() {
426 let d = dispatch();
427 assert_eq!(d.resolve_for_in(TypeTag::Set), Some("iter"));
428 }
429
430 #[test]
431 fn iterable_resolves_for_map() {
432 let d = dispatch();
433 assert_eq!(d.resolve_for_in(TypeTag::Map), Some("iter"));
434 }
435
436 #[test]
437 fn iterable_resolves_for_range() {
438 let d = dispatch();
439 assert_eq!(d.resolve_for_in(TypeTag::Range), Some("iter"));
440 }
441
442 #[test]
443 fn iterable_none_for_int() {
444 let d = dispatch();
445 assert_eq!(d.resolve_for_in(TypeTag::Int), None);
446 }
447
448 #[test]
451 fn displayable_resolves_for_int() {
452 let d = dispatch();
453 assert_eq!(d.resolve_display(TypeTag::Int), Some("display"));
454 }
455
456 #[test]
457 fn displayable_resolves_for_string() {
458 let d = dispatch();
459 assert_eq!(d.resolve_display(TypeTag::String), Some("display"));
460 }
461
462 #[test]
463 fn displayable_resolves_for_optional() {
464 let d = dispatch();
465 assert_eq!(d.resolve_display(TypeTag::Optional), Some("display"));
466 }
467
468 #[test]
469 fn displayable_none_for_function() {
470 let d = dispatch();
471 assert_eq!(d.resolve_display(TypeTag::Function), None);
472 }
473
474 #[test]
477 fn add_resolves_for_int() {
478 let d = dispatch();
479 assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Int), Some("add"));
480 }
481
482 #[test]
483 fn add_resolves_for_string() {
484 let d = dispatch();
485 assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::String), Some("add"));
486 }
487
488 #[test]
489 fn sub_resolves_for_float() {
490 let d = dispatch();
491 assert_eq!(d.resolve_binop(BinOp::Sub, TypeTag::Float), Some("sub"));
492 }
493
494 #[test]
495 fn mul_resolves_for_int() {
496 let d = dispatch();
497 assert_eq!(d.resolve_binop(BinOp::Mul, TypeTag::Int), Some("mul"));
498 }
499
500 #[test]
501 fn add_none_for_bool() {
502 let d = dispatch();
503 assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Bool), None);
504 }
505
506 #[test]
509 fn into_resolves_for_int() {
510 let d = dispatch();
511 assert_eq!(
512 d.resolve_conversion(TypeTag::Int, ConversionDirection::Into),
513 Some("into")
514 );
515 }
516
517 #[test]
518 fn from_resolves_for_float() {
519 let d = dispatch();
520 assert_eq!(
521 d.resolve_conversion(TypeTag::Float, ConversionDirection::From),
522 Some("from")
523 );
524 }
525
526 #[test]
527 fn from_resolves_for_string() {
528 let d = dispatch();
529 assert_eq!(
530 d.resolve_conversion(TypeTag::String, ConversionDirection::From),
531 Some("from")
532 );
533 }
534
535 #[test]
536 fn into_none_for_void() {
537 let d = dispatch();
538 assert_eq!(
539 d.resolve_conversion(TypeTag::Void, ConversionDirection::Into),
540 None
541 );
542 }
543
544 #[test]
547 fn register_custom_comparable() {
548 let mut d = dispatch();
549 d.register_trait(TypeTag::Record, "compare");
551 assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Record), Some("compare"));
552 assert_eq!(d.resolve_binop(BinOp::Ge, TypeTag::Record), Some("compare"));
553 }
554
555 #[test]
556 fn register_custom_equatable() {
557 let mut d = dispatch();
558 d.register_trait(TypeTag::Record, "equals");
559 assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Record), Some("equals"));
560 assert_eq!(d.resolve_binop(BinOp::Ne, TypeTag::Record), Some("equals"));
561 }
562
563 #[test]
564 fn register_custom_iterable() {
565 let mut d = dispatch();
566 d.register_trait(TypeTag::Record, "iter");
567 assert_eq!(d.resolve_for_in(TypeTag::Record), Some("iter"));
568 }
569
570 #[test]
571 fn register_custom_displayable() {
572 let mut d = dispatch();
573 d.register_trait(TypeTag::Record, "display");
574 assert_eq!(d.resolve_display(TypeTag::Record), Some("display"));
575 }
576
577 #[test]
578 fn register_custom_add() {
579 let mut d = dispatch();
580 d.register_trait(TypeTag::Record, "add");
581 assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Record), Some("add"));
582 }
583
584 #[test]
585 fn register_custom_from_into() {
586 let mut d = dispatch();
587 d.register_trait(TypeTag::Record, "into");
588 d.register_trait(TypeTag::Record, "from");
589 assert_eq!(
590 d.resolve_conversion(TypeTag::Record, ConversionDirection::Into),
591 Some("into")
592 );
593 assert_eq!(
594 d.resolve_conversion(TypeTag::Record, ConversionDirection::From),
595 Some("from")
596 );
597 }
598
599 #[test]
602 fn has_trait_positive() {
603 let d = dispatch();
604 assert!(d.has_trait(TypeTag::Int, "compare"));
605 assert!(d.has_trait(TypeTag::List, "iter"));
606 }
607
608 #[test]
609 fn has_trait_negative() {
610 let d = dispatch();
611 assert!(!d.has_trait(TypeTag::Int, "iter"));
612 assert!(!d.has_trait(TypeTag::Function, "compare"));
613 }
614
615 #[test]
616 fn types_implementing_compare() {
617 let d = dispatch();
618 let types = d.types_implementing("compare");
619 assert!(types.contains(&TypeTag::Int));
620 assert!(types.contains(&TypeTag::Float));
621 assert!(types.contains(&TypeTag::String));
622 assert!(!types.contains(&TypeTag::Function));
623 }
624
625 #[test]
626 fn types_implementing_unknown_method() {
627 let d = dispatch();
628 assert!(d.types_implementing("nonexistent").is_empty());
629 }
630
631 #[test]
634 fn logical_ops_not_trait_dispatched() {
635 let d = dispatch();
636 assert_eq!(d.resolve_binop(BinOp::And, TypeTag::Bool), None);
638 assert_eq!(d.resolve_binop(BinOp::Or, TypeTag::Bool), None);
639 }
640
641 #[test]
642 fn bitwise_ops_not_trait_dispatched() {
643 let d = dispatch();
644 assert_eq!(d.resolve_binop(BinOp::BitAnd, TypeTag::Int), None);
645 assert_eq!(d.resolve_binop(BinOp::BitOr, TypeTag::Int), None);
646 }
647
648 #[test]
651 fn all_prelude_traits_recognized() {
652 let d = dispatch();
653 for name in [
654 "Comparable",
655 "Equatable",
656 "Hashable",
657 "Displayable",
658 "Iterable",
659 "Add",
660 "Sub",
661 "Mul",
662 "Div",
663 "Rem",
664 "Into",
665 "From",
666 "Default",
667 "Serializable",
668 "Cloneable",
669 "TryFrom",
670 "Collectable",
671 ] {
672 assert!(
673 d.is_known_trait(name),
674 "trait `{name}` should be recognized"
675 );
676 }
677 }
678
679 #[test]
680 fn unknown_trait_not_recognized() {
681 let d = dispatch();
682 assert!(!d.is_known_trait("NonExistentTrait"));
683 }
684
685 #[test]
686 fn known_trait_names_includes_new_traits() {
687 let d = dispatch();
688 let names = d.known_trait_names();
689 assert!(names.contains(&"Serializable"));
690 assert!(names.contains(&"Cloneable"));
691 assert!(names.contains(&"Default"));
692 assert!(names.contains(&"TryFrom"));
693 assert!(names.contains(&"Collectable"));
694 }
695
696 #[test]
697 fn trait_method_names_for_default() {
698 let d = dispatch();
699 assert_eq!(d.trait_method_names("Default"), &["default"]);
700 }
701
702 #[test]
703 fn trait_method_names_for_stub_traits() {
704 let d = dispatch();
705 assert!(d.trait_method_names("Serializable").is_empty());
707 assert!(d.trait_method_names("Cloneable").is_empty());
708 assert!(d.trait_method_names("TryFrom").is_empty());
709 assert!(d.trait_method_names("Collectable").is_empty());
710 }
711
712 #[test]
715 fn hashable_recognized() {
716 let d = dispatch();
717 assert!(d.is_known_trait("Hashable"));
718 }
719
720 #[test]
721 fn hashable_method_is_hash_code() {
722 let d = dispatch();
723 assert_eq!(d.trait_method_names("Hashable"), &["hash_code"]);
724 }
725
726 #[test]
727 fn hashable_registered_for_int() {
728 let d = dispatch();
729 assert!(d.has_trait(TypeTag::Int, "hash_code"));
730 }
731
732 #[test]
733 fn hashable_registered_for_string() {
734 let d = dispatch();
735 assert!(d.has_trait(TypeTag::String, "hash_code"));
736 }
737
738 #[test]
739 fn hashable_registered_for_list() {
740 let d = dispatch();
741 assert!(d.has_trait(TypeTag::List, "hash_code"));
742 }
743
744 #[test]
745 fn hashable_not_registered_for_function() {
746 let d = dispatch();
747 assert!(!d.has_trait(TypeTag::Function, "hash_code"));
748 }
749
750 #[test]
753 fn default_registered_for_int() {
754 let d = dispatch();
755 assert!(d.has_trait(TypeTag::Int, "default"));
756 }
757
758 #[test]
759 fn default_registered_for_string() {
760 let d = dispatch();
761 assert!(d.has_trait(TypeTag::String, "default"));
762 }
763
764 #[test]
765 fn default_registered_for_bool() {
766 let d = dispatch();
767 assert!(d.has_trait(TypeTag::Bool, "default"));
768 }
769
770 #[test]
771 fn default_registered_for_float() {
772 let d = dispatch();
773 assert!(d.has_trait(TypeTag::Float, "default"));
774 }
775
776 #[test]
777 fn default_registered_for_char() {
778 let d = dispatch();
779 assert!(d.has_trait(TypeTag::Char, "default"));
780 }
781
782 #[test]
783 fn default_not_registered_for_function() {
784 let d = dispatch();
785 assert!(!d.has_trait(TypeTag::Function, "default"));
786 }
787
788 #[test]
791 fn derive_serializable_recognized() {
792 let d = dispatch();
795 assert!(d.is_known_trait("Serializable"));
796 }
797
798 #[test]
799 fn derive_cloneable_recognized() {
800 let d = dispatch();
801 assert!(d.is_known_trait("Cloneable"));
802 }
803
804 #[test]
805 fn derive_default_recognized() {
806 let d = dispatch();
807 assert!(d.is_known_trait("Default"));
808 }
809}