use std::collections::{HashMap, HashSet};
use bock_ast::BinOp;
use bock_interp::TypeTag;
pub type MethodName = &'static str;
pub struct TraitDispatch {
binop_methods: HashMap<BinOp, MethodName>,
trait_impls: HashMap<MethodName, HashSet<TypeTag>>,
known_traits: HashSet<&'static str>,
trait_methods: HashMap<&'static str, Vec<MethodName>>,
}
impl Default for TraitDispatch {
fn default() -> Self {
Self::new()
}
}
impl TraitDispatch {
#[must_use]
pub fn new() -> Self {
let mut binop_methods = HashMap::new();
binop_methods.insert(BinOp::Lt, "compare");
binop_methods.insert(BinOp::Le, "compare");
binop_methods.insert(BinOp::Gt, "compare");
binop_methods.insert(BinOp::Ge, "compare");
binop_methods.insert(BinOp::Eq, "equals");
binop_methods.insert(BinOp::Ne, "equals");
binop_methods.insert(BinOp::Add, "add");
binop_methods.insert(BinOp::Sub, "sub");
binop_methods.insert(BinOp::Mul, "mul");
binop_methods.insert(BinOp::Div, "div");
binop_methods.insert(BinOp::Rem, "rem");
let mut dispatch = Self {
binop_methods,
trait_impls: HashMap::new(),
known_traits: HashSet::new(),
trait_methods: HashMap::new(),
};
dispatch.register_builtins();
dispatch.register_prelude_traits();
dispatch
}
fn register_builtins(&mut self) {
for ty in [
TypeTag::Int,
TypeTag::Float,
TypeTag::Bool,
TypeTag::String,
TypeTag::Char,
TypeTag::List,
TypeTag::Map,
TypeTag::Set,
] {
self.register_trait(ty, "compare");
}
for ty in [
TypeTag::Int,
TypeTag::Float,
TypeTag::Bool,
TypeTag::String,
TypeTag::Char,
TypeTag::List,
TypeTag::Map,
TypeTag::Set,
] {
self.register_trait(ty, "equals");
}
for ty in [
TypeTag::Int,
TypeTag::Float,
TypeTag::Bool,
TypeTag::String,
TypeTag::Char,
TypeTag::List,
TypeTag::Map,
TypeTag::Set,
TypeTag::Optional,
TypeTag::Result,
] {
self.register_trait(ty, "display");
}
for ty in [TypeTag::List, TypeTag::Set, TypeTag::Map, TypeTag::Range] {
self.register_trait(ty, "iter");
}
for ty in [TypeTag::Int, TypeTag::Float, TypeTag::String] {
self.register_trait(ty, "add");
}
for ty in [TypeTag::Int, TypeTag::Float] {
self.register_trait(ty, "sub");
}
for ty in [TypeTag::Int, TypeTag::Float] {
self.register_trait(ty, "mul");
}
for ty in [TypeTag::Int, TypeTag::Float] {
self.register_trait(ty, "div");
}
for ty in [TypeTag::Int, TypeTag::Float] {
self.register_trait(ty, "rem");
}
for ty in [
TypeTag::Int,
TypeTag::Float,
TypeTag::Bool,
TypeTag::String,
TypeTag::Char,
TypeTag::List,
TypeTag::Map,
TypeTag::Set,
] {
self.register_trait(ty, "hash_code");
}
self.register_trait(TypeTag::Int, "into");
self.register_trait(TypeTag::Int, "from");
self.register_trait(TypeTag::Float, "from");
self.register_trait(TypeTag::String, "from");
for ty in [
TypeTag::Int,
TypeTag::Float,
TypeTag::Bool,
TypeTag::String,
TypeTag::Char,
] {
self.register_trait(ty, "default");
}
}
fn register_prelude_traits(&mut self) {
self.register_known_trait("Comparable", &["compare"]);
self.register_known_trait("Equatable", &["equals"]);
self.register_known_trait("Hashable", &["hash_code"]);
self.register_known_trait("Displayable", &["display"]);
self.register_known_trait("Iterable", &["iter"]);
self.register_known_trait("Add", &["add"]);
self.register_known_trait("Sub", &["sub"]);
self.register_known_trait("Mul", &["mul"]);
self.register_known_trait("Div", &["div"]);
self.register_known_trait("Rem", &["rem"]);
self.register_known_trait("Into", &["into"]);
self.register_known_trait("From", &["from"]);
self.register_known_trait("Default", &["default"]);
self.register_known_trait("Serializable", &[]);
self.register_known_trait("Cloneable", &[]);
self.register_known_trait("TryFrom", &[]);
self.register_known_trait("Collectable", &[]);
}
fn register_known_trait(&mut self, name: &'static str, methods: &[MethodName]) {
self.known_traits.insert(name);
self.trait_methods.insert(name, methods.to_vec());
}
pub fn register_trait(&mut self, type_tag: TypeTag, method: MethodName) {
self.trait_impls.entry(method).or_default().insert(type_tag);
}
#[must_use]
pub fn is_known_trait(&self, name: &str) -> bool {
self.known_traits.contains(name)
}
#[must_use]
pub fn trait_method_names(&self, trait_name: &str) -> &[MethodName] {
self.trait_methods
.get(trait_name)
.map(Vec::as_slice)
.unwrap_or(&[])
}
#[must_use]
pub fn known_trait_names(&self) -> Vec<&'static str> {
let mut names: Vec<&'static str> = self.known_traits.iter().copied().collect();
names.sort_unstable();
names
}
#[must_use]
pub fn resolve_binop(&self, op: BinOp, lhs_type: TypeTag) -> Option<MethodName> {
let method = self.binop_methods.get(&op)?;
if self.has_trait(lhs_type, method) {
Some(method)
} else {
None
}
}
#[must_use]
pub fn resolve_for_in(&self, collection_type: TypeTag) -> Option<MethodName> {
if self.has_trait(collection_type, "iter") {
Some("iter")
} else {
None
}
}
#[must_use]
pub fn resolve_display(&self, type_tag: TypeTag) -> Option<MethodName> {
if self.has_trait(type_tag, "display") {
Some("display")
} else {
None
}
}
#[must_use]
pub fn resolve_conversion(
&self,
type_tag: TypeTag,
direction: ConversionDirection,
) -> Option<MethodName> {
let method = match direction {
ConversionDirection::Into => "into",
ConversionDirection::From => "from",
};
if self.has_trait(type_tag, method) {
Some(method)
} else {
None
}
}
#[must_use]
pub fn has_trait(&self, type_tag: TypeTag, method: &str) -> bool {
self.trait_impls
.get(method)
.is_some_and(|types| types.contains(&type_tag))
}
#[must_use]
pub fn types_implementing(&self, method: &str) -> Vec<TypeTag> {
self.trait_impls
.get(method)
.map(|types| types.iter().copied().collect())
.unwrap_or_default()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConversionDirection {
Into,
From,
}
#[cfg(test)]
mod tests {
use super::*;
fn dispatch() -> TraitDispatch {
TraitDispatch::new()
}
#[test]
fn comparable_resolves_lt_for_int() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Int), Some("compare"));
}
#[test]
fn comparable_resolves_ge_for_string() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Ge, TypeTag::String), Some("compare"));
}
#[test]
fn comparable_resolves_gt_for_float() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Gt, TypeTag::Float), Some("compare"));
}
#[test]
fn comparable_resolves_le_for_list() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Le, TypeTag::List), Some("compare"));
}
#[test]
fn comparable_none_for_function() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Function), None);
}
#[test]
fn equatable_resolves_eq_for_int() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Int), Some("equals"));
}
#[test]
fn equatable_resolves_ne_for_bool() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Ne, TypeTag::Bool), Some("equals"));
}
#[test]
fn equatable_none_for_iterator() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Iterator), None);
}
#[test]
fn iterable_resolves_for_list() {
let d = dispatch();
assert_eq!(d.resolve_for_in(TypeTag::List), Some("iter"));
}
#[test]
fn iterable_resolves_for_set() {
let d = dispatch();
assert_eq!(d.resolve_for_in(TypeTag::Set), Some("iter"));
}
#[test]
fn iterable_resolves_for_map() {
let d = dispatch();
assert_eq!(d.resolve_for_in(TypeTag::Map), Some("iter"));
}
#[test]
fn iterable_resolves_for_range() {
let d = dispatch();
assert_eq!(d.resolve_for_in(TypeTag::Range), Some("iter"));
}
#[test]
fn iterable_none_for_int() {
let d = dispatch();
assert_eq!(d.resolve_for_in(TypeTag::Int), None);
}
#[test]
fn displayable_resolves_for_int() {
let d = dispatch();
assert_eq!(d.resolve_display(TypeTag::Int), Some("display"));
}
#[test]
fn displayable_resolves_for_string() {
let d = dispatch();
assert_eq!(d.resolve_display(TypeTag::String), Some("display"));
}
#[test]
fn displayable_resolves_for_optional() {
let d = dispatch();
assert_eq!(d.resolve_display(TypeTag::Optional), Some("display"));
}
#[test]
fn displayable_none_for_function() {
let d = dispatch();
assert_eq!(d.resolve_display(TypeTag::Function), None);
}
#[test]
fn add_resolves_for_int() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Int), Some("add"));
}
#[test]
fn add_resolves_for_string() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::String), Some("add"));
}
#[test]
fn sub_resolves_for_float() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Sub, TypeTag::Float), Some("sub"));
}
#[test]
fn mul_resolves_for_int() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Mul, TypeTag::Int), Some("mul"));
}
#[test]
fn add_none_for_bool() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Bool), None);
}
#[test]
fn into_resolves_for_int() {
let d = dispatch();
assert_eq!(
d.resolve_conversion(TypeTag::Int, ConversionDirection::Into),
Some("into")
);
}
#[test]
fn from_resolves_for_float() {
let d = dispatch();
assert_eq!(
d.resolve_conversion(TypeTag::Float, ConversionDirection::From),
Some("from")
);
}
#[test]
fn from_resolves_for_string() {
let d = dispatch();
assert_eq!(
d.resolve_conversion(TypeTag::String, ConversionDirection::From),
Some("from")
);
}
#[test]
fn into_none_for_void() {
let d = dispatch();
assert_eq!(
d.resolve_conversion(TypeTag::Void, ConversionDirection::Into),
None
);
}
#[test]
fn register_custom_comparable() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "compare");
assert_eq!(d.resolve_binop(BinOp::Lt, TypeTag::Record), Some("compare"));
assert_eq!(d.resolve_binop(BinOp::Ge, TypeTag::Record), Some("compare"));
}
#[test]
fn register_custom_equatable() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "equals");
assert_eq!(d.resolve_binop(BinOp::Eq, TypeTag::Record), Some("equals"));
assert_eq!(d.resolve_binop(BinOp::Ne, TypeTag::Record), Some("equals"));
}
#[test]
fn register_custom_iterable() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "iter");
assert_eq!(d.resolve_for_in(TypeTag::Record), Some("iter"));
}
#[test]
fn register_custom_displayable() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "display");
assert_eq!(d.resolve_display(TypeTag::Record), Some("display"));
}
#[test]
fn register_custom_add() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "add");
assert_eq!(d.resolve_binop(BinOp::Add, TypeTag::Record), Some("add"));
}
#[test]
fn register_custom_from_into() {
let mut d = dispatch();
d.register_trait(TypeTag::Record, "into");
d.register_trait(TypeTag::Record, "from");
assert_eq!(
d.resolve_conversion(TypeTag::Record, ConversionDirection::Into),
Some("into")
);
assert_eq!(
d.resolve_conversion(TypeTag::Record, ConversionDirection::From),
Some("from")
);
}
#[test]
fn has_trait_positive() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Int, "compare"));
assert!(d.has_trait(TypeTag::List, "iter"));
}
#[test]
fn has_trait_negative() {
let d = dispatch();
assert!(!d.has_trait(TypeTag::Int, "iter"));
assert!(!d.has_trait(TypeTag::Function, "compare"));
}
#[test]
fn types_implementing_compare() {
let d = dispatch();
let types = d.types_implementing("compare");
assert!(types.contains(&TypeTag::Int));
assert!(types.contains(&TypeTag::Float));
assert!(types.contains(&TypeTag::String));
assert!(!types.contains(&TypeTag::Function));
}
#[test]
fn types_implementing_unknown_method() {
let d = dispatch();
assert!(d.types_implementing("nonexistent").is_empty());
}
#[test]
fn logical_ops_not_trait_dispatched() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::And, TypeTag::Bool), None);
assert_eq!(d.resolve_binop(BinOp::Or, TypeTag::Bool), None);
}
#[test]
fn bitwise_ops_not_trait_dispatched() {
let d = dispatch();
assert_eq!(d.resolve_binop(BinOp::BitAnd, TypeTag::Int), None);
assert_eq!(d.resolve_binop(BinOp::BitOr, TypeTag::Int), None);
}
#[test]
fn all_prelude_traits_recognized() {
let d = dispatch();
for name in [
"Comparable",
"Equatable",
"Hashable",
"Displayable",
"Iterable",
"Add",
"Sub",
"Mul",
"Div",
"Rem",
"Into",
"From",
"Default",
"Serializable",
"Cloneable",
"TryFrom",
"Collectable",
] {
assert!(
d.is_known_trait(name),
"trait `{name}` should be recognized"
);
}
}
#[test]
fn unknown_trait_not_recognized() {
let d = dispatch();
assert!(!d.is_known_trait("NonExistentTrait"));
}
#[test]
fn known_trait_names_includes_new_traits() {
let d = dispatch();
let names = d.known_trait_names();
assert!(names.contains(&"Serializable"));
assert!(names.contains(&"Cloneable"));
assert!(names.contains(&"Default"));
assert!(names.contains(&"TryFrom"));
assert!(names.contains(&"Collectable"));
}
#[test]
fn trait_method_names_for_default() {
let d = dispatch();
assert_eq!(d.trait_method_names("Default"), &["default"]);
}
#[test]
fn trait_method_names_for_stub_traits() {
let d = dispatch();
assert!(d.trait_method_names("Serializable").is_empty());
assert!(d.trait_method_names("Cloneable").is_empty());
assert!(d.trait_method_names("TryFrom").is_empty());
assert!(d.trait_method_names("Collectable").is_empty());
}
#[test]
fn hashable_recognized() {
let d = dispatch();
assert!(d.is_known_trait("Hashable"));
}
#[test]
fn hashable_method_is_hash_code() {
let d = dispatch();
assert_eq!(d.trait_method_names("Hashable"), &["hash_code"]);
}
#[test]
fn hashable_registered_for_int() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Int, "hash_code"));
}
#[test]
fn hashable_registered_for_string() {
let d = dispatch();
assert!(d.has_trait(TypeTag::String, "hash_code"));
}
#[test]
fn hashable_registered_for_list() {
let d = dispatch();
assert!(d.has_trait(TypeTag::List, "hash_code"));
}
#[test]
fn hashable_not_registered_for_function() {
let d = dispatch();
assert!(!d.has_trait(TypeTag::Function, "hash_code"));
}
#[test]
fn default_registered_for_int() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Int, "default"));
}
#[test]
fn default_registered_for_string() {
let d = dispatch();
assert!(d.has_trait(TypeTag::String, "default"));
}
#[test]
fn default_registered_for_bool() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Bool, "default"));
}
#[test]
fn default_registered_for_float() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Float, "default"));
}
#[test]
fn default_registered_for_char() {
let d = dispatch();
assert!(d.has_trait(TypeTag::Char, "default"));
}
#[test]
fn default_not_registered_for_function() {
let d = dispatch();
assert!(!d.has_trait(TypeTag::Function, "default"));
}
#[test]
fn derive_serializable_recognized() {
let d = dispatch();
assert!(d.is_known_trait("Serializable"));
}
#[test]
fn derive_cloneable_recognized() {
let d = dispatch();
assert!(d.is_known_trait("Cloneable"));
}
#[test]
fn derive_default_recognized() {
let d = dispatch();
assert!(d.is_known_trait("Default"));
}
}