use std::collections::HashMap;
use std::sync::Arc;
use super::ty::*;
#[derive(Debug)]
pub struct TypeContext {
scopes: Vec<Scope>,
types: HashMap<DefId, TypeDef>,
traits: HashMap<DefId, TraitDef>,
impls: Vec<TraitImpl>,
aliases: HashMap<DefId, TypeAlias>,
functions: HashMap<DefId, FnSig>,
next_def_id: u32,
current_self_ty: Option<Ty>,
inherent_methods: HashMap<(DefId, Arc<str>), TraitMethod>,
param_trait_bounds: HashMap<Arc<str>, Vec<Arc<str>>>,
module_bindings: HashMap<Arc<str>, HashMap<Arc<str>, TypeScheme>>,
}
#[derive(Debug, Clone)]
struct Scope {
bindings: HashMap<Arc<str>, TypeScheme>,
type_params: HashMap<Arc<str>, Ty>,
kind: ScopeKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScopeKind {
Module,
Function,
Block,
Loop,
Match,
}
impl Scope {
fn new(kind: ScopeKind) -> Self {
Self {
bindings: HashMap::new(),
type_params: HashMap::new(),
kind,
}
}
}
#[derive(Debug, Clone)]
pub struct TypeDef {
pub def_id: DefId,
pub name: Arc<str>,
pub generics: Vec<GenericParam>,
pub kind: TypeDefKind,
}
#[derive(Debug, Clone)]
pub enum TypeDefKind {
Struct(StructDef),
Enum(EnumDef),
}
#[derive(Debug, Clone)]
pub struct StructDef {
pub fields: Vec<(Arc<str>, Ty)>,
pub is_tuple: bool,
}
#[derive(Debug, Clone)]
pub struct EnumDef {
pub variants: Vec<EnumVariant>,
}
#[derive(Debug, Clone)]
pub struct EnumVariant {
pub name: Arc<str>,
pub fields: Vec<(Option<Arc<str>>, Ty)>,
pub discriminant: Option<i128>,
}
#[derive(Debug, Clone)]
pub struct GenericParam {
pub name: Arc<str>,
pub index: u32,
pub kind: GenericParamKind,
}
#[derive(Debug, Clone)]
pub enum GenericParamKind {
Type { bounds: Vec<TraitBound> },
Lifetime,
Const { ty: Ty },
}
#[derive(Debug, Clone)]
pub struct TraitBound {
pub trait_id: DefId,
pub args: Vec<Ty>,
}
#[derive(Debug, Clone)]
pub struct TraitDef {
pub def_id: DefId,
pub name: Arc<str>,
pub generics: Vec<GenericParam>,
pub supertraits: Vec<TraitBound>,
pub assoc_types: Vec<AssocType>,
pub methods: Vec<TraitMethod>,
}
#[derive(Debug, Clone)]
pub struct AssocType {
pub name: Arc<str>,
pub bounds: Vec<TraitBound>,
pub default: Option<Ty>,
}
#[derive(Debug, Clone)]
pub struct TraitMethod {
pub name: Arc<str>,
pub sig: FnSig,
pub has_default: bool,
}
#[derive(Debug, Clone)]
pub struct TraitImpl {
pub trait_id: DefId,
pub self_ty: Ty,
pub generics: Vec<GenericParam>,
pub assoc_types: HashMap<Arc<str>, Ty>,
pub methods: HashMap<Arc<str>, DefId>,
pub where_clauses: Vec<WhereClause>,
}
#[derive(Debug, Clone)]
pub struct WhereClause {
pub ty: Ty,
pub bounds: Vec<TraitBound>,
}
#[derive(Debug, Clone)]
pub struct TypeAlias {
pub def_id: DefId,
pub name: Arc<str>,
pub generics: Vec<GenericParam>,
pub ty: Ty,
}
#[derive(Debug, Clone)]
pub struct FnSig {
pub generics: Vec<GenericParam>,
pub lifetime_params: Vec<Arc<str>>,
pub params: Vec<(Arc<str>, Ty)>,
pub ret: Ty,
pub is_unsafe: bool,
pub is_async: bool,
pub is_const: bool,
pub where_clauses: Vec<WhereClause>,
}
impl TypeContext {
pub fn new() -> Self {
let mut ctx = Self {
scopes: vec![Scope::new(ScopeKind::Module)],
types: HashMap::new(),
traits: HashMap::new(),
impls: Vec::new(),
aliases: HashMap::new(),
functions: HashMap::new(),
next_def_id: 0,
current_self_ty: None,
inherent_methods: HashMap::new(),
param_trait_bounds: HashMap::new(),
module_bindings: HashMap::new(),
};
ctx.init_builtins();
ctx
}
fn init_builtins(&mut self) {
}
pub fn register_builtin_traits(&mut self) {
let common_traits = [
"Default",
"Display",
"Debug",
"Clone",
"Copy",
"PartialEq",
"Eq",
"PartialOrd",
"Ord",
"Hash",
"Send",
"Sync",
"Sized",
"Drop",
"Iterator",
"IntoIterator",
"FromIterator",
"From",
"Into",
"TryFrom",
"TryInto",
"AsRef",
"AsMut",
"Deref",
"DerefMut",
"Add",
"Sub",
"Mul",
"Div",
"Rem",
"Neg",
"Serialize",
"Deserialize",
];
for name in &common_traits {
if self.lookup_trait_by_name(name).is_none() {
let def_id = self.fresh_def_id();
self.register_trait(TraitDef {
def_id,
name: Arc::from(*name),
generics: Vec::new(),
supertraits: Vec::new(),
methods: Vec::new(),
assoc_types: Vec::new(),
});
}
}
}
pub fn fresh_def_id(&mut self) -> DefId {
let id = DefId::new(0, self.next_def_id);
self.next_def_id += 1;
id
}
pub fn push_scope(&mut self, kind: ScopeKind) {
self.scopes.push(Scope::new(kind));
}
pub fn pop_scope(&mut self) {
if self.scopes.len() > 1 {
self.scopes.pop();
}
}
pub fn current_scope_kind(&self) -> ScopeKind {
self.scopes
.last()
.map(|s| s.kind)
.unwrap_or(ScopeKind::Module)
}
pub fn in_loop(&self) -> bool {
self.scopes.iter().any(|s| s.kind == ScopeKind::Loop)
}
pub fn in_function(&self) -> bool {
self.scopes.iter().any(|s| s.kind == ScopeKind::Function)
}
pub fn set_self_ty(&mut self, ty: Option<Ty>) {
self.current_self_ty = ty;
}
pub fn get_self_ty(&self) -> Option<&Ty> {
self.current_self_ty.as_ref()
}
pub fn define_var(&mut self, name: impl Into<Arc<str>>, ty: Ty) {
let name = name.into();
if let Some(scope) = self.scopes.last_mut() {
scope.bindings.insert(name, TypeScheme::mono(ty));
}
}
pub fn define_var_scheme(&mut self, name: impl Into<Arc<str>>, scheme: TypeScheme) {
let name = name.into();
if let Some(scope) = self.scopes.last_mut() {
scope.bindings.insert(name, scheme);
}
}
pub fn lookup_var(&self, name: &str) -> Option<Ty> {
for scope in self.scopes.iter().rev() {
if let Some(scheme) = scope.bindings.get(name) {
return Some(scheme.instantiate());
}
}
None
}
pub fn lookup_var_scheme(&self, name: &str) -> Option<&TypeScheme> {
for scope in self.scopes.iter().rev() {
if let Some(scheme) = scope.bindings.get(name) {
return Some(scheme);
}
}
None
}
pub fn define_type_param(&mut self, name: impl Into<Arc<str>>, ty: Ty) {
let name = name.into();
if let Some(scope) = self.scopes.last_mut() {
scope.type_params.insert(name, ty);
}
}
pub fn lookup_type_param(&self, name: &str) -> Option<&Ty> {
for scope in self.scopes.iter().rev() {
if let Some(ty) = scope.type_params.get(name) {
return Some(ty);
}
}
None
}
pub fn register_type(&mut self, def: TypeDef) {
self.types.insert(def.def_id, def);
}
pub fn lookup_type(&self, def_id: DefId) -> Option<&TypeDef> {
self.types.get(&def_id)
}
pub fn lookup_type_by_name(&self, name: &str) -> Option<&TypeDef> {
self.types
.iter()
.filter(|(_, t)| t.name.as_ref() == name)
.min_by_key(|(def_id, _)| *def_id)
.map(|(_, t)| t)
}
pub fn register_trait(&mut self, def: TraitDef) {
self.traits.insert(def.def_id, def);
}
pub fn lookup_trait(&self, def_id: DefId) -> Option<&TraitDef> {
self.traits.get(&def_id)
}
pub fn lookup_trait_by_name(&self, name: &str) -> Option<&TraitDef> {
self.traits
.iter()
.filter(|(_, t)| t.name.as_ref() == name)
.min_by_key(|(def_id, _)| *def_id)
.map(|(_, t)| t)
}
pub fn register_impl(&mut self, impl_: TraitImpl) {
self.impls.push(impl_);
}
pub fn find_impls(&self, trait_id: DefId, _self_ty: &Ty) -> Vec<&TraitImpl> {
self.impls
.iter()
.filter(|impl_| impl_.trait_id == trait_id)
.collect()
}
pub fn lookup_trait_method(&self, self_ty: &Ty, method_name: &str) -> Option<&TraitMethod> {
for impl_ in &self.impls {
let matches = match (&impl_.self_ty.kind, &self_ty.kind) {
(TyKind::Adt(d1, _), TyKind::Adt(d2, _)) => d1 == d2,
_ => false,
};
if matches && impl_.methods.contains_key(method_name) {
if let Some(trait_def) = self.traits.get(&impl_.trait_id) {
if let Some(method) = trait_def
.methods
.iter()
.find(|m| m.name.as_ref() == method_name)
{
return Some(method);
}
}
}
}
None
}
pub fn register_inherent_method(
&mut self,
type_def_id: DefId,
method_name: Arc<str>,
sig: FnSig,
) {
self.inherent_methods.insert(
(type_def_id, method_name.clone()),
TraitMethod {
name: method_name,
sig,
has_default: false,
},
);
}
pub fn lookup_inherent_method(
&self,
type_def_id: DefId,
method_name: &str,
) -> Option<&TraitMethod> {
self.inherent_methods
.get(&(type_def_id, Arc::from(method_name)))
}
pub fn lookup_inherent_method_by_name(
&self,
type_name: &str,
method_name: &str,
) -> Option<&TraitMethod> {
if let Some(type_def) = self.lookup_type_by_name(type_name) {
let def_id = type_def.def_id;
return self.lookup_inherent_method(def_id, method_name);
}
None
}
pub fn register_param_bounds(&mut self, param_name: Arc<str>, trait_names: Vec<Arc<str>>) {
self.param_trait_bounds.insert(param_name, trait_names);
}
pub fn clear_param_bounds(&mut self) {
self.param_trait_bounds.clear();
}
pub fn lookup_param_method(&self, param_name: &str, method_name: &str) -> Option<&TraitMethod> {
let bounds = self.param_trait_bounds.get(param_name)?;
for trait_name in bounds {
if let Some(trait_def) = self.lookup_trait_by_name(trait_name) {
if let Some(method) = trait_def
.methods
.iter()
.find(|m| m.name.as_ref() == method_name)
{
return Some(method);
}
}
}
None
}
pub fn register_module_bindings(
&mut self,
name: Arc<str>,
bindings: HashMap<Arc<str>, TypeScheme>,
) {
self.module_bindings.insert(name, bindings);
}
pub fn lookup_module_binding(&self, module: &str, name: &str) -> Option<Ty> {
self.module_bindings
.get(module)?
.get(name)
.map(|scheme| scheme.instantiate())
}
pub fn current_scope_bindings(&self) -> HashMap<Arc<str>, TypeScheme> {
self.scopes
.last()
.map(|s| s.bindings.clone())
.unwrap_or_default()
}
pub fn clone_module_bindings(&self, module: &str) -> Option<HashMap<Arc<str>, TypeScheme>> {
self.module_bindings.get(module).cloned()
}
pub fn register_alias(&mut self, alias: TypeAlias) {
self.aliases.insert(alias.def_id, alias);
}
pub fn lookup_alias(&self, def_id: DefId) -> Option<&TypeAlias> {
self.aliases.get(&def_id)
}
pub fn register_function(&mut self, def_id: DefId, sig: FnSig) {
self.functions.insert(def_id, sig);
}
pub fn lookup_function(&self, def_id: DefId) -> Option<&FnSig> {
self.functions.get(&def_id)
}
pub fn generalize(&self, ty: &Ty) -> TypeScheme {
let free_in_ty = ty.free_vars();
let free_in_env = self.free_vars_in_env();
let vars: Vec<_> = free_in_ty.difference(&free_in_env).cloned().collect();
TypeScheme::poly(vars, ty.clone())
}
fn free_vars_in_env(&self) -> std::collections::HashSet<TyVarId> {
let mut vars = std::collections::HashSet::new();
for scope in &self.scopes {
for scheme in scope.bindings.values() {
let free = scheme.ty.free_vars();
for var in free {
if !scheme.vars.contains(&var) {
vars.insert(var);
}
}
}
}
vars
}
}
impl Default for TypeContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scope_management() {
let mut ctx = TypeContext::new();
ctx.define_var("x", Ty::int(IntTy::I32));
assert!(ctx.lookup_var("x").is_some());
ctx.push_scope(ScopeKind::Block);
ctx.define_var("y", Ty::bool());
assert!(ctx.lookup_var("x").is_some());
assert!(ctx.lookup_var("y").is_some());
ctx.pop_scope();
assert!(ctx.lookup_var("x").is_some());
assert!(ctx.lookup_var("y").is_none()); }
#[test]
fn test_generalization() {
let ctx = TypeContext::new();
let v = TyVarId::fresh();
let ty = Ty::function(vec![Ty::var(v)], Ty::var(v));
let scheme = ctx.generalize(&ty);
assert_eq!(scheme.vars.len(), 1);
}
}