extern crate im;
#[macro_use]
extern crate moniker;
use im::HashMap;
use moniker::{Binder, BoundTerm, Embed, FreeVar, Scope, Var};
use std::rc::Rc;
#[derive(Debug, Clone, BoundTerm)]
pub enum Type {
Int,
Float,
String,
Arrow(RcType, RcType),
}
#[derive(Debug, Clone, BoundTerm)]
pub struct RcType {
pub inner: Rc<Type>,
}
impl From<Type> for RcType {
fn from(src: Type) -> RcType {
RcType {
inner: Rc::new(src),
}
}
}
#[derive(Debug, Clone, BoundTerm)]
pub enum Literal {
Int(i32),
Float(f32),
String(String),
}
#[derive(Debug, Clone, BoundTerm)]
pub enum Expr {
Ann(RcExpr, RcType),
Literal(Literal),
Var(Var<String>),
Lam(Scope<(Binder<String>, Embed<Option<RcType>>), RcExpr>),
App(RcExpr, RcExpr),
}
#[derive(Debug, Clone, BoundTerm)]
pub struct RcExpr {
pub inner: Rc<Expr>,
}
impl From<Expr> for RcExpr {
fn from(src: Expr) -> RcExpr {
RcExpr {
inner: Rc::new(src),
}
}
}
impl RcExpr {
fn subst<N: PartialEq<Var<String>>>(&self, name: &N, replacement: &RcExpr) -> RcExpr {
match *self.inner {
Expr::Ann(ref expr, ref ty) => {
RcExpr::from(Expr::Ann(expr.subst(name, replacement), ty.clone()))
},
Expr::Var(ref var) if name == var => replacement.clone(),
Expr::Var(_) | Expr::Literal(_) => self.clone(),
Expr::Lam(ref scope) => RcExpr::from(Expr::Lam(Scope {
unsafe_pattern: scope.unsafe_pattern.clone(),
unsafe_body: scope.unsafe_body.subst(name, replacement),
})),
Expr::App(ref fun, ref arg) => RcExpr::from(Expr::App(
fun.subst(name, replacement),
arg.subst(name, replacement),
)),
}
}
}
pub fn eval(expr: &RcExpr) -> RcExpr {
match *expr.inner {
Expr::Ann(ref expr, _) => eval(expr),
Expr::Literal(_) | Expr::Var(_) | Expr::Lam(_) => expr.clone(),
Expr::App(ref fun, ref arg) => match *eval(fun).inner {
Expr::Lam(ref scope) => {
let ((binder, _), body) = scope.clone().unbind();
eval(&body.subst(&binder, &eval(arg)))
},
_ => expr.clone(),
},
}
}
type Context = HashMap<FreeVar<String>, RcType>;
pub fn check(context: &Context, expr: &RcExpr, expected_ty: &RcType) -> Result<(), String> {
match (&*expr.inner, &*expected_ty.inner) {
(&Expr::Lam(ref scope), &Type::Arrow(ref param_ty, ref ret_ty)) => {
if let ((Binder(free_var), Embed(None)), body) = scope.clone().unbind() {
check(&context.update(free_var, param_ty.clone()), &body, ret_ty)?;
return Ok(());
}
},
(_, _) => {},
}
let inferred_ty = infer(&context, expr)?;
if RcType::term_eq(&inferred_ty, expected_ty) {
Ok(())
} else {
Err(format!(
"type mismatch - found `{:?}` but expected `{:?}`",
inferred_ty, expected_ty
))
}
}
pub fn infer(context: &Context, expr: &RcExpr) -> Result<RcType, String> {
match *expr.inner {
Expr::Ann(ref expr, ref ty) => {
check(context, expr, ty)?;
Ok(ty.clone())
},
Expr::Literal(Literal::Int(_)) => Ok(RcType::from(Type::Int)),
Expr::Literal(Literal::Float(_)) => Ok(RcType::from(Type::Float)),
Expr::Literal(Literal::String(_)) => Ok(RcType::from(Type::String)),
Expr::Var(Var::Free(ref free_var)) => match context.get(free_var) {
Some(term) => Ok((*term).clone()),
None => Err(format!("`{}` not found in `{:?}`", free_var, context)),
},
Expr::Var(Var::Bound(ref bound_var)) => {
panic!("encountered a bound variable: {}", bound_var)
},
Expr::Lam(ref scope) => match scope.clone().unbind() {
((Binder(free_var), Embed(Some(ann))), body) => {
let body_ty = infer(&context.update(free_var, ann.clone()), &body)?;
Ok(RcType::from(Type::Arrow(ann, body_ty)))
},
((binder, Embed(None)), _) => Err(format!(
"type annotation needed for parameter `{:?}`",
binder
)),
},
Expr::App(ref fun, ref arg) => match *infer(context, fun)?.inner {
Type::Arrow(ref param_ty, ref ret_ty) => {
let arg_ty = infer(context, arg)?;
if RcType::term_eq(param_ty, &arg_ty) {
Ok(ret_ty.clone())
} else {
Err(format!(
"argument type mismatch - found `{:?}` but expected `{:?}`",
arg_ty, param_ty,
))
}
},
_ => Err(format!("`{:?}` is not a function", fun)),
},
}
}
#[test]
fn test_infer() {
use moniker::FreeVar;
let x = FreeVar::fresh_named("x");
let expr = RcExpr::from(Expr::Lam(Scope::new(
(Binder(x.clone()), Embed(Some(RcType::from(Type::Int)))),
RcExpr::from(Expr::Var(Var::Free(x.clone()))),
)));
assert_term_eq!(
infer(&Context::new(), &expr).unwrap(),
RcType::from(Type::Arrow(
RcType::from(Type::Int),
RcType::from(Type::Int)
)),
);
}
fn main() {}