use rusttyc::{
types::Arity, Constructable, Partial, TcErr, TcKey, TcVar, TypeChecker, Variant as TcVariant, VarlessTypeChecker,
};
use std::cmp::max;
use std::convert::TryInto;
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum Type {
Int128,
FixedPointI64F64,
Bool,
}
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
enum Variant {
Any,
Fixed(u8, u8),
Integer(u8),
Numeric,
Bool,
}
#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
struct Variable(usize);
impl TcVar for Variable {}
impl TcVariant for Variant {
type Err = String;
fn top() -> Self {
Self::Any
}
fn meet(lhs: Partial<Self>, rhs: Partial<Self>) -> Result<Partial<Self>, Self::Err> {
use Variant::*;
assert_eq!(lhs.least_arity, 0, "spurious child");
assert_eq!(rhs.least_arity, 0, "spurious child");
let variant = match (lhs.variant, rhs.variant) {
(Any, other) | (other, Any) => Ok(other),
(Integer(l), Integer(r)) => Ok(Integer(max(r, l))),
(Fixed(li, lf), Fixed(ri, rf)) => Ok(Fixed(max(li, ri), max(lf, rf))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) if f == 0 => Ok(Integer(max(i, u))),
(Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) => Ok(Fixed(max(i, u), f)),
(Bool, Bool) => Ok(Bool),
(Bool, _) | (_, Bool) => Err("bool can only be combined with bool"),
(Numeric, Integer(w)) | (Integer(w), Numeric) => Ok(Integer(w)),
(Numeric, Fixed(i, f)) | (Fixed(i, f), Numeric) => Ok(Fixed(i, f)),
(Numeric, Numeric) => Ok(Numeric),
}?;
Ok(Partial { variant, least_arity: 0 })
}
fn arity(&self) -> Arity {
Arity::Fixed(0)
}
}
#[derive(Clone, Copy, Debug)]
enum ParamType {
ParamId(usize),
Abstract(Variant),
}
#[derive(Clone, Debug)]
enum Expression {
Conditional {
cond: Box<Expression>,
cons: Box<Expression>,
alt: Box<Expression>,
},
PolyFn {
name: &'static str,
param_constraints: Vec<Option<Variant>>,
args: Vec<(ParamType, Expression)>,
returns: ParamType,
},
ConstInt(i128),
ConstBool(bool),
ConstFixed(i64, u64),
}
impl Constructable for Variant {
type Type = Type;
fn construct(&self, children: &[Self::Type]) -> Result<Self::Type, Self::Err> {
assert!(children.is_empty(), "spurious children");
use Variant::*;
match self {
Any => Err("Cannot reify `Any`.".to_string()),
Integer(w) if *w <= 128 => Ok(Type::Int128),
Integer(w) => Err(format!("Integer too wide, {}-bit not supported.", w)),
Fixed(i, f) if *i <= 64 && *f <= 64 => Ok(Type::FixedPointI64F64),
Fixed(i, f) => Err(format!("Fixed point number too wide, I{}F{} not supported.", i, f)),
Numeric => {
Err("Cannot reify a numeric value. Either define a default (int/fixed) or restrict type.".to_string())
}
Bool => Ok(Type::Bool),
}
}
}
fn tc_expr(tc: &mut VarlessTypeChecker<Variant>, expr: &Expression) -> Result<TcKey, TcErr<Variant>> {
use Expression::*;
let key_result = tc.new_term_key(); match expr {
ConstInt(c) => {
let width = (128 - c.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(Variant::Integer(width)))?;
}
ConstFixed(i, f) => {
let int_width = (64 - i.leading_zeros()).try_into().unwrap();
let frac_width = (64 - f.leading_zeros()).try_into().unwrap();
tc.impose(key_result.concretizes_explicit(Variant::Fixed(int_width, frac_width)))?;
}
ConstBool(_) => tc.impose(key_result.concretizes_explicit(Variant::Bool))?,
Conditional { cond, cons, alt } => {
let key_cond = tc_expr(tc, cond)?;
let key_cons = tc_expr(tc, cons)?;
let key_alt = tc_expr(tc, alt)?;
tc.impose(key_cond.concretizes_explicit(Variant::Bool))?;
tc.impose(key_result.is_meet_of(key_cons, key_alt))?;
}
PolyFn { name: _, param_constraints, args, returns } => {
let params: Vec<(Option<Variant>, TcKey)> =
param_constraints.iter().map(|p| (*p, tc.new_term_key())).collect();
¶ms;
for (arg_ty, arg_expr) in args {
let arg_key = tc_expr(tc, arg_expr)?;
match arg_ty {
ParamType::ParamId(id) => {
let (p_constr, p_key) = params[*id];
tc.impose(p_key.concretizes(arg_key))?;
if let Some(c) = p_constr {
tc.impose(arg_key.concretizes_explicit(c))?;
}
}
ParamType::Abstract(at) => tc.impose(arg_key.concretizes_explicit(*at))?,
};
}
match returns {
ParamType::Abstract(at) => tc.impose(key_result.concretizes_explicit(*at))?,
ParamType::ParamId(id) => {
let (constr, key) = params[*id];
if let Some(c) = constr {
tc.impose(key_result.concretizes_explicit(c))?;
}
tc.impose(key_result.equate_with(key))?;
}
}
}
}
Ok(key_result)
}
fn main() {
let expr = build_complex_expression_type_checks();
let mut tc: VarlessTypeChecker<Variant> = TypeChecker::without_vars();
let res = tc_expr(&mut tc, &expr).and_then(|key| tc.type_check().map(|tt| (key, tt)));
match res {
Ok((key, tt)) => {
let res_type = tt[&key];
assert_eq!(res_type, Type::FixedPointI64F64);
}
Err(_) => panic!("Unexpected type error!"),
}
}
fn build_complex_expression_type_checks() -> Expression {
use Expression::*;
let const27 = ConstFixed(2, 7);
let const3 = ConstInt(3);
let const43 = ConstFixed(4, 3);
let const_true = ConstBool(true);
let exponentiation = PolyFn {
name: "exponentiation", param_constraints: vec![Some(Variant::Numeric)],
args: vec![(ParamType::ParamId(0), const27), (ParamType::Abstract(Variant::Integer(1)), const3.clone())],
returns: ParamType::ParamId(0),
};
let addition = create_addition(exponentiation, const43);
Conditional { cond: Box::new(const_true), cons: Box::new(addition), alt: Box::new(const3) }
}
fn create_addition(lhs: Expression, rhs: Expression) -> Expression {
Expression::PolyFn {
name: "addition", param_constraints: vec![Some(Variant::Numeric)],
args: vec![(ParamType::ParamId(0), lhs), (ParamType::ParamId(0), rhs)],
returns: ParamType::ParamId(0),
}
}