use crate::type_checker::{TcVar, TypeChecker};
use crate::types::{Abstract, AbstractTypeTable, Generalizable, ReificationErr, TryReifiable, TypeTable};
use crate::{TcErr, TcKey};
use std::cmp::max;
use std::convert::TryInto;
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum AbstractType {
Any,
Fixed(u8, u8),
Integer(u8),
Numeric,
Bool,
}
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum ConcreteType {
Int128,
FixedPointI64F64,
Bool,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
struct Variable(usize);
impl TcVar for Variable {}
impl Abstract for AbstractType {
type Err = ();
fn unconstrained() -> Self {
AbstractType::Any
}
fn meet(&self, right: &Self) -> Result<Self, Self::Err> {
use crate::tests::AbstractType::*;
match (self, right) {
(Any, other) | (other, Any) => Ok(other.clone()),
(Integer(l), Integer(r)) => Ok(Integer(max(*r, *l))),
(Fixed(li, lf), Fixed(ri, rf)) => Ok(Fixed(max(*li, *ri), max(*lf, *rf))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) if *f == 0 => Ok(Integer(max(*i, *u))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) => Ok(Fixed(max(*i, *u), *f)),
(Bool, Bool) => Ok(Bool),
(Bool, _) | (_, Bool) => Err(()),
(Numeric, Integer(w)) | (Integer(w), Numeric) => Ok(Integer(*w)),
(Numeric, Fixed(i, f)) | (Fixed(i, f), Numeric) => Ok(Fixed(*i, *f)),
(Numeric, Numeric) => Ok(Numeric),
}
}
fn arity(&self) -> Option<usize> {
match self {
AbstractType::Any | AbstractType::Numeric => None,
AbstractType::Bool | AbstractType::Fixed(_, _) | AbstractType::Integer(_) => Some(0),
}
}
fn nth_child(&self, _n: usize) -> &Self {
panic!("cannot access non-extant children")
}
fn with_children<I>(&self, children: I) -> Self
where
I: IntoIterator<Item = Self>,
{
assert!(children.into_iter().next().is_none());
return self.clone();
}
}
#[derive(Clone, Copy, Debug)]
enum ParamType {
ParamId(usize),
Abstract(AbstractType),
}
#[derive(Clone, Debug)]
enum Expression {
Conditional {
cond: Box<Expression>,
cons: Box<Expression>,
alt: Box<Expression>,
},
PolyFn {
name: &'static str,
param_constraints: Vec<Option<AbstractType>>,
args: Vec<(ParamType, Expression)>,
returns: ParamType,
},
ConstInt(i128),
ConstBool(bool),
ConstFixed(i64, u64),
}
impl TryReifiable for AbstractType {
type Reified = ConcreteType;
fn try_reify(&self) -> Result<Self::Reified, ReificationErr> {
match self {
AbstractType::Any => Err(ReificationErr::TooGeneral("Cannot reify `Any`.".to_string())),
AbstractType::Integer(w) if *w <= 128 => Ok(ConcreteType::Int128),
AbstractType::Integer(w) => {
Err(ReificationErr::Conflicting(format!("Integer too wide, {}-bit not supported.", w)))
}
AbstractType::Fixed(i, f) if *i <= 64 && *f <= 64 => Ok(ConcreteType::FixedPointI64F64),
AbstractType::Fixed(i, f) => {
Err(ReificationErr::Conflicting(format!("Fixed point number too wide, I{}F{} not supported.", i, f)))
}
AbstractType::Numeric => Err(ReificationErr::TooGeneral(
"Cannot reify a numeric value. Either define a default (int/fixed) or restrict type.".to_string(),
)),
AbstractType::Bool => Ok(ConcreteType::Bool),
}
}
}
impl Generalizable for ConcreteType {
type Generalized = AbstractType;
fn generalize(&self) -> Self::Generalized {
match self {
ConcreteType::FixedPointI64F64 => AbstractType::Fixed(64, 64),
ConcreteType::Int128 => AbstractType::Integer(128),
ConcreteType::Bool => AbstractType::Bool,
}
}
}
fn build_type_error() -> Expression {
use Expression::*;
let const_false = ConstBool(false);
let const43 = ConstFixed(4, 3);
PolyFn {
name: "addition", param_constraints: vec![Some(AbstractType::Numeric)],
args: vec![(ParamType::ParamId(0), const43), (ParamType::ParamId(0), const_false)],
returns: ParamType::ParamId(0),
}
}
fn create_addition(lhs: Expression, rhs: Expression) -> Expression {
Expression::PolyFn {
name: "addition", param_constraints: vec![Some(AbstractType::Numeric)],
args: vec![(ParamType::ParamId(0), lhs), (ParamType::ParamId(0), rhs)],
returns: ParamType::ParamId(0),
}
}
fn build_complex_expression_type_checks() -> Expression {
use Expression::*;
let const27 = ConstFixed(2, 7);
let const3 = ConstInt(3);
let const43 = ConstFixed(4, 3);
let const_true = ConstBool(true);
let exponentiation = PolyFn {
name: "exponentiation", param_constraints: vec![Some(AbstractType::Numeric)],
args: vec![(ParamType::ParamId(0), const27), (ParamType::Abstract(AbstractType::Integer(1)), const3.clone())],
returns: ParamType::ParamId(0),
};
let addition = create_addition(exponentiation, const43);
Conditional { cond: Box::new(const_true), cons: Box::new(addition), alt: Box::new(const3) }
}
fn tc_expr<Var: TcVar>(
mut tc: TypeChecker<AbstractType, Var>,
expr: &Expression,
) -> Result<(TcKey, AbstractTypeTable<AbstractType>), TcErr<AbstractType>> {
let key = _tc_expr(&mut tc, expr)?;
let tt = tc.type_check()?;
Ok((key, tt))
}
fn _tc_expr<Var: TcVar>(
tc: &mut TypeChecker<AbstractType, Var>,
expr: &Expression,
) -> Result<TcKey, TcErr<AbstractType>> {
use Expression::*;
let key_result = tc.new_term_key(); match expr {
ConstInt(c) => {
let width = (128 - c.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(AbstractType::Integer(width)))?;
}
ConstFixed(i, f) => {
let int_width = (64 - i.leading_zeros()).try_into().unwrap();
let frac_width = (64 - f.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(AbstractType::Fixed(int_width, frac_width)))?;
}
ConstBool(_) => tc.impose(key_result.concretizes_concrete(ConcreteType::Bool))?,
Conditional { cond, cons, alt } => {
let key_cond = _tc_expr(tc, cond)?;
let key_cons = _tc_expr(tc, cons)?;
let key_alt = _tc_expr(tc, alt)?;
tc.impose(key_cond.concretizes_explicit(AbstractType::Bool))?;
tc.impose(key_result.is_meet_of(key_cons, key_alt))?;
}
PolyFn { name: _, param_constraints, args, returns } => {
let params: Vec<(Option<AbstractType>, TcKey)> =
param_constraints.iter().map(|p| (*p, tc.new_term_key())).collect();
¶ms;
for (arg_ty, arg_expr) in args {
let arg_key = _tc_expr(tc, arg_expr)?;
match arg_ty {
ParamType::ParamId(id) => {
let (p_constr, p_key) = params[*id];
tc.impose(p_key.concretizes(arg_key))?;
if let Some(c) = p_constr {
tc.impose(arg_key.concretizes_explicit(c))?;
}
}
ParamType::Abstract(at) => tc.impose(arg_key.concretizes_explicit(*at))?,
};
}
match returns {
ParamType::Abstract(at) => tc.impose(key_result.concretizes_explicit(*at))?,
ParamType::ParamId(id) => {
let (constr, key) = params[*id];
if let Some(c) = constr {
tc.impose(key_result.concretizes_explicit(c))?;
}
tc.impose(key_result.equate_with(key))?;
}
}
}
}
Ok(key_result)
}
#[test]
fn create_different_types() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let first = tc.new_term_key();
let second = tc.new_term_key();
assert_ne!(first, second);
}
#[test]
fn bound_by_concrete_transitive() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let first = tc.new_term_key();
let second = tc.new_term_key();
assert!(tc.impose(second.concretizes_concrete(ConcreteType::Int128)).is_ok());
assert!(tc.impose(first.equate_with(second)).is_ok());
let tt = tc.type_check().expect("unexpected type error").as_hashmap();
assert_eq!(tt[&first], tt[&second]);
}
#[test]
fn complex_type_check() {
let expr = build_complex_expression_type_checks();
let tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let (key, tt) = tc_expr(tc, &expr).expect("unexpected type error");
assert_eq!(tt.as_hashmap()[&key], AbstractType::Fixed(3, 3));
}
#[test]
fn failing_type_check() {
let expr = build_type_error();
let tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
if let Ok((key, tt)) = tc_expr(tc, &expr) {
panic!("unexpectedly got result type {:?}", tt.as_hashmap()[&key]);
}
}
#[test]
fn test_variable_dedup() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let var_a = tc.get_var_key(&Variable(0));
let term = tc.new_term_key();
let var_b = tc.get_var_key(&Variable(1));
let var_a_2 = tc.get_var_key(&Variable(0));
assert_eq!(var_a, var_a_2);
assert_ne!(var_a, term);
assert_ne!(var_a, var_b);
assert_ne!(term, var_b);
}
#[test]
fn test_asym_simple() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let key_a = tc.new_term_key();
let key_b = tc.new_term_key();
tc.impose(key_a.concretizes_explicit(AbstractType::Integer(3))).unwrap();
tc.impose(key_b.concretizes(key_a)).unwrap();
let tt = tc.type_check().expect("Unexpected type error.").as_hashmap();
assert_eq!(tt[&key_a], tt[&key_b]);
}
#[test]
fn test_asym_order() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let key_a = tc.new_term_key();
let key_b = tc.new_term_key();
tc.impose(key_b.concretizes(key_a)).unwrap();
tc.impose(key_a.concretizes_explicit(AbstractType::Integer(3))).unwrap();
let tt = tc.type_check().expect("Unexpected type error.").as_hashmap();
assert_eq!(tt[&key_a], tt[&key_b]);
}
#[test]
fn test_asym_separation() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let key_a = tc.new_term_key();
let key_b = tc.new_term_key();
let key_c = tc.new_term_key();
let a_type = AbstractType::Integer(3);
let c_type = AbstractType::Integer(12);
tc.impose(key_a.concretizes_explicit(a_type)).unwrap();
tc.impose(key_b.concretizes(key_a)).unwrap();
tc.impose(key_b.equate_with(key_c)).unwrap();
tc.impose(key_c.concretizes_explicit(c_type)).unwrap();
let tt = tc.type_check().expect("unexpected type error.").as_hashmap();
assert_eq!(tt[&key_b], tt[&key_c]);
assert_eq!(tt[&key_a], a_type);
assert_eq!(tt[&key_b], c_type);
assert_eq!(tt[&key_c], c_type);
}
#[test]
fn test_meet() {
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let key_a = tc.new_term_key();
let key_b = tc.new_term_key();
let key_c = tc.new_term_key();
let a_type = AbstractType::Integer(3);
let c_type = AbstractType::Integer(12);
tc.impose(key_a.concretizes_explicit(a_type)).unwrap();
tc.impose(key_c.concretizes_explicit(c_type)).unwrap();
tc.impose(key_b.is_meet_of(key_a, key_c)).unwrap();
let tt = tc.type_check().expect("unexpected type error.").as_hashmap();
assert_eq!(tt[&key_b], tt[&key_c]);
assert_eq!(tt[&key_a], a_type);
assert_eq!(tt[&key_b], c_type);
assert_eq!(tt[&key_c], c_type);
}