use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use super::{DefId, Ty, TyKind};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KindVarId(pub u32);
impl KindVarId {
pub fn fresh() -> Self {
static COUNTER: AtomicU32 = AtomicU32::new(0);
Self(COUNTER.fetch_add(1, Ordering::SeqCst))
}
}
impl fmt::Display for KindVarId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "?K{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Kind {
Type,
Arrow(Box<Kind>, Box<Kind>),
Lifetime,
Const(Box<Ty>),
Var(KindVarId),
Row,
Effect,
Error,
}
impl Kind {
pub fn ty() -> Self {
Kind::Type
}
pub fn unary() -> Self {
Kind::Arrow(Box::new(Kind::Type), Box::new(Kind::Type))
}
pub fn binary() -> Self {
Kind::Arrow(
Box::new(Kind::Type),
Box::new(Kind::Arrow(Box::new(Kind::Type), Box::new(Kind::Type))),
)
}
pub fn nary(n: usize) -> Self {
if n == 0 {
Kind::Type
} else {
Kind::Arrow(Box::new(Kind::Type), Box::new(Kind::nary(n - 1)))
}
}
pub fn fresh_var() -> Self {
Kind::Var(KindVarId::fresh())
}
pub fn is_type(&self) -> bool {
matches!(self, Kind::Type)
}
pub fn is_constructor(&self) -> bool {
matches!(self, Kind::Arrow(_, _))
}
pub fn arity(&self) -> usize {
match self {
Kind::Arrow(_, result) => 1 + result.arity(),
_ => 0,
}
}
pub fn apply(&self, _arg: &Kind) -> Option<Kind> {
match self {
Kind::Arrow(param, result) => {
let _ = param;
Some((**result).clone())
}
_ => None,
}
}
pub fn has_vars(&self) -> bool {
match self {
Kind::Var(_) => true,
Kind::Arrow(k1, k2) => k1.has_vars() || k2.has_vars(),
Kind::Const(ty) => ty.has_vars(),
_ => false,
}
}
pub fn substitute(&self, subst: &KindSubstitution) -> Kind {
match self {
Kind::Var(id) => {
if let Some(kind) = subst.get(*id) {
kind.substitute(subst)
} else {
self.clone()
}
}
Kind::Arrow(k1, k2) => Kind::Arrow(
Box::new(k1.substitute(subst)),
Box::new(k2.substitute(subst)),
),
_ => self.clone(),
}
}
}
impl fmt::Display for Kind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Kind::Type => write!(f, "*"),
Kind::Arrow(k1, k2) => {
if matches!(**k1, Kind::Arrow(_, _)) {
write!(f, "({}) -> {}", k1, k2)
} else {
write!(f, "{} -> {}", k1, k2)
}
}
Kind::Lifetime => write!(f, "Lifetime"),
Kind::Const(ty) => write!(f, "Const {}", ty),
Kind::Var(id) => write!(f, "{}", id),
Kind::Row => write!(f, "Row"),
Kind::Effect => write!(f, "Effect"),
Kind::Error => write!(f, "{{error}}"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct KindSubstitution {
map: HashMap<KindVarId, Kind>,
}
impl KindSubstitution {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, var: KindVarId, kind: Kind) {
self.map.insert(var, kind);
}
pub fn get(&self, var: KindVarId) -> Option<&Kind> {
self.map.get(&var)
}
pub fn contains(&self, var: KindVarId) -> bool {
self.map.contains_key(&var)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct HKTParam {
pub name: Arc<str>,
pub index: u32,
pub kind: Kind,
}
impl HKTParam {
pub fn new(name: impl Into<Arc<str>>, index: u32, kind: Kind) -> Self {
Self {
name: name.into(),
index,
kind,
}
}
pub fn type_param(name: impl Into<Arc<str>>, index: u32) -> Self {
Self::new(name, index, Kind::Type)
}
pub fn constructor_param(name: impl Into<Arc<str>>, index: u32) -> Self {
Self::new(name, index, Kind::unary())
}
}
impl fmt::Display for HKTParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.kind.is_type() {
write!(f, "{}", self.name)
} else {
write!(f, "{}: {}", self.name, self.kind)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypeConstructor {
pub def_id: DefId,
pub name: Arc<str>,
pub params: Vec<HKTParam>,
pub result_kind: Kind,
}
impl TypeConstructor {
pub fn new(def_id: DefId, name: impl Into<Arc<str>>, params: Vec<HKTParam>) -> Self {
Self {
def_id,
name: name.into(),
params,
result_kind: Kind::Type,
}
}
pub fn kind(&self) -> Kind {
self.params
.iter()
.rev()
.fold(self.result_kind.clone(), |acc, param| {
Kind::Arrow(Box::new(param.kind.clone()), Box::new(acc))
})
}
pub fn is_simple(&self) -> bool {
self.params.iter().all(|p| p.kind.is_type())
}
pub fn arity(&self) -> usize {
self.params.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PartialApp {
pub constructor: Arc<TypeConstructor>,
pub args: Vec<Ty>,
}
impl PartialApp {
pub fn new(constructor: Arc<TypeConstructor>, args: Vec<Ty>) -> Self {
Self { constructor, args }
}
pub fn remaining_kind(&self) -> Kind {
let full_kind = self.constructor.kind();
let mut current = full_kind;
for _ in &self.args {
if let Kind::Arrow(_, result) = current {
current = *result;
} else {
return Kind::Error;
}
}
current
}
pub fn is_fully_applied(&self) -> bool {
self.args.len() == self.constructor.arity()
}
pub fn apply(&self, arg: Ty) -> Self {
let mut args = self.args.clone();
args.push(arg);
Self {
constructor: self.constructor.clone(),
args,
}
}
}
#[derive(Debug, Default)]
pub struct KindContext {
type_params: HashMap<Arc<str>, Kind>,
constructors: HashMap<DefId, Arc<TypeConstructor>>,
subst: KindSubstitution,
}
impl KindContext {
pub fn new() -> Self {
Self::default()
}
pub fn register_param(&mut self, name: Arc<str>, kind: Kind) {
self.type_params.insert(name, kind);
}
pub fn register_constructor(&mut self, tc: TypeConstructor) {
self.constructors.insert(tc.def_id, Arc::new(tc));
}
pub fn param_kind(&self, name: &str) -> Option<&Kind> {
self.type_params.get(name)
}
pub fn get_constructor(&self, def_id: DefId) -> Option<&Arc<TypeConstructor>> {
self.constructors.get(&def_id)
}
pub fn infer_kind(&mut self, ty: &Ty) -> Result<Kind, KindError> {
match &ty.kind {
TyKind::Int(_)
| TyKind::Float(_)
| TyKind::Bool
| TyKind::Char
| TyKind::Str
| TyKind::Never => Ok(Kind::Type),
TyKind::Tuple(_)
| TyKind::Array(_, _)
| TyKind::Slice(_)
| TyKind::Ref(_, _, _)
| TyKind::Ptr(_, _)
| TyKind::Fn(_) => Ok(Kind::Type),
TyKind::Var(_) | TyKind::Infer(_) => Ok(Kind::Type),
TyKind::Param(name, _) => self
.param_kind(name)
.cloned()
.ok_or_else(|| KindError::UnboundTypeParam(name.to_string())),
TyKind::Adt(def_id, args) => {
let tc_info = self
.get_constructor(*def_id)
.map(|tc| (tc.arity(), tc.name.to_string(), tc.params.clone()));
if let Some((arity, name, params)) = tc_info {
if args.len() != arity {
return Err(KindError::ArityMismatch {
expected: arity,
found: args.len(),
name,
});
}
for (arg, param) in args.iter().zip(¶ms) {
let arg_kind = self.infer_kind(arg)?;
self.unify_kinds(&arg_kind, ¶m.kind)?;
}
Ok(Kind::Type)
} else {
Ok(Kind::Type)
}
}
TyKind::Projection { .. } => Ok(Kind::Type),
TyKind::TraitObject(_) => Ok(Kind::Type),
TyKind::Error => Ok(Kind::Error),
}
}
pub fn unify_kinds(&mut self, k1: &Kind, k2: &Kind) -> Result<(), KindError> {
let k1 = k1.substitute(&self.subst);
let k2 = k2.substitute(&self.subst);
match (&k1, &k2) {
(Kind::Type, Kind::Type) => Ok(()),
(Kind::Lifetime, Kind::Lifetime) => Ok(()),
(Kind::Row, Kind::Row) => Ok(()),
(Kind::Effect, Kind::Effect) => Ok(()),
(Kind::Error, _) | (_, Kind::Error) => Ok(()),
(Kind::Arrow(a1, r1), Kind::Arrow(a2, r2)) => {
self.unify_kinds(a1, a2)?;
self.unify_kinds(r1, r2)
}
(Kind::Const(t1), Kind::Const(t2)) if t1 == t2 => Ok(()),
(Kind::Var(v), k) | (k, Kind::Var(v)) => {
if self.occurs_check(*v, k) {
Err(KindError::InfiniteKind(*v))
} else {
self.subst.insert(*v, k.clone());
Ok(())
}
}
_ => Err(KindError::Mismatch {
expected: k1.clone(),
found: k2.clone(),
}),
}
}
fn occurs_check(&self, var: KindVarId, kind: &Kind) -> bool {
match kind {
Kind::Var(v) => *v == var,
Kind::Arrow(k1, k2) => self.occurs_check(var, k1) || self.occurs_check(var, k2),
_ => false,
}
}
pub fn apply_subst(&self, kind: &Kind) -> Kind {
kind.substitute(&self.subst)
}
}
#[derive(Debug, Clone)]
pub enum KindError {
Mismatch { expected: Kind, found: Kind },
ArityMismatch {
expected: usize,
found: usize,
name: String,
},
UnboundTypeParam(String),
InfiniteKind(KindVarId),
NotAConstructor(Kind),
}
impl fmt::Display for KindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KindError::Mismatch { expected, found } => {
write!(f, "kind mismatch: expected {}, found {}", expected, found)
}
KindError::ArityMismatch {
expected,
found,
name,
} => {
write!(
f,
"wrong number of type arguments for '{}': expected {}, found {}",
name, expected, found
)
}
KindError::UnboundTypeParam(name) => {
write!(f, "unbound type parameter: {}", name)
}
KindError::InfiniteKind(var) => {
write!(f, "infinite kind: {} occurs in its own definition", var)
}
KindError::NotAConstructor(kind) => {
write!(f, "expected a type constructor, found kind {}", kind)
}
}
}
}
impl std::error::Error for KindError {}
pub fn builtin_constructors() -> Vec<TypeConstructor> {
vec![
TypeConstructor::new(
DefId::new(0, 0),
"Option",
vec![HKTParam::type_param("T", 0)],
),
TypeConstructor::new(
DefId::new(0, 1),
"Result",
vec![HKTParam::type_param("T", 0), HKTParam::type_param("E", 1)],
),
TypeConstructor::new(DefId::new(0, 2), "Vec", vec![HKTParam::type_param("T", 0)]),
TypeConstructor::new(
DefId::new(0, 3),
"HashMap",
vec![HKTParam::type_param("K", 0), HKTParam::type_param("V", 1)],
),
TypeConstructor::new(DefId::new(0, 4), "Box", vec![HKTParam::type_param("T", 0)]),
TypeConstructor::new(DefId::new(0, 5), "Rc", vec![HKTParam::type_param("T", 0)]),
TypeConstructor::new(DefId::new(0, 6), "Arc", vec![HKTParam::type_param("T", 0)]),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kind_display() {
assert_eq!(format!("{}", Kind::Type), "*");
assert_eq!(format!("{}", Kind::unary()), "* -> *");
assert_eq!(format!("{}", Kind::binary()), "* -> * -> *");
assert_eq!(format!("{}", Kind::nary(3)), "* -> * -> * -> *");
}
#[test]
fn test_kind_arity() {
assert_eq!(Kind::Type.arity(), 0);
assert_eq!(Kind::unary().arity(), 1);
assert_eq!(Kind::binary().arity(), 2);
assert_eq!(Kind::nary(5).arity(), 5);
}
#[test]
fn test_type_constructor() {
let tc = TypeConstructor::new(
DefId::new(0, 0),
"Result",
vec![HKTParam::type_param("T", 0), HKTParam::type_param("E", 1)],
);
assert_eq!(tc.arity(), 2);
assert_eq!(format!("{}", tc.kind()), "* -> * -> *");
}
#[test]
fn test_kind_unification() {
let mut ctx = KindContext::new();
assert!(ctx.unify_kinds(&Kind::Type, &Kind::Type).is_ok());
assert!(ctx.unify_kinds(&Kind::unary(), &Kind::unary()).is_ok());
let v = Kind::fresh_var();
assert!(ctx.unify_kinds(&v, &Kind::Type).is_ok());
let mut ctx2 = KindContext::new();
assert!(ctx2.unify_kinds(&Kind::Type, &Kind::unary()).is_err());
}
#[test]
fn test_partial_application() {
let tc = Arc::new(TypeConstructor::new(
DefId::new(0, 0),
"Result",
vec![HKTParam::type_param("T", 0), HKTParam::type_param("E", 1)],
));
let partial = PartialApp::new(tc.clone(), vec![]);
assert!(!partial.is_fully_applied());
assert_eq!(format!("{}", partial.remaining_kind()), "* -> * -> *");
let partial2 = partial.apply(Ty::int(super::super::IntTy::I32));
assert!(!partial2.is_fully_applied());
assert_eq!(format!("{}", partial2.remaining_kind()), "* -> *");
let partial3 = partial2.apply(Ty::str());
assert!(partial3.is_fully_applied());
assert_eq!(format!("{}", partial3.remaining_kind()), "*");
}
}