use depyler_hir::hir::Type;
use crate::type_system::constraint::{ConstraintKind, TypeConstraint};
pub struct SubtypeChecker {
cache: std::cell::RefCell<std::collections::HashMap<(Type, Type), bool>>,
}
impl SubtypeChecker {
pub fn new() -> Self {
Self {
cache: std::cell::RefCell::new(std::collections::HashMap::new()),
}
}
pub fn check_subtype(&self, lhs: &Type, rhs: &Type) -> Result<(), String> {
if let Some(&result) = self.cache.borrow().get(&(lhs.clone(), rhs.clone())) {
return if result {
Ok(())
} else {
Err(format!("{:?} is not a subtype of {:?}", lhs, rhs))
};
}
let result = self.check_subtype_uncached(lhs, rhs);
self.cache
.borrow_mut()
.insert((lhs.clone(), rhs.clone()), result.is_ok());
result
}
fn check_subtype_uncached(&self, lhs: &Type, rhs: &Type) -> Result<(), String> {
if lhs == rhs {
return Ok(());
}
match (lhs, rhs) {
(Type::Int, Type::Float) => Ok(()),
(ty, Type::Optional(inner)) if ty == inner.as_ref() => Ok(()),
(Type::Optional(t1), Type::Optional(t2)) => self.check_subtype(t1, t2),
(Type::List(t1), Type::List(t2)) => self.check_subtype(t1, t2),
(Type::UnificationVar(_), _) | (_, Type::UnificationVar(_)) => {
Ok(()) }
_ => Err(format!("{:?} is not a subtype of {:?}", lhs, rhs)),
}
}
pub fn check_constraint(&self, constraint: &TypeConstraint) -> Result<(), String> {
match constraint.kind {
ConstraintKind::Eq => {
if constraint.lhs == constraint.rhs {
Ok(())
} else {
Err(format!(
"Type mismatch: {:?} != {:?} ({})",
constraint.lhs, constraint.rhs, constraint.reason
))
}
}
ConstraintKind::Subtype => self
.check_subtype(&constraint.lhs, &constraint.rhs)
.map_err(|e| format!("{} ({})", e, constraint.reason)),
ConstraintKind::Supertype => {
self.check_subtype(&constraint.rhs, &constraint.lhs)
.map_err(|e| format!("{} ({})", e, constraint.reason))
}
_ => Err(format!(
"Unsupported constraint kind: {:?}",
constraint.kind
)),
}
}
}
impl Default for SubtypeChecker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reflexivity() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::Int, &Type::Int).is_ok());
}
#[test]
fn test_reflexivity_string() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::String, &Type::String).is_ok());
}
#[test]
fn test_reflexivity_bool() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::Bool, &Type::Bool).is_ok());
}
#[test]
fn test_numeric_tower() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::Int, &Type::Float).is_ok());
}
#[test]
fn test_no_narrowing() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::Float, &Type::Int).is_err());
}
#[test]
fn test_option_lift() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(&Type::Int, &Type::Optional(Box::new(Type::Int)));
assert!(result.is_ok());
}
#[test]
fn test_option_covariance() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(
&Type::Optional(Box::new(Type::Int)),
&Type::Optional(Box::new(Type::Float)),
);
assert!(result.is_ok());
}
#[test]
fn test_option_no_contravariance() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(
&Type::Optional(Box::new(Type::Float)),
&Type::Optional(Box::new(Type::Int)),
);
assert!(result.is_err());
}
#[test]
fn test_list_covariance() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(
&Type::List(Box::new(Type::Int)),
&Type::List(Box::new(Type::Float)),
);
assert!(result.is_ok());
}
#[test]
fn test_list_no_contravariance() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(
&Type::List(Box::new(Type::Float)),
&Type::List(Box::new(Type::Int)),
);
assert!(result.is_err());
}
#[test]
fn test_unification_var_deferred() {
let checker = SubtypeChecker::new();
let result = checker.check_subtype(&Type::UnificationVar(42), &Type::Int);
assert!(result.is_ok());
let result = checker.check_subtype(&Type::Int, &Type::UnificationVar(99));
assert!(result.is_ok());
}
#[test]
fn test_unrelated_types() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::String, &Type::Int).is_err());
assert!(checker.check_subtype(&Type::Bool, &Type::String).is_err());
}
#[test]
fn test_cache_hit() {
let checker = SubtypeChecker::new();
assert!(checker.check_subtype(&Type::Int, &Type::Float).is_ok());
assert!(checker.check_subtype(&Type::Int, &Type::Float).is_ok());
assert!(checker.check_subtype(&Type::String, &Type::Int).is_err());
assert!(checker.check_subtype(&Type::String, &Type::Int).is_err());
}
#[test]
fn test_default_impl() {
let checker = SubtypeChecker::default();
assert!(checker.check_subtype(&Type::Int, &Type::Int).is_ok());
}
#[test]
fn test_check_constraint_eq_success() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::Int,
kind: ConstraintKind::Eq,
reason: "test".to_string(),
};
assert!(checker.check_constraint(&constraint).is_ok());
}
#[test]
fn test_check_constraint_eq_failure() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::String,
kind: ConstraintKind::Eq,
reason: "test".to_string(),
};
assert!(checker.check_constraint(&constraint).is_err());
}
#[test]
fn test_check_constraint_subtype() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::Float,
kind: ConstraintKind::Subtype,
reason: "numeric coercion".to_string(),
};
assert!(checker.check_constraint(&constraint).is_ok());
}
#[test]
fn test_check_constraint_supertype() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Float,
rhs: Type::Int,
kind: ConstraintKind::Supertype,
reason: "reverse numeric coercion".to_string(),
};
assert!(checker.check_constraint(&constraint).is_ok());
}
#[test]
fn test_check_constraint_unsupported_callable() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::Int,
kind: ConstraintKind::Callable,
reason: "test".to_string(),
};
assert!(checker.check_constraint(&constraint).is_err());
}
#[test]
fn test_check_constraint_unsupported_hasfield() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::Int,
kind: ConstraintKind::HasField("foo".to_string()),
reason: "test".to_string(),
};
assert!(checker.check_constraint(&constraint).is_err());
}
#[test]
fn test_check_constraint_unsupported_arithmetic() {
let checker = SubtypeChecker::new();
let constraint = TypeConstraint {
lhs: Type::Int,
rhs: Type::Int,
kind: ConstraintKind::Arithmetic,
reason: "test".to_string(),
};
assert!(checker.check_constraint(&constraint).is_err());
}
}