use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TyVar(pub u32);
impl fmt::Display for TyVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "τ{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MonoType {
Var(TyVar),
Int,
Float,
Bool,
String,
Char,
Unit,
Function(Box<MonoType>, Box<MonoType>),
List(Box<MonoType>),
Tuple(Vec<MonoType>),
Optional(Box<MonoType>),
Result(Box<MonoType>, Box<MonoType>),
Named(String),
Reference(Box<MonoType>),
DataFrame(Vec<(String, MonoType)>),
Series(Box<MonoType>),
}
impl fmt::Display for MonoType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MonoType::Var(v) => write!(f, "{v}"),
MonoType::Int => write!(f, "i32"),
MonoType::Float => write!(f, "f64"),
MonoType::Bool => write!(f, "bool"),
MonoType::String => write!(f, "String"),
MonoType::Char => write!(f, "char"),
MonoType::Unit => write!(f, "()"),
MonoType::Function(arg, ret) => write!(f, "({arg} -> {ret})"),
MonoType::List(elem) => write!(f, "[{elem}]"),
MonoType::Optional(inner) => write!(f, "{inner}?"),
MonoType::Result(ok, err) => write!(f, "Result<{ok}, {err}>"),
MonoType::Tuple(types) => {
write!(f, "(")?;
for (i, ty) in types.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{ty}")?;
}
write!(f, ")")
}
MonoType::Named(name) => write!(f, "{name}"),
MonoType::Reference(inner) => write!(f, "&{inner}"),
MonoType::DataFrame(columns) => {
write!(f, "DataFrame[")?;
for (i, (name, ty)) in columns.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{name}: {ty}")?;
}
write!(f, "]")
}
MonoType::Series(dtype) => write!(f, "Series<{dtype}>"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TypeScheme {
pub vars: Vec<TyVar>,
pub ty: MonoType,
}
impl TypeScheme {
#[must_use]
pub fn mono(ty: MonoType) -> Self {
TypeScheme {
vars: Vec::new(),
ty,
}
}
pub fn instantiate(&self, gen: &mut TyVarGenerator) -> MonoType {
if self.vars.is_empty() {
self.ty.clone()
} else {
let subst: HashMap<TyVar, MonoType> = self
.vars
.iter()
.map(|v| (v.clone(), MonoType::Var(gen.fresh())))
.collect();
self.ty.substitute(&subst)
}
}
pub fn generalize(env: &crate::middleend::environment::TypeEnv, ty: &MonoType) -> Self {
let env_vars = env.free_vars();
let ty_vars = ty.free_vars();
let generalized_vars: Vec<TyVar> = ty_vars
.into_iter()
.filter(|v| !env_vars.contains(v))
.collect();
TypeScheme {
vars: generalized_vars,
ty: ty.clone(),
}
}
}
impl fmt::Display for TypeScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.vars.is_empty() {
write!(f, "{}", self.ty)
} else {
write!(f, "∀")?;
for (i, var) in self.vars.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, "{var}")?;
}
write!(f, ". {}", self.ty)
}
}
}
pub struct TyVarGenerator {
next: u32,
}
impl TyVarGenerator {
#[must_use]
pub fn new() -> Self {
TyVarGenerator { next: 0 }
}
pub fn fresh(&mut self) -> TyVar {
let var = TyVar(self.next);
self.next += 1;
var
}
}
impl Default for TyVarGenerator {
fn default() -> Self {
Self::new()
}
}
pub type Substitution = HashMap<TyVar, MonoType>;
impl MonoType {
#[must_use]
pub fn substitute(&self, subst: &Substitution) -> MonoType {
match self {
MonoType::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
MonoType::Function(arg, ret) => MonoType::Function(
Box::new(arg.substitute(subst)),
Box::new(ret.substitute(subst)),
),
MonoType::List(elem) => MonoType::List(Box::new(elem.substitute(subst))),
MonoType::Optional(inner) => MonoType::Optional(Box::new(inner.substitute(subst))),
MonoType::Result(ok, err) => MonoType::Result(
Box::new(ok.substitute(subst)),
Box::new(err.substitute(subst)),
),
MonoType::DataFrame(columns) => MonoType::DataFrame(
columns
.iter()
.map(|(name, ty)| (name.clone(), ty.substitute(subst)))
.collect(),
),
MonoType::Series(dtype) => MonoType::Series(Box::new(dtype.substitute(subst))),
MonoType::Reference(inner) => MonoType::Reference(Box::new(inner.substitute(subst))),
MonoType::Tuple(types) => {
MonoType::Tuple(types.iter().map(|ty| ty.substitute(subst)).collect())
}
_ => self.clone(),
}
}
#[must_use]
pub fn free_vars(&self) -> Vec<TyVar> {
use std::collections::HashSet;
fn collect_vars(ty: &MonoType, vars: &mut HashSet<TyVar>) {
match ty {
MonoType::Var(v) => {
vars.insert(v.clone());
}
MonoType::Function(arg, ret) => {
collect_vars(arg, vars);
collect_vars(ret, vars);
}
MonoType::List(elem) => collect_vars(elem, vars),
MonoType::Optional(inner)
| MonoType::Series(inner)
| MonoType::Reference(inner) => {
collect_vars(inner, vars);
}
MonoType::Result(ok, err) => {
collect_vars(ok, vars);
collect_vars(err, vars);
}
MonoType::DataFrame(columns) => {
for (_, ty) in columns {
collect_vars(ty, vars);
}
}
MonoType::Tuple(types) => {
for ty in types {
collect_vars(ty, vars);
}
}
_ => {}
}
}
let mut vars = HashSet::new();
collect_vars(self, &mut vars);
vars.into_iter().collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_type_display() {
assert_eq!(MonoType::Int.to_string(), "i32");
assert_eq!(MonoType::Bool.to_string(), "bool");
assert_eq!(
MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool)).to_string(),
"(i32 -> bool)"
);
assert_eq!(MonoType::List(Box::new(MonoType::Int)).to_string(), "[i32]");
}
#[test]
fn test_type_scheme_instantiation() {
let mut gen = TyVarGenerator::new();
let var = gen.fresh();
let scheme = TypeScheme {
vars: vec![var.clone()],
ty: MonoType::Function(
Box::new(MonoType::Var(var.clone())),
Box::new(MonoType::Var(var)),
),
};
let instantiated = scheme.instantiate(&mut gen);
match instantiated {
MonoType::Function(arg, ret) => {
assert!(matches!(*arg, MonoType::Var(_)));
assert!(matches!(*ret, MonoType::Var(_)));
}
_ => panic!("Expected function type"),
}
}
#[test]
fn test_substitution() {
let mut subst = HashMap::new();
let var = TyVar(0);
subst.insert(var.clone(), MonoType::Int);
let ty = MonoType::List(Box::new(MonoType::Var(var)));
let result = ty.substitute(&subst);
assert_eq!(result, MonoType::List(Box::new(MonoType::Int)));
}
#[test]
fn test_free_vars() {
let var1 = TyVar(0);
let var2 = TyVar(1);
let ty = MonoType::Function(
Box::new(MonoType::Var(var1.clone())),
Box::new(MonoType::List(Box::new(MonoType::Var(var2.clone())))),
);
let free = ty.free_vars();
assert_eq!(free.len(), 2);
assert!(free.contains(&var1));
assert!(free.contains(&var2));
let ty_dup = MonoType::Function(
Box::new(MonoType::Var(var1.clone())),
Box::new(MonoType::Var(var1.clone())),
);
let free_dup = ty_dup.free_vars();
assert_eq!(free_dup.len(), 1);
assert!(free_dup.contains(&var1));
}
}
#[cfg(test)]
mod property_tests_types {
use proptest::proptest;
proptest! {
#[test]
fn test_mono_never_panics(input: String) {
let _input = if input.len() > 100 { &input[..100] } else { &input[..] };
let _ = std::panic::catch_unwind(|| {
});
}
}
}
#[cfg(test)]
mod coverage_tests {
use super::*;
#[test]
fn test_display_float() {
assert_eq!(MonoType::Float.to_string(), "f64");
}
#[test]
fn test_display_string() {
assert_eq!(MonoType::String.to_string(), "String");
}
#[test]
fn test_display_char() {
assert_eq!(MonoType::Char.to_string(), "char");
}
#[test]
fn test_display_unit() {
assert_eq!(MonoType::Unit.to_string(), "()");
}
#[test]
fn test_display_optional() {
assert_eq!(
MonoType::Optional(Box::new(MonoType::Int)).to_string(),
"i32?"
);
}
#[test]
fn test_display_result() {
assert_eq!(
MonoType::Result(Box::new(MonoType::Int), Box::new(MonoType::String)).to_string(),
"Result<i32, String>"
);
}
#[test]
fn test_display_tuple_empty() {
assert_eq!(MonoType::Tuple(vec![]).to_string(), "()");
}
#[test]
fn test_display_tuple_single() {
assert_eq!(MonoType::Tuple(vec![MonoType::Int]).to_string(), "(i32)");
}
#[test]
fn test_display_tuple_multiple() {
assert_eq!(
MonoType::Tuple(vec![MonoType::Int, MonoType::Bool, MonoType::String]).to_string(),
"(i32, bool, String)"
);
}
#[test]
fn test_display_named() {
assert_eq!(MonoType::Named("MyStruct".into()).to_string(), "MyStruct");
}
#[test]
fn test_display_reference() {
assert_eq!(
MonoType::Reference(Box::new(MonoType::Int)).to_string(),
"&i32"
);
}
#[test]
fn test_display_dataframe_empty() {
assert_eq!(MonoType::DataFrame(vec![]).to_string(), "DataFrame[]");
}
#[test]
fn test_display_dataframe_columns() {
let df = MonoType::DataFrame(vec![
("name".into(), MonoType::String),
("age".into(), MonoType::Int),
]);
assert_eq!(df.to_string(), "DataFrame[name: String, age: i32]");
}
#[test]
fn test_display_series() {
assert_eq!(
MonoType::Series(Box::new(MonoType::Float)).to_string(),
"Series<f64>"
);
}
#[test]
fn test_display_tyvar() {
assert_eq!(TyVar(0).to_string(), "τ0");
assert_eq!(TyVar(42).to_string(), "τ42");
}
#[test]
fn test_display_var() {
assert_eq!(MonoType::Var(TyVar(5)).to_string(), "τ5");
}
#[test]
fn test_type_scheme_mono() {
let scheme = TypeScheme::mono(MonoType::Int);
assert!(scheme.vars.is_empty());
assert_eq!(scheme.ty, MonoType::Int);
}
#[test]
fn test_type_scheme_display_mono() {
let scheme = TypeScheme::mono(MonoType::Bool);
assert_eq!(scheme.to_string(), "bool");
}
#[test]
fn test_type_scheme_display_poly() {
let scheme = TypeScheme {
vars: vec![TyVar(0), TyVar(1)],
ty: MonoType::Function(
Box::new(MonoType::Var(TyVar(0))),
Box::new(MonoType::Var(TyVar(1))),
),
};
assert!(scheme.to_string().contains("∀"));
assert!(scheme.to_string().contains("τ0"));
assert!(scheme.to_string().contains("τ1"));
}
#[test]
fn test_type_scheme_instantiate_mono() {
let mut gen = TyVarGenerator::new();
let scheme = TypeScheme::mono(MonoType::Int);
let result = scheme.instantiate(&mut gen);
assert_eq!(result, MonoType::Int);
}
#[test]
fn test_tyvar_generator_default() {
let gen = TyVarGenerator::default();
let mut gen = gen;
assert_eq!(gen.fresh(), TyVar(0));
}
#[test]
fn test_tyvar_generator_increments() {
let mut gen = TyVarGenerator::new();
assert_eq!(gen.fresh(), TyVar(0));
assert_eq!(gen.fresh(), TyVar(1));
assert_eq!(gen.fresh(), TyVar(2));
}
#[test]
fn test_substitute_optional() {
let var = TyVar(0);
let mut subst = HashMap::new();
subst.insert(var.clone(), MonoType::String);
let ty = MonoType::Optional(Box::new(MonoType::Var(var)));
let result = ty.substitute(&subst);
assert_eq!(result, MonoType::Optional(Box::new(MonoType::String)));
}
#[test]
fn test_substitute_result() {
let var1 = TyVar(0);
let var2 = TyVar(1);
let mut subst = HashMap::new();
subst.insert(var1.clone(), MonoType::Int);
subst.insert(var2.clone(), MonoType::String);
let ty = MonoType::Result(Box::new(MonoType::Var(var1)), Box::new(MonoType::Var(var2)));
let result = ty.substitute(&subst);
assert_eq!(
result,
MonoType::Result(Box::new(MonoType::Int), Box::new(MonoType::String))
);
}
#[test]
fn test_substitute_reference() {
let var = TyVar(0);
let mut subst = HashMap::new();
subst.insert(var.clone(), MonoType::Bool);
let ty = MonoType::Reference(Box::new(MonoType::Var(var)));
let result = ty.substitute(&subst);
assert_eq!(result, MonoType::Reference(Box::new(MonoType::Bool)));
}
#[test]
fn test_substitute_tuple() {
let var = TyVar(0);
let mut subst = HashMap::new();
subst.insert(var.clone(), MonoType::Float);
let ty = MonoType::Tuple(vec![MonoType::Int, MonoType::Var(var), MonoType::Bool]);
let result = ty.substitute(&subst);
assert_eq!(
result,
MonoType::Tuple(vec![MonoType::Int, MonoType::Float, MonoType::Bool])
);
}
#[test]
fn test_substitute_dataframe() {
let var = TyVar(0);
let mut subst = HashMap::new();
subst.insert(var.clone(), MonoType::Int);
let ty = MonoType::DataFrame(vec![
("id".into(), MonoType::Var(var)),
("name".into(), MonoType::String),
]);
let result = ty.substitute(&subst);
assert_eq!(
result,
MonoType::DataFrame(vec![
("id".into(), MonoType::Int),
("name".into(), MonoType::String),
])
);
}
#[test]
fn test_substitute_series() {
let var = TyVar(0);
let mut subst = HashMap::new();
subst.insert(var.clone(), MonoType::Float);
let ty = MonoType::Series(Box::new(MonoType::Var(var)));
let result = ty.substitute(&subst);
assert_eq!(result, MonoType::Series(Box::new(MonoType::Float)));
}
#[test]
fn test_substitute_no_match() {
let var = TyVar(0);
let subst = HashMap::new();
let ty = MonoType::Var(var.clone());
let result = ty.substitute(&subst);
assert_eq!(result, MonoType::Var(var));
}
#[test]
fn test_substitute_primitive_unchanged() {
let subst = HashMap::new();
assert_eq!(MonoType::Int.substitute(&subst), MonoType::Int);
assert_eq!(MonoType::Bool.substitute(&subst), MonoType::Bool);
assert_eq!(MonoType::Char.substitute(&subst), MonoType::Char);
}
#[test]
fn test_free_vars_optional() {
let var = TyVar(0);
let ty = MonoType::Optional(Box::new(MonoType::Var(var.clone())));
let vars = ty.free_vars();
assert!(vars.contains(&var));
}
#[test]
fn test_free_vars_result() {
let var1 = TyVar(0);
let var2 = TyVar(1);
let ty = MonoType::Result(
Box::new(MonoType::Var(var1.clone())),
Box::new(MonoType::Var(var2.clone())),
);
let vars = ty.free_vars();
assert_eq!(vars.len(), 2);
assert!(vars.contains(&var1));
assert!(vars.contains(&var2));
}
#[test]
fn test_free_vars_reference() {
let var = TyVar(0);
let ty = MonoType::Reference(Box::new(MonoType::Var(var.clone())));
let vars = ty.free_vars();
assert!(vars.contains(&var));
}
#[test]
fn test_free_vars_series() {
let var = TyVar(0);
let ty = MonoType::Series(Box::new(MonoType::Var(var.clone())));
let vars = ty.free_vars();
assert!(vars.contains(&var));
}
#[test]
fn test_free_vars_dataframe() {
let var = TyVar(0);
let ty = MonoType::DataFrame(vec![("col".into(), MonoType::Var(var.clone()))]);
let vars = ty.free_vars();
assert!(vars.contains(&var));
}
#[test]
fn test_free_vars_tuple() {
let var1 = TyVar(0);
let var2 = TyVar(1);
let ty = MonoType::Tuple(vec![
MonoType::Var(var1.clone()),
MonoType::Int,
MonoType::Var(var2.clone()),
]);
let vars = ty.free_vars();
assert_eq!(vars.len(), 2);
assert!(vars.contains(&var1));
assert!(vars.contains(&var2));
}
#[test]
fn test_free_vars_primitive_empty() {
assert!(MonoType::Int.free_vars().is_empty());
assert!(MonoType::Bool.free_vars().is_empty());
assert!(MonoType::String.free_vars().is_empty());
assert!(MonoType::Unit.free_vars().is_empty());
}
#[test]
fn test_tyvar_equality() {
assert_eq!(TyVar(0), TyVar(0));
assert_ne!(TyVar(0), TyVar(1));
}
#[test]
fn test_tyvar_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(TyVar(0));
set.insert(TyVar(1));
set.insert(TyVar(0)); assert_eq!(set.len(), 2);
}
}