from __future__ import print_function
import copy
class Lambda(object):
def __init__(self, v, body):
self.v = v
self.body = body
def __str__(self):
return "(fn {v} => {body})".format(v=self.v, body=self.body)
class Identifier(object):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
class Apply(object):
def __init__(self, fn, arg):
self.fn = fn
self.arg = arg
def __str__(self):
return "({fn} {arg})".format(fn=self.fn, arg=self.arg)
class Let(object):
def __init__(self, v, defn, body):
self.v = v
self.defn = defn
self.body = body
def __str__(self):
return "(let {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)
class Letrec(object):
def __init__(self, v, defn, body):
self.v = v
self.defn = defn
self.body = body
def __str__(self):
return "(letrec {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)
class InferenceError(Exception):
def __init__(self, message):
self.__message = message
message = property(lambda self: self.__message)
def __str__(self):
return str(self.message)
class ParseError(Exception):
def __init__(self, message):
self.__message = message
message = property(lambda self: self.__message)
def __str__(self):
return str(self.message)
class TypeVariable(object):
next_variable_id = 0
def __init__(self):
self.id = TypeVariable.next_variable_id
TypeVariable.next_variable_id += 1
self.instance = None
self.__name = None
next_variable_name = 'a'
@property
def name(self):
if self.__name is None:
self.__name = TypeVariable.next_variable_name
TypeVariable.next_variable_name = chr(ord(TypeVariable.next_variable_name) + 1)
return self.__name
def __str__(self):
if self.instance is not None:
return str(self.instance)
else:
return self.name
def __repr__(self):
return "TypeVariable(id = {0})".format(self.id)
def copy(self):
b = TypeVariable.__new__(TypeVariable)
b.id = self.id
b.instance = self.instance
b.__name = self.__name
return b
class TypeOperator(object):
def __init__(self, name, types):
self.name = name
self.types = types
def __str__(self):
num_types = len(self.types)
if num_types == 0:
return self.name
elif num_types == 2:
return "({0} {1} {2})".format(str(self.types[0]), self.name, str(self.types[1]))
else:
return "{0} {1}" .format(self.name, ' '.join(self.types))
def copy(self):
b = TypeOperator.__new__(TypeOperator)
b.types = self.types
b.name = self.name
return b
class Function(TypeOperator):
def __init__(self, from_type, to_type):
super(Function, self).__init__("->", [from_type, to_type])
Integer = TypeOperator("int", []) Bool = TypeOperator("bool", [])
def analyse(node, env, non_generic=None):
if non_generic is None:
non_generic = set()
if isinstance(node, Identifier):
return get_type(node.name, env, non_generic)
elif isinstance(node, Apply):
fun_type = analyse(node.fn, env, non_generic)
arg_type = analyse(node.arg, env, non_generic)
result_type = TypeVariable()
unify(Function(arg_type, result_type), fun_type)
return result_type
elif isinstance(node, Lambda):
arg_type = TypeVariable()
new_env = env.copy()
new_env[node.v] = arg_type
new_non_generic = non_generic.copy()
new_non_generic.add(arg_type)
result_type = analyse(node.body, new_env, new_non_generic)
return Function(arg_type, result_type)
elif isinstance(node, Let):
defn_type = analyse(node.defn, env, non_generic)
new_env = env.copy()
new_env[node.v] = defn_type
return analyse(node.body, new_env, non_generic)
elif isinstance(node, Letrec):
new_type = TypeVariable()
new_env = env.copy()
new_env[node.v] = new_type
new_non_generic = non_generic.copy()
new_non_generic.add(new_type)
defn_type = analyse(node.defn, new_env, new_non_generic)
unify(new_type, defn_type)
return analyse(node.body, new_env, non_generic)
assert 0, "Unhandled syntax node {0}".format(type(node))
def get_type(name, env, non_generic):
if name in env:
return fresh(env[name], non_generic)
elif is_integer_literal(name):
return Integer
else:
raise ParseError("Undefined symbol {0}".format(name))
def fresh(t, non_generic):
mappings = {}
def freshrec(tp):
p = prune(tp)
if isinstance(p, TypeVariable):
if is_generic(p, non_generic):
if p not in mappings:
mappings[p] = TypeVariable()
return mappings[p]
else:
return p
elif isinstance(p, TypeOperator):
return TypeOperator(p.name, [freshrec(x) for x in p.types])
return freshrec(t)
def unify(t1, t2):
a = prune(t1)
b = prune(t2)
if isinstance(a, TypeVariable):
if a != b:
if occurs_in_type(a, b):
raise InferenceError("recursive unification")
a.instance = b
elif isinstance(a, TypeOperator) and isinstance(b, TypeVariable):
unify(b, a)
elif isinstance(a, TypeOperator) and isinstance(b, TypeOperator):
if a.name != b.name or len(a.types) != len(b.types):
raise InferenceError("Type mismatch: {0} != {1}".format(str(a), str(b)))
for p, q in zip(a.types, b.types):
unify(p, q)
else:
assert 0, "Not unified"
def prune(t):
if isinstance(t, TypeVariable):
if t.instance is not None:
t.instance = prune(t.instance)
return t.instance
return t
def is_generic(v, non_generic):
return not occurs_in(v, non_generic)
def occurs_in_type(v, type2):
pruned_type2 = prune(type2)
if pruned_type2 == v:
return True
elif isinstance(pruned_type2, TypeOperator):
return occurs_in(v, pruned_type2.types)
return False
def occurs_in(t, types):
return any(occurs_in_type(t, t2) for t2 in types)
def is_integer_literal(name):
result = True
try:
int(name)
except ValueError:
result = False
return result
def try_exp(env, node):
print(str(node) + " : ", end=' ')
try:
t = analyse(node, env)
print(str(t))
except (ParseError, InferenceError) as e:
print(e)
def main():
var1 = TypeVariable()
var2 = TypeVariable()
pair_type = TypeOperator("*", (var1, var2))
var3 = TypeVariable()
my_env = {"pair": Function(var1, Function(var2, pair_type)),
"true": Bool,
"cond": Function(Bool, Function(var3, Function(var3, var3))),
"zero": Function(Integer, Bool),
"pred": Function(Integer, Integer),
"times": Function(Integer, Function(Integer, Integer))}
pair = Apply(Apply(Identifier("pair"),
Apply(Identifier("f"),
Identifier("4"))),
Apply(Identifier("f"),
Identifier("true")))
examples = [
Letrec("factorial", Lambda("n", Apply(
Apply( Apply(Identifier("cond"), Apply(Identifier("zero"), Identifier("n"))),
Identifier("1")),
Apply( Apply(Identifier("times"), Identifier("n")),
Apply(Identifier("factorial"),
Apply(Identifier("pred"), Identifier("n")))
)
)
), Apply(Identifier("factorial"), Identifier("5"))
),
Lambda("x",
Apply(
Apply(Identifier("pair"),
Apply(Identifier("x"), Identifier("3"))),
Apply(Identifier("x"), Identifier("true")))),
Apply(
Apply(Identifier("pair"), Apply(Identifier("f"), Identifier("4"))),
Apply(Identifier("f"), Identifier("true"))),
Let("f", Lambda("x", Identifier("x")), pair),
Lambda("f", Apply(Identifier("f"), Identifier("f"))),
Let("g",
Lambda("f", Identifier("5")),
Apply(Identifier("g"), Identifier("g"))),
Lambda("g",
Let("f",
Lambda("x", Identifier("g")),
Apply(
Apply(Identifier("pair"),
Apply(Identifier("f"), Identifier("3"))
),
Apply(Identifier("f"), Identifier("true"))))),
Lambda("f", Lambda("g", Lambda("arg", Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg"))))))
]
for example in examples:
try_exp(my_env, example)
if __name__ == '__main__':
main()