use std::collections::BTreeMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Type {
Bool,
Nat,
Int,
String,
Set(Box<Type>),
Seq(Box<Type>),
Fn(Box<Type>, Box<Type>),
Option(Box<Type>),
Record(RecordType),
Tuple(Vec<Type>),
Range(i64, i64),
Named(String),
Var(TypeVar),
Error,
}
impl Type {
pub fn is_numeric(&self) -> bool {
matches!(self, Type::Nat | Type::Int | Type::Range(_, _))
}
pub fn is_collection(&self) -> bool {
matches!(self, Type::Set(_) | Type::Seq(_) | Type::Fn(_, _))
}
pub fn has_vars(&self) -> bool {
match self {
Type::Var(_) => true,
Type::Set(t) | Type::Seq(t) | Type::Option(t) => t.has_vars(),
Type::Fn(k, v) => k.has_vars() || v.has_vars(),
Type::Record(r) => r.fields.values().any(|t| t.has_vars()),
Type::Tuple(elems) => elems.iter().any(|t| t.has_vars()),
_ => false,
}
}
pub fn substitute(&self, subst: &Substitution) -> Type {
match self {
Type::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
Type::Set(t) => Type::Set(Box::new(t.substitute(subst))),
Type::Seq(t) => Type::Seq(Box::new(t.substitute(subst))),
Type::Option(t) => Type::Option(Box::new(t.substitute(subst))),
Type::Fn(k, v) => {
Type::Fn(Box::new(k.substitute(subst)), Box::new(v.substitute(subst)))
}
Type::Record(r) => Type::Record(RecordType {
fields: r
.fields
.iter()
.map(|(k, v)| (k.clone(), v.substitute(subst)))
.collect(),
}),
Type::Tuple(elems) => Type::Tuple(elems.iter().map(|t| t.substitute(subst)).collect()),
_ => self.clone(),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Bool => write!(f, "Bool"),
Type::Nat => write!(f, "Nat"),
Type::Int => write!(f, "Int"),
Type::String => write!(f, "String"),
Type::Set(t) => write!(f, "Set[{}]", t),
Type::Seq(t) => write!(f, "Seq[{}]", t),
Type::Fn(k, v) => write!(f, "dict[{}, {}]", k, v),
Type::Option(t) => write!(f, "Option[{}]", t),
Type::Record(r) => {
write!(f, "Record {{ ")?;
for (i, (name, ty)) in r.fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", name, ty)?;
}
write!(f, " }}")
}
Type::Tuple(elems) => {
write!(f, "(")?;
for (i, ty) in elems.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", ty)?;
}
write!(f, ")")
}
Type::Range(lo, hi) => write!(f, "{}..{}", lo, hi),
Type::Named(name) => write!(f, "{}", name),
Type::Var(v) => write!(f, "?{}", v.0),
Type::Error => write!(f, "<error>"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RecordType {
pub fields: BTreeMap<String, Type>,
}
impl RecordType {
pub fn new() -> Self {
Self {
fields: BTreeMap::new(),
}
}
pub fn from_fields(fields: impl IntoIterator<Item = (String, Type)>) -> Self {
Self {
fields: fields.into_iter().collect(),
}
}
pub fn get_field(&self, name: &str) -> Option<&Type> {
self.fields.get(name)
}
}
impl Default for RecordType {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TypeVar(pub u32);
impl TypeVar {
pub fn new(id: u32) -> Self {
Self(id)
}
}
#[derive(Debug, Clone, Default)]
pub struct Substitution {
mappings: BTreeMap<TypeVar, Type>,
}
impl Substitution {
pub fn new() -> Self {
Self {
mappings: BTreeMap::new(),
}
}
pub fn get(&self, var: &TypeVar) -> Option<&Type> {
self.mappings.get(var)
}
pub fn insert(&mut self, var: TypeVar, ty: Type) {
self.mappings.insert(var, ty);
}
pub fn compose(&self, other: &Substitution) -> Substitution {
let mut result = Substitution::new();
for (var, ty) in &self.mappings {
result.insert(*var, ty.substitute(other));
}
for (var, ty) in &other.mappings {
if !result.mappings.contains_key(var) {
result.insert(*var, ty.clone());
}
}
result
}
pub fn is_empty(&self) -> bool {
self.mappings.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct TypeVarGen {
next_id: u32,
}
impl TypeVarGen {
pub fn new() -> Self {
Self { next_id: 0 }
}
pub fn fresh(&mut self) -> TypeVar {
let var = TypeVar(self.next_id);
self.next_id += 1;
var
}
pub fn fresh_type(&mut self) -> Type {
Type::Var(self.fresh())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_display() {
assert_eq!(Type::Bool.to_string(), "Bool");
assert_eq!(Type::Set(Box::new(Type::Nat)).to_string(), "Set[Nat]");
assert_eq!(
Type::Fn(Box::new(Type::String), Box::new(Type::Int)).to_string(),
"dict[String, Int]"
);
}
#[test]
fn test_type_has_vars() {
let mut gen = TypeVarGen::new();
assert!(!Type::Bool.has_vars());
assert!(Type::Var(gen.fresh()).has_vars());
assert!(Type::Set(Box::new(Type::Var(gen.fresh()))).has_vars());
}
#[test]
fn test_substitution() {
let mut gen = TypeVarGen::new();
let v1 = gen.fresh();
let v2 = gen.fresh();
let mut subst = Substitution::new();
subst.insert(v1, Type::Nat);
assert_eq!(Type::Var(v1).substitute(&subst), Type::Nat);
assert_eq!(Type::Var(v2).substitute(&subst), Type::Var(v2));
}
#[test]
fn test_record_type() {
let rec =
RecordType::from_fields([("x".to_string(), Type::Nat), ("y".to_string(), Type::Bool)]);
assert_eq!(rec.get_field("x"), Some(&Type::Nat));
assert_eq!(rec.get_field("z"), None);
}
}