use rusttyc::types::Abstract;
use rusttyc::{
types::{Generalizable, ReificationErr, TryReifiable},
TcErr, TcKey, TypeChecker,
};
use std::cmp::max;
use std::convert::TryInto;
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
struct Key(u32);
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum AbstractType {
Any,
Fixed(u8, u8),
Integer(u8),
Numeric,
Bool,
}
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum ConcreteType {
Int128,
FixedPointI64F64,
Bool,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, Copy)]
enum Variant {
Fixed,
Integer,
Bool,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
struct Variable();
impl rusttyc::TcVar for Variable {}
impl Abstract for AbstractType {
type Err = String;
fn unconstrained() -> Self {
AbstractType::Any
}
fn meet(&self, other: &Self) -> Result<Self, Self::Err> {
use crate::AbstractType::*;
match (self, other) {
(Any, other) | (other, Any) => Ok(other.clone()),
(Integer(l), Integer(r)) => Ok(Integer(max(*r, *l))),
(Fixed(li, lf), Fixed(ri, rf)) => Ok(Fixed(max(*li, *ri), max(*lf, *rf))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) if *f == 0 => Ok(Integer(max(*i, *u))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) => Ok(Fixed(max(*i, *u), *f)),
(Bool, Bool) => Ok(Bool),
(Bool, _) | (_, Bool) => Err("bools cannot be met with other types.".to_string()),
(Numeric, Integer(w)) | (Integer(w), Numeric) => Ok(Integer(*w)),
(Numeric, Fixed(i, f)) | (Fixed(i, f), Numeric) => Ok(Fixed(*i, *f)),
(Numeric, Numeric) => Ok(Numeric),
}
}
fn arity(&self) -> Option<usize> {
match self {
Self::Bool | Self::Fixed(_, _) | Self::Integer(_) => Some(0),
_ => None,
}
}
fn with_children<I>(&self, children: I) -> Self
where
I: IntoIterator<Item = Self>,
{
assert!(children.into_iter().collect::<Vec<Self>>().is_empty());
self.clone()
}
fn nth_child(&self, _: usize) -> &Self {
panic!("will not be called")
}
}
#[derive(Clone, Copy, Debug)]
enum ParamType {
ParamId(usize),
Abstract(AbstractType),
}
#[derive(Clone, Debug)]
enum Expression {
Conditional {
cond: Box<Expression>,
cons: Box<Expression>,
alt: Box<Expression>,
},
PolyFn {
name: &'static str,
param_constraints: Vec<Option<AbstractType>>,
args: Vec<(ParamType, Expression)>,
returns: ParamType,
},
ConstInt(i128),
ConstBool(bool),
ConstFixed(i64, u64),
}
impl TryReifiable for AbstractType {
type Reified = ConcreteType;
fn try_reify(&self) -> Result<Self::Reified, ReificationErr> {
match self {
AbstractType::Any => Err(ReificationErr::TooGeneral("Cannot reify `Any`.".to_string())),
AbstractType::Integer(w) if *w <= 128 => Ok(ConcreteType::Int128),
AbstractType::Integer(w) => {
Err(ReificationErr::Conflicting(format!("Integer too wide, {}-bit not supported.", w)))
}
AbstractType::Fixed(i, f) if *i <= 64 && *f <= 64 => Ok(ConcreteType::FixedPointI64F64),
AbstractType::Fixed(i, f) => {
Err(ReificationErr::Conflicting(format!("Fixed point number too wide, I{}F{} not supported.", i, f)))
}
AbstractType::Numeric => Err(ReificationErr::TooGeneral(
"Cannot reify a numeric value. Either define a default (int/fixed) or restrict type.".to_string(),
)),
AbstractType::Bool => Ok(ConcreteType::Bool),
}
}
}
impl Generalizable for ConcreteType {
type Generalized = AbstractType;
fn generalize(&self) -> Self::Generalized {
match self {
ConcreteType::FixedPointI64F64 => AbstractType::Fixed(64, 64),
ConcreteType::Int128 => AbstractType::Integer(128),
ConcreteType::Bool => AbstractType::Bool,
}
}
}
fn build_complex_expression_type_checks() -> Expression {
use Expression::*;
let const27 = ConstFixed(2, 7);
let const3 = ConstInt(3);
let const43 = ConstFixed(4, 3);
let const_true = ConstBool(true);
let exponentiation = PolyFn {
name: "exponentiation", param_constraints: vec![Some(AbstractType::Numeric)],
args: vec![(ParamType::ParamId(0), const27), (ParamType::Abstract(AbstractType::Integer(1)), const3.clone())],
returns: ParamType::ParamId(0),
};
let addition = Expression::PolyFn {
name: "addition", param_constraints: vec![Some(AbstractType::Numeric)],
args: vec![(ParamType::ParamId(0), exponentiation), (ParamType::ParamId(0), const43)],
returns: ParamType::ParamId(0),
};
Conditional { cond: Box::new(const_true), cons: Box::new(addition), alt: Box::new(const3) }
}
fn tc_expr(tc: &mut TypeChecker<AbstractType, Variable>, expr: &Expression) -> Result<TcKey, TcErr<AbstractType>> {
use Expression::*;
let key_result = tc.new_term_key(); match expr {
ConstInt(c) => {
let width = (128 - c.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(AbstractType::Integer(width)))?;
}
ConstFixed(i, f) => {
let int_width = (64 - i.leading_zeros()).try_into().unwrap();
let frac_width = (64 - f.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(AbstractType::Fixed(int_width, frac_width)))?;
}
ConstBool(_) => tc.impose(key_result.concretizes_concrete(ConcreteType::Bool))?,
Conditional { cond, cons, alt } => {
let key_cond = tc_expr(tc, cond)?;
let key_cons = tc_expr(tc, cons)?;
let key_alt = tc_expr(tc, alt)?;
tc.impose(key_cond.concretizes_explicit(AbstractType::Bool))?;
tc.impose(key_result.is_meet_of(key_cons, key_alt))?;
}
PolyFn { name: _, param_constraints, args, returns } => {
let params: Vec<(Option<AbstractType>, TcKey)> =
param_constraints.iter().map(|p| (*p, tc.new_term_key())).collect();
for (arg_ty, arg_expr) in args {
let arg_key = tc_expr(tc, arg_expr)?;
match arg_ty {
ParamType::ParamId(id) => {
let (p_constr, p_key) = params[*id];
tc.impose(p_key.concretizes(arg_key))?;
if let Some(c) = p_constr {
tc.impose(arg_key.concretizes_explicit(c))?;
}
}
ParamType::Abstract(at) => tc.impose(arg_key.concretizes_explicit(*at))?,
};
}
match returns {
ParamType::Abstract(at) => tc.impose(key_result.concretizes_explicit(*at))?,
ParamType::ParamId(id) => {
let (constr, key) = params[*id];
if let Some(c) = constr {
tc.impose(key_result.concretizes_explicit(c))?;
}
tc.impose(key_result.concretizes(key))?;
}
}
}
}
Ok(key_result)
}
fn main() {
let expr = build_complex_expression_type_checks();
let mut tc: TypeChecker<AbstractType, Variable> = TypeChecker::new();
let res = tc_expr(&mut tc, &expr).and_then(|key| tc.type_check().map(|tt| (key, tt)));
match res {
Ok((key, tt)) => {
let res_type = tt[key];
assert_eq!(res_type, AbstractType::Fixed(3, 3));
}
Err(_) => panic!("Unexpected type error!"),
}
}