extern crate im;
#[macro_use]
extern crate moniker;
use im::HashMap;
use moniker::{Binder, BoundTerm, Embed, FreeVar, Rec, Scope, Var};
use std::rc::Rc;
#[derive(Debug, Clone, BoundTerm)]
pub enum Type {
Int,
Float,
String,
Var(Var<String>), Arrow(RcType, RcType),
Record(Vec<(String, RcType)>),
Variant(Vec<(String, RcType)>),
Rec(Scope<Rec<(Binder<String>, Embed<RcType>)>, ()>),
}
#[derive(Debug, Clone, BoundTerm)]
pub struct RcType {
pub inner: Rc<Type>,
}
impl From<Type> for RcType {
fn from(src: Type) -> RcType {
RcType {
inner: Rc::new(src),
}
}
}
impl RcType {
fn subst<N: PartialEq<Var<String>>>(&self, name: &N, replacement: &RcType) -> RcType {
match *self.inner {
Type::Var(ref var) if name == var => replacement.clone(),
Type::Var(_) | Type::Int | Type::Float | Type::String => self.clone(),
Type::Arrow(ref param, ref body) => RcType::from(Type::Arrow(
param.subst(name, replacement),
body.subst(name, replacement),
)),
Type::Record(ref fields) => {
let fields = fields
.iter()
.map(|&(ref label, ref elem)| (label.clone(), elem.subst(name, replacement)))
.collect();
RcType::from(Type::Record(fields))
},
Type::Variant(ref variants) => {
let variants = variants
.iter()
.map(|&(ref label, ref elem)| (label.clone(), elem.subst(name, replacement)))
.collect();
RcType::from(Type::Variant(variants))
},
Type::Rec(ref scope) => {
let (ref n, Embed(ref ann)) = scope.unsafe_pattern.unsafe_pattern;
RcType::from(Type::Rec(Scope {
unsafe_pattern: Rec {
unsafe_pattern: (n.clone(), Embed(ann.subst(name, replacement))),
},
unsafe_body: (),
}))
},
}
}
}
#[derive(Debug, Clone, PartialEq, BoundTerm, BoundPattern)]
pub enum Literal {
Int(i32),
Float(f32),
String(String),
}
#[derive(Debug, Clone, BoundPattern)]
pub enum Pattern {
Wildcard,
Ann(RcPattern, Embed<RcType>),
Literal(Literal),
Binder(Binder<String>),
Record(Vec<(String, RcPattern)>),
Tag(String, RcPattern),
}
#[derive(Debug, Clone, BoundPattern)]
pub struct RcPattern {
pub inner: Rc<Pattern>,
}
impl From<Pattern> for RcPattern {
fn from(src: Pattern) -> RcPattern {
RcPattern {
inner: Rc::new(src),
}
}
}
#[derive(Debug, Clone, BoundTerm)]
pub enum Expr {
Ann(RcExpr, RcType),
Literal(Literal),
Var(Var<String>), Lam(Scope<RcPattern, RcExpr>),
App(RcExpr, RcExpr),
Record(Vec<(String, RcExpr)>),
Proj(RcExpr, String),
Tag(String, RcExpr),
Case(RcExpr, Vec<Scope<RcPattern, RcExpr>>),
Fold(RcType, RcExpr),
Unfold(RcType, RcExpr),
}
#[derive(Debug, Clone, BoundTerm)]
pub struct RcExpr {
pub inner: Rc<Expr>,
}
impl From<Expr> for RcExpr {
fn from(src: Expr) -> RcExpr {
RcExpr {
inner: Rc::new(src),
}
}
}
impl RcExpr {
fn substs<N: PartialEq<Var<String>>>(&self, mappings: &[(N, RcExpr)]) -> RcExpr {
match *self.inner {
Expr::Ann(ref expr, ref ty) => {
RcExpr::from(Expr::Ann(expr.substs(mappings), ty.clone()))
},
Expr::Var(ref var) => match mappings.iter().find(|&(name, _)| name == var) {
Some((_, ref replacement)) => replacement.clone(),
None => self.clone(),
},
Expr::Literal(_) => self.clone(),
Expr::Lam(ref scope) => RcExpr::from(Expr::Lam(Scope {
unsafe_pattern: scope.unsafe_pattern.clone(),
unsafe_body: scope.unsafe_body.substs(mappings),
})),
Expr::App(ref fun, ref arg) => {
RcExpr::from(Expr::App(fun.substs(mappings), arg.substs(mappings)))
},
Expr::Record(ref fields) => {
let fields = fields
.iter()
.map(|&(ref label, ref elem)| (label.clone(), elem.substs(mappings)))
.collect();
RcExpr::from(Expr::Record(fields))
},
Expr::Proj(ref expr, ref label) => {
RcExpr::from(Expr::Proj(expr.substs(mappings), label.clone()))
},
Expr::Tag(ref label, ref expr) => {
RcExpr::from(Expr::Tag(label.clone(), expr.substs(mappings)))
},
Expr::Case(ref expr, ref clauses) => RcExpr::from(Expr::Case(
expr.substs(mappings),
clauses
.iter()
.map(|scope| Scope {
unsafe_pattern: scope.unsafe_pattern.clone(), unsafe_body: scope.unsafe_body.substs(mappings),
})
.collect(),
)),
Expr::Fold(ref ty, ref expr) => {
RcExpr::from(Expr::Fold(ty.clone(), expr.substs(mappings)))
},
Expr::Unfold(ref ty, ref expr) => {
RcExpr::from(Expr::Unfold(ty.clone(), expr.substs(mappings)))
},
}
}
}
type Context = HashMap<FreeVar<String>, RcType>;
pub fn eval(expr: &RcExpr) -> RcExpr {
match *expr.inner {
Expr::Ann(ref expr, _) => eval(expr),
Expr::Literal(_) | Expr::Var(_) | Expr::Lam(_) => expr.clone(),
Expr::App(ref fun, ref arg) => match *eval(fun).inner {
Expr::Lam(ref scope) => {
let (pattern, body) = scope.clone().unbind();
match match_expr(&pattern, &eval(arg)) {
None => expr.clone(), Some(mappings) => eval(&body.substs(&mappings)),
}
},
_ => expr.clone(),
},
Expr::Record(ref fields) => {
let fields = fields
.iter()
.map(|&(ref label, ref elem)| (label.clone(), eval(elem)))
.collect();
RcExpr::from(Expr::Record(fields))
},
Expr::Proj(ref expr, ref label) => {
let expr = eval(expr);
if let Expr::Record(ref fields) = *expr.inner {
if let Some(&(_, ref e)) = fields.iter().find(|&(ref l, _)| l == label) {
return e.clone();
}
}
expr
},
Expr::Tag(ref label, ref expr) => RcExpr::from(Expr::Tag(label.clone(), eval(expr))),
Expr::Case(ref arg, ref clauses) => {
let arg = eval(arg);
for clause in clauses {
let (pattern, body) = clause.clone().unbind();
if let Some(mappings) = match_expr(&pattern, &arg) {
return eval(&body.substs(&mappings));
}
}
RcExpr::from(Expr::Case(arg, clauses.clone())) },
Expr::Fold(ref ty, ref expr) => RcExpr::from(Expr::Fold(ty.clone(), eval(expr))),
Expr::Unfold(ref ty, ref expr) => {
let expr = eval(expr);
if let Expr::Fold(_, ref expr) = *expr.inner {
return expr.clone();
}
RcExpr::from(Expr::Unfold(ty.clone(), expr))
},
}
}
pub fn match_expr(pattern: &RcPattern, expr: &RcExpr) -> Option<Vec<(FreeVar<String>, RcExpr)>> {
match (&*pattern.inner, &*expr.inner) {
(&Pattern::Ann(ref pattern, _), _) => match_expr(pattern, expr),
(&Pattern::Literal(ref pattern_lit), &Expr::Literal(ref expr_lit))
if pattern_lit == expr_lit =>
{
Some(vec![])
},
(&Pattern::Binder(Binder(ref free_var)), _) => Some(vec![(free_var.clone(), expr.clone())]),
(&Pattern::Record(ref pattern_fields), &Expr::Record(ref expr_fields))
if pattern_fields.len() == expr_fields.len() =>
{
let mut mappings = Vec::new();
for (pattern_field, expr_field) in <_>::zip(pattern_fields.iter(), expr_fields.iter()) {
if pattern_field.0 != expr_field.0 {
return None;
} else {
mappings.extend(match_expr(&pattern_field.1, &expr_field.1)?);
}
}
Some(mappings)
}
(&Pattern::Tag(ref pattern_label, ref pattern), &Expr::Tag(ref expr_label, ref expr))
if pattern_label == expr_label =>
{
match_expr(pattern, expr)
},
(_, _) => None,
}
}
pub fn check_expr(context: &Context, expr: &RcExpr, expected_ty: &RcType) -> Result<(), String> {
match (&*expr.inner, &*expected_ty.inner) {
(&Expr::Lam(ref scope), &Type::Arrow(ref param_ty, ref ret_ty)) => {
let (pattern, body) = scope.clone().unbind();
let bindings = check_pattern(context, &pattern, param_ty)?;
return check_expr(&(context + &bindings), &body, ret_ty);
},
(&Expr::Tag(ref label, ref expr), &Type::Variant(ref variants)) => {
return match variants.iter().find(|&(l, _)| l == label) {
None => Err(format!(
"variant type did not contain the label `{}`",
label
)),
Some(&(_, ref ty)) => check_expr(context, expr, ty),
};
},
(&Expr::Case(ref expr, ref clauses), _) => {
let expr_ty = infer_expr(context, expr)?;
for clause in clauses {
let (pattern, body) = clause.clone().unbind();
let bindings = check_pattern(context, &pattern, &expr_ty)?;
check_expr(&(context + &bindings), &body, expected_ty)?;
}
return Ok(());
},
(_, _) => {},
}
let inferred_ty = infer_expr(&context, expr)?;
if RcType::term_eq(&inferred_ty, expected_ty) {
Ok(())
} else {
Err(format!(
"type mismatch - found `{:?}` but expected `{:?}`",
inferred_ty, expected_ty
))
}
}
pub fn infer_expr(context: &Context, expr: &RcExpr) -> Result<RcType, String> {
match *expr.inner {
Expr::Ann(ref expr, ref ty) => {
check_expr(context, expr, ty)?;
Ok(ty.clone())
},
Expr::Literal(Literal::Int(_)) => Ok(RcType::from(Type::Int)),
Expr::Literal(Literal::Float(_)) => Ok(RcType::from(Type::Float)),
Expr::Literal(Literal::String(_)) => Ok(RcType::from(Type::String)),
Expr::Var(Var::Free(ref free_var)) => match context.get(free_var) {
Some(term) => Ok((*term).clone()),
None => Err(format!("`{}` not found in `{:?}`", free_var, context)),
},
Expr::Var(Var::Bound(ref bound_var)) => {
panic!("encountered a bound variable: {}", bound_var)
},
Expr::Lam(ref scope) => {
let (pattern, body) = scope.clone().unbind();
let (ann, bindings) = infer_pattern(context, &pattern)?;
let body_ty = infer_expr(&(context + &bindings), &body)?;
Ok(RcType::from(Type::Arrow(ann, body_ty)))
},
Expr::App(ref fun, ref arg) => match *infer_expr(context, fun)?.inner {
Type::Arrow(ref param_ty, ref ret_ty) => {
let arg_ty = infer_expr(context, arg)?;
if RcType::term_eq(param_ty, &arg_ty) {
Ok(ret_ty.clone())
} else {
Err(format!(
"argument type mismatch - found `{:?}` but expected `{:?}`",
arg_ty, param_ty,
))
}
},
_ => Err(format!("`{:?}` is not a function", fun)),
},
Expr::Record(ref fields) => Ok(RcType::from(Type::Record(
fields
.iter()
.map(|&(ref label, ref expr)| Ok((label.clone(), infer_expr(context, expr)?)))
.collect::<Result<_, String>>()?,
))),
Expr::Proj(ref expr, ref label) => match *infer_expr(context, expr)?.inner {
Type::Record(ref fields) => match fields.iter().find(|&(l, _)| l == label) {
Some(&(_, ref ty)) => Ok(ty.clone()),
None => Err(format!("field `{}` not found in type", label)),
},
_ => Err("record expected".to_string()),
},
Expr::Tag(_, _) => Err("type annotations needed".to_string()),
Expr::Case(_, _) => Err("type annotations needed".to_string()),
Expr::Fold(ref ty, ref expr) => match *ty.inner {
Type::Rec(ref scope) => {
let (binder, Embed(body_ty)) = scope.clone().unbind().0.unrec();
check_expr(context, expr, &body_ty.subst(&binder, ty))?;
Ok(ty.clone())
},
_ => Err(format!("found `{:?}` but expected a recursive type", ty)),
},
Expr::Unfold(ref ty, ref expr) => match *ty.inner {
Type::Rec(ref scope) => {
let (binder, Embed(body_ty)) = scope.clone().unbind().0.unrec();
check_expr(context, expr, ty)?;
Ok(body_ty.subst(&binder, ty))
},
_ => Err(format!("found `{:?}` but expected a recursive type", ty)),
},
}
}
pub fn check_pattern(
context: &Context,
pattern: &RcPattern,
expected_ty: &RcType,
) -> Result<Context, String> {
match (&*pattern.inner, &*expected_ty.inner) {
(&Pattern::Binder(Binder(ref free_var)), _) => {
return Ok(Context::singleton(free_var.clone(), expected_ty.clone()));
},
(&Pattern::Tag(ref label, ref pattern), &Type::Variant(ref variants)) => {
return match variants.iter().find(|&(l, _)| l == label) {
None => Err(format!(
"variant type did not contain the label `{}`",
label
)),
Some(&(_, ref ty)) => check_pattern(context, pattern, ty),
};
},
(_, _) => {},
}
let (inferred_ty, telescope) = infer_pattern(&context, pattern)?;
if RcType::term_eq(&inferred_ty, expected_ty) {
Ok(telescope)
} else {
Err(format!(
"type mismatch - found `{:?}` but expected `{:?}`",
inferred_ty, expected_ty
))
}
}
pub fn infer_pattern(context: &Context, expr: &RcPattern) -> Result<(RcType, Context), String> {
match *expr.inner {
Pattern::Wildcard => Err("type annotations needed".to_string()),
Pattern::Ann(ref pattern, Embed(ref ty)) => {
let telescope = check_pattern(context, pattern, ty)?;
Ok((ty.clone(), telescope))
},
Pattern::Literal(Literal::Int(_)) => Ok((RcType::from(Type::Int), Context::new())),
Pattern::Literal(Literal::Float(_)) => Ok((RcType::from(Type::Float), Context::new())),
Pattern::Literal(Literal::String(_)) => Ok((RcType::from(Type::String), Context::new())),
Pattern::Binder(_) => Err("type annotations needed".to_string()),
Pattern::Record(ref fields) => {
let mut telescope = Context::new();
let fields = fields
.iter()
.map(|&(ref label, ref pattern)| {
let (pattern_ty, pattern_telescope) = infer_pattern(context, pattern)?;
telescope.extend(pattern_telescope);
Ok((label.clone(), pattern_ty))
})
.collect::<Result<_, String>>()?;
Ok((RcType::from(Type::Record(fields)), telescope))
},
Pattern::Tag(_, _) => Err("type annotations needed".to_string()),
}
}
#[test]
fn test_infer_expr() {
use moniker::FreeVar;
let x = FreeVar::fresh_named("x");
let expr = RcExpr::from(Expr::Lam(Scope::new(
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(x.clone()))),
Embed(RcType::from(Type::Int)),
)),
RcExpr::from(Expr::Var(Var::Free(x.clone()))),
)));
assert_term_eq!(
infer_expr(&Context::new(), &expr).unwrap(),
RcType::from(Type::Arrow(
RcType::from(Type::Int),
RcType::from(Type::Int)
)),
);
}
#[test]
fn test_infer_app_expr() {
use moniker::FreeVar;
let x = FreeVar::fresh_named("x");
let expr = RcExpr::from(Expr::App(
RcExpr::from(Expr::Ann(
RcExpr::from(Expr::Lam(Scope::new(
RcPattern::from(Pattern::Binder(Binder(x.clone()))),
RcExpr::from(Expr::Var(Var::Free(x.clone()))),
))),
RcType::from(Type::Arrow(
RcType::from(Type::Int),
RcType::from(Type::Int),
)),
)),
RcExpr::from(Expr::Literal(Literal::Int(1))),
));
assert_term_eq!(
infer_expr(&Context::new(), &expr).unwrap(),
RcType::from(Type::Int),
);
}
#[test]
fn test_infer_expr_record1() {
use moniker::FreeVar;
let a = FreeVar::fresh_named("a");
let b = FreeVar::fresh_named("b");
let expr = RcExpr::from(Expr::Lam(Scope::new(
RcPattern::from(Pattern::Record(vec![
(
String::from("x"),
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(a.clone()))),
Embed(RcType::from(Type::Int)),
)),
),
(
String::from("y"),
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(b.clone()))),
Embed(RcType::from(Type::String)),
)),
),
])),
RcExpr::from(Expr::Var(Var::Free(b.clone()))),
)));
assert_term_eq!(
infer_expr(&Context::new(), &expr).unwrap(),
RcType::from(Type::Arrow(
RcType::from(Type::Record(vec![
(String::from("x"), RcType::from(Type::Int)),
(String::from("y"), RcType::from(Type::String)),
])),
RcType::from(Type::String),
)),
);
}
#[test]
fn test_infer_expr_record2() {
use moniker::FreeVar;
let a = FreeVar::fresh_named("a");
let b = FreeVar::fresh_named("b");
let c = FreeVar::fresh_named("c");
let expr = RcExpr::from(Expr::Lam(Scope::new(
RcPattern::from(Pattern::Record(vec![
(
String::from("x"),
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(a.clone()))),
Embed(RcType::from(Type::Int)),
)),
),
(
String::from("y"),
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(b.clone()))),
Embed(RcType::from(Type::String)),
)),
),
(
String::from("z"),
RcPattern::from(Pattern::Ann(
RcPattern::from(Pattern::Binder(Binder(c.clone()))),
Embed(RcType::from(Type::Float)),
)),
),
])),
RcExpr::from(Expr::Record(vec![
(
String::from("x"),
RcExpr::from(Expr::Var(Var::Free(a.clone()))),
),
(
String::from("y"),
RcExpr::from(Expr::Var(Var::Free(b.clone()))),
),
(
String::from("z"),
RcExpr::from(Expr::Var(Var::Free(c.clone()))),
),
])),
)));
assert_term_eq!(
infer_expr(&Context::new(), &expr).unwrap(),
RcType::from(Type::Arrow(
RcType::from(Type::Record(vec![
(String::from("x"), RcType::from(Type::Int)),
(String::from("y"), RcType::from(Type::String)),
(String::from("z"), RcType::from(Type::Float)),
])),
RcType::from(Type::Record(vec![
(String::from("x"), RcType::from(Type::Int)),
(String::from("y"), RcType::from(Type::String)),
(String::from("z"), RcType::from(Type::Float)),
])),
)),
);
}
fn main() {}