use depyler_hir::hir::Type;
use crate::type_system::constraint::{ConstraintKind, TypeConstraint};
use crate::type_system::subtyping::SubtypeChecker;
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
pub struct Solution {
assignments: HashMap<usize, Type>,
consistent: bool,
}
impl Solution {
pub fn is_consistent(&self) -> bool {
self.consistent
}
pub fn get(&self, var_id: usize) -> Option<&Type> {
self.assignments.get(&var_id)
}
}
pub struct WorklistSolver {
constraints: VecDeque<TypeConstraint>,
assignments: HashMap<usize, Type>,
checker: SubtypeChecker,
iterations: usize,
max_iterations: usize,
}
impl WorklistSolver {
pub fn new() -> Self {
Self {
constraints: VecDeque::new(),
assignments: HashMap::new(),
checker: SubtypeChecker::new(),
iterations: 0,
max_iterations: 1000,
}
}
pub fn add_constraint(&mut self, constraint: TypeConstraint) {
self.constraints.push_back(constraint);
}
pub fn solve(&mut self) -> Result<Solution, String> {
while let Some(constraint) = self.constraints.pop_front() {
self.iterations += 1;
if self.iterations > self.max_iterations {
return Err(format!(
"Solver timeout after {} iterations",
self.max_iterations
));
}
self.process_constraint(&constraint)?;
}
Ok(Solution {
assignments: self.assignments.clone(),
consistent: true,
})
}
fn process_constraint(&mut self, constraint: &TypeConstraint) -> Result<(), String> {
match constraint.kind {
ConstraintKind::Eq => {
self.process_equality(&constraint.lhs, &constraint.rhs, &constraint.reason)
}
ConstraintKind::Subtype => {
self.process_subtype(&constraint.lhs, &constraint.rhs, &constraint.reason)
}
ConstraintKind::Supertype => {
self.process_subtype(&constraint.rhs, &constraint.lhs, &constraint.reason)
}
_ => Ok(()), }
}
fn process_equality(&mut self, lhs: &Type, rhs: &Type, reason: &str) -> Result<(), String> {
match (lhs, rhs) {
(Type::UnificationVar(var_id), ty) | (ty, Type::UnificationVar(var_id)) => {
if let Some(existing) = self.assignments.get(var_id) {
if existing != ty {
return Err(format!(
"Type mismatch: variable {} has type {:?}, expected {:?} ({})",
var_id, existing, ty, reason
));
}
} else {
self.assignments.insert(*var_id, ty.clone());
}
Ok(())
}
(t1, t2) if t1 == t2 => Ok(()),
_ => Err(format!(
"Equality constraint failed: {:?} != {:?} ({})",
lhs, rhs, reason
)),
}
}
fn process_subtype(&mut self, lhs: &Type, rhs: &Type, reason: &str) -> Result<(), String> {
match (lhs, rhs) {
(Type::UnificationVar(var_id), ty) => {
if let Some(existing) = self.assignments.get(var_id) {
self.checker
.check_subtype(existing, ty)
.map_err(|e| format!("{} ({})", e, reason))?;
} else {
self.assignments.insert(*var_id, ty.clone());
}
Ok(())
}
(ty, Type::UnificationVar(var_id)) => {
if let Some(existing) = self.assignments.get(var_id) {
self.checker
.check_subtype(ty, existing)
.map_err(|e| format!("{} ({})", e, reason))?;
} else {
self.assignments.insert(*var_id, ty.clone());
}
Ok(())
}
(t1, t2) => self
.checker
.check_subtype(t1, t2)
.map_err(|e| format!("{} ({})", e, reason)),
}
}
}
impl Default for WorklistSolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_solve_simple_equality() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"Assignment",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Int));
}
#[test]
fn test_solve_subtype_constraint() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::subtype(
Type::Int,
Type::Float,
"Function argument",
));
let solution = solver.solve().expect("Should solve");
assert!(solution.is_consistent());
}
#[test]
fn test_solve_transitive_subtyping() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::subtype(
Type::UnificationVar(0),
Type::UnificationVar(1),
"Transitivity test",
));
solver.add_constraint(TypeConstraint::subtype(
Type::UnificationVar(1),
Type::UnificationVar(2),
"Transitivity test",
));
let solution = solver.solve().expect("Should solve");
assert!(solution.is_consistent());
}
#[test]
fn test_default_solver() {
let solver = WorklistSolver::default();
assert_eq!(solver.iterations, 0);
}
#[test]
fn test_equality_right_unification_var() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::Int,
Type::UnificationVar(5),
"Assignment from right",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(5), Some(&Type::Int));
}
#[test]
fn test_equality_var_already_assigned_consistent() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"First assignment",
));
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"Same assignment",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Int));
}
#[test]
fn test_equality_var_conflict() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"First assignment",
));
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::String,
"Conflicting assignment",
));
let result = solver.solve();
assert!(result.is_err());
}
#[test]
fn test_equality_concrete_match() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(Type::Int, Type::Int, "Same type"));
let solution = solver.solve().expect("Should solve");
assert!(solution.is_consistent());
}
#[test]
fn test_equality_concrete_mismatch() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(Type::Int, Type::String, "Mismatch"));
let result = solver.solve();
assert!(result.is_err());
}
#[test]
fn test_subtype_unification_var_left() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::subtype(
Type::UnificationVar(0),
Type::Int,
"Upper bound",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Int));
}
#[test]
fn test_subtype_unification_var_right() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::subtype(
Type::Int,
Type::UnificationVar(0),
"Lower bound",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Int));
}
#[test]
fn test_subtype_var_left_existing_consistent() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"Assign",
));
solver.add_constraint(TypeConstraint::subtype(
Type::UnificationVar(0),
Type::Float,
"Check bound",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Int));
}
#[test]
fn test_subtype_var_left_existing_inconsistent() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Float,
"Assign",
));
solver.add_constraint(TypeConstraint::subtype(
Type::UnificationVar(0),
Type::Int,
"Check bound",
));
let result = solver.solve();
assert!(result.is_err());
}
#[test]
fn test_subtype_var_right_existing_consistent() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Float,
"Assign",
));
solver.add_constraint(TypeConstraint::subtype(
Type::Int,
Type::UnificationVar(0),
"Check bound",
));
let solution = solver.solve().expect("Should solve");
assert_eq!(solution.get(0), Some(&Type::Float));
}
#[test]
fn test_subtype_var_right_existing_inconsistent() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::eq(
Type::UnificationVar(0),
Type::Int,
"Assign",
));
solver.add_constraint(TypeConstraint::subtype(
Type::Float,
Type::UnificationVar(0),
"Check bound",
));
let result = solver.solve();
assert!(result.is_err());
}
#[test]
fn test_supertype_constraint() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint::supertype(
Type::Float,
Type::Int,
"Supertype check",
));
let solution = solver.solve().expect("Should solve");
assert!(solution.is_consistent());
}
#[test]
fn test_other_constraint_ignored() {
let mut solver = WorklistSolver::new();
solver.add_constraint(TypeConstraint {
lhs: Type::Int,
rhs: Type::Int,
kind: ConstraintKind::Callable,
reason: "Ignored".to_string(),
});
let solution = solver.solve().expect("Should solve");
assert!(solution.is_consistent());
}
#[test]
fn test_solution_get_missing() {
let solution = Solution {
assignments: HashMap::new(),
consistent: true,
};
assert!(solution.get(999).is_none());
}
#[test]
fn test_empty_constraints() {
let mut solver = WorklistSolver::new();
let solution = solver.solve().expect("Should solve empty constraints");
assert!(solution.is_consistent());
}
}