use crate::cdsl::ast::{Def, DefIndex, DefPool, Var, VarIndex, VarPool};
use crate::cdsl::typevar::{DerivedFunc, TypeSet, TypeVar};
use std::collections::{HashMap, HashSet};
use std::iter::FromIterator;
#[derive(Debug, Hash, PartialEq, Eq)]
pub(crate) enum Constraint {
WiderOrEq(TypeVar, TypeVar),
Eq(TypeVar, TypeVar),
InTypeset(TypeVar, TypeSet),
}
impl Constraint {
fn translate_with<F: Fn(&TypeVar) -> TypeVar>(&self, func: F) -> Constraint {
match self {
Constraint::WiderOrEq(lhs, rhs) => {
let lhs = func(&lhs);
let rhs = func(&rhs);
Constraint::WiderOrEq(lhs, rhs)
}
Constraint::Eq(lhs, rhs) => {
let lhs = func(&lhs);
let rhs = func(&rhs);
Constraint::Eq(lhs, rhs)
}
Constraint::InTypeset(tv, ts) => {
let tv = func(&tv);
Constraint::InTypeset(tv, ts.clone())
}
}
}
fn translate_with_map(
&self,
original_to_own_typevar: &HashMap<&TypeVar, TypeVar>,
) -> Constraint {
self.translate_with(|tv| substitute(original_to_own_typevar, tv))
}
fn translate_with_env(&self, type_env: &TypeEnvironment) -> Constraint {
self.translate_with(|tv| type_env.get_equivalent(tv))
}
fn is_trivial(&self) -> bool {
match self {
Constraint::WiderOrEq(lhs, rhs) => {
if lhs == rhs {
return true;
}
let ts1 = lhs.get_typeset();
let ts2 = rhs.get_typeset();
if ts1.is_wider_or_equal(&ts2) {
return true;
}
if ts1.is_narrower(&ts2) {
return true;
}
if (&ts1.lanes & &ts2.lanes).is_empty() {
return true;
}
self.is_concrete()
}
Constraint::Eq(lhs, rhs) => lhs == rhs || self.is_concrete(),
Constraint::InTypeset(_, _) => {
self.is_concrete()
}
}
}
fn is_concrete(&self) -> bool {
match self {
Constraint::WiderOrEq(lhs, rhs) => {
lhs.singleton_type().is_some() && rhs.singleton_type().is_some()
}
Constraint::Eq(lhs, rhs) => {
lhs.singleton_type().is_some() && rhs.singleton_type().is_some()
}
Constraint::InTypeset(tv, _) => tv.singleton_type().is_some(),
}
}
fn typevar_args(&self) -> Vec<&TypeVar> {
match self {
Constraint::WiderOrEq(lhs, rhs) => vec![lhs, rhs],
Constraint::Eq(lhs, rhs) => vec![lhs, rhs],
Constraint::InTypeset(tv, _) => vec![tv],
}
}
}
#[derive(Clone, Copy)]
enum TypeEnvRank {
Singleton = 5,
Input = 4,
Intermediate = 3,
Output = 2,
Temp = 1,
Internal = 0,
}
pub(crate) struct TypeEnvironment {
vars: HashSet<VarIndex>,
ranks: HashMap<TypeVar, TypeEnvRank>,
equivalency_map: HashMap<TypeVar, TypeVar>,
pub constraints: Vec<Constraint>,
}
impl TypeEnvironment {
fn new() -> Self {
TypeEnvironment {
vars: HashSet::new(),
ranks: HashMap::new(),
equivalency_map: HashMap::new(),
constraints: Vec::new(),
}
}
fn register(&mut self, var_index: VarIndex, var: &mut Var) {
self.vars.insert(var_index);
let rank = if var.is_input() {
TypeEnvRank::Input
} else if var.is_intermediate() {
TypeEnvRank::Intermediate
} else if var.is_output() {
TypeEnvRank::Output
} else {
assert!(var.is_temp());
TypeEnvRank::Temp
};
self.ranks.insert(var.get_or_create_typevar(), rank);
}
fn add_constraint(&mut self, constraint: Constraint) {
if self.constraints.iter().any(|item| *item == constraint) {
return;
}
if let Constraint::InTypeset(tv, _) = &constraint {
assert!(
tv.base.is_none(),
"type variable is {:?}, while expecting none",
tv
);
assert!(
tv.name.starts_with("typeof_"),
"Name \"{}\" should start with \"typeof_\"",
tv.name
);
}
self.constraints.push(constraint);
}
pub fn get_equivalent(&self, tv: &TypeVar) -> TypeVar {
let mut tv = tv;
while let Some(found) = self.equivalency_map.get(tv) {
tv = found;
}
match &tv.base {
Some(parent) => self
.get_equivalent(&parent.type_var)
.derived(parent.derived_func),
None => tv.clone(),
}
}
fn rank(&self, tv: &TypeVar) -> u8 {
let actual_tv = match tv.base {
Some(_) => tv.free_typevar(),
None => Some(tv.clone()),
};
let rank = match actual_tv {
Some(actual_tv) => match self.ranks.get(&actual_tv) {
Some(rank) => Some(*rank),
None => {
assert!(
!actual_tv.name.starts_with("typeof_"),
format!("variable {} should be explicitly ranked", actual_tv.name)
);
None
}
},
None => None,
};
let rank = match rank {
Some(rank) => rank,
None => {
if tv.singleton_type().is_some() {
TypeEnvRank::Singleton
} else {
TypeEnvRank::Internal
}
}
};
rank as u8
}
fn record_equivalent(&mut self, tv1: TypeVar, tv2: TypeVar) {
assert!(tv1.base.is_none());
assert!(self.get_equivalent(&tv1) == tv1);
if let Some(tv2_base) = &tv2.base {
assert!(self.get_equivalent(&tv2_base.type_var) != tv1);
}
self.equivalency_map.insert(tv1, tv2);
}
pub fn free_typevars(&self, var_pool: &mut VarPool) -> Vec<TypeVar> {
let mut typevars = Vec::new();
typevars.extend(self.equivalency_map.keys().cloned());
typevars.extend(
self.vars
.iter()
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
);
let set: HashSet<TypeVar> = HashSet::from_iter(
typevars
.iter()
.map(|tv| self.get_equivalent(tv).free_typevar())
.filter(|opt_tv| {
opt_tv.is_some()
})
.map(|tv| tv.unwrap()),
);
Vec::from_iter(set)
}
fn normalize(&mut self, var_pool: &mut VarPool) {
let source_tvs: HashSet<TypeVar> = HashSet::from_iter(
self.vars
.iter()
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
);
let mut children: HashMap<TypeVar, HashSet<TypeVar>> = HashMap::new();
for type_var in self.equivalency_map.values() {
if type_var.base.is_none() {
continue;
}
let parent_tv = type_var.free_typevar();
if parent_tv.is_none() {
continue;
}
let parent_tv = parent_tv.unwrap();
children
.entry(parent_tv)
.or_insert_with(HashSet::new)
.insert(type_var.clone());
}
for (equivalent_tv, canon_tv) in self.equivalency_map.iter() {
children
.entry(canon_tv.clone())
.or_insert_with(HashSet::new)
.insert(equivalent_tv.clone());
}
for free_root in self.free_typevars(var_pool) {
let mut root = &free_root;
while !source_tvs.contains(&root)
&& children.contains_key(&root)
&& children.get(&root).unwrap().len() == 1
{
let child = children.get(&root).unwrap().iter().next().unwrap();
assert_eq!(self.equivalency_map[child], root.clone());
self.equivalency_map.remove(child);
root = child;
}
}
}
fn extract(self, var_pool: &mut VarPool) -> TypeEnvironment {
let vars_tv: HashSet<TypeVar> = HashSet::from_iter(
self.vars
.iter()
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
);
let mut new_equivalency_map: HashMap<TypeVar, TypeVar> = HashMap::new();
for tv in &vars_tv {
let canon_tv = self.get_equivalent(tv);
if *tv != canon_tv {
new_equivalency_map.insert(tv.clone(), canon_tv.clone());
}
assert!(vars_tv.contains(tv));
let canon_free_tv = canon_tv.free_typevar();
assert!(canon_free_tv.is_none() || vars_tv.contains(&canon_free_tv.unwrap()));
}
let mut new_constraints: HashSet<Constraint> = HashSet::new();
for constraint in &self.constraints {
let constraint = constraint.translate_with_env(&self);
if constraint.is_trivial() || new_constraints.contains(&constraint) {
continue;
}
for arg in constraint.typevar_args() {
let arg_free_tv = arg.free_typevar();
assert!(arg_free_tv.is_none() || vars_tv.contains(&arg_free_tv.unwrap()));
}
new_constraints.insert(constraint);
}
TypeEnvironment {
vars: self.vars,
ranks: self.ranks,
equivalency_map: new_equivalency_map,
constraints: Vec::from_iter(new_constraints),
}
}
}
fn substitute(map: &HashMap<&TypeVar, TypeVar>, external_type_var: &TypeVar) -> TypeVar {
match map.get(&external_type_var) {
Some(own_type_var) => own_type_var.clone(),
None => match &external_type_var.base {
Some(parent) => {
let parent_substitute = substitute(map, &parent.type_var);
TypeVar::derived(&parent_substitute, parent.derived_func)
}
None => external_type_var.clone(),
},
}
}
fn canonicalize_derivations(tv: TypeVar) -> TypeVar {
let base = match &tv.base {
Some(base) => base,
None => return tv,
};
let derived_func = base.derived_func;
if let Some(base_base) = &base.type_var.base {
let base_base_tv = &base_base.type_var;
match (derived_func, base_base.derived_func) {
(DerivedFunc::HalfWidth, DerivedFunc::DoubleWidth)
| (DerivedFunc::DoubleWidth, DerivedFunc::HalfWidth)
| (DerivedFunc::HalfVector, DerivedFunc::DoubleVector)
| (DerivedFunc::DoubleVector, DerivedFunc::HalfVector) => {
return canonicalize_derivations(base_base_tv.clone());
}
(DerivedFunc::HalfWidth, DerivedFunc::HalfVector)
| (DerivedFunc::HalfWidth, DerivedFunc::DoubleVector)
| (DerivedFunc::DoubleWidth, DerivedFunc::DoubleVector)
| (DerivedFunc::DoubleWidth, DerivedFunc::HalfVector) => {
return canonicalize_derivations(
base_base_tv
.derived(derived_func)
.derived(base_base.derived_func),
);
}
_ => {}
};
}
canonicalize_derivations(base.type_var.clone()).derived(derived_func)
}
fn constrain_fixpoint(tv1: &TypeVar, tv2: &TypeVar) {
loop {
let old_tv1_ts = tv1.get_typeset().clone();
tv2.constrain_types(tv1.clone());
if tv1.get_typeset() == old_tv1_ts {
break;
}
}
let old_tv2_ts = tv2.get_typeset();
tv1.constrain_types(tv2.clone());
assert!(old_tv2_ts == tv2.get_typeset());
}
fn unify(tv1: &TypeVar, tv2: &TypeVar, type_env: &mut TypeEnvironment) -> Result<(), String> {
let tv1 = canonicalize_derivations(type_env.get_equivalent(tv1));
let tv2 = canonicalize_derivations(type_env.get_equivalent(tv2));
if tv1 == tv2 {
return Ok(());
}
if type_env.rank(&tv2) < type_env.rank(&tv1) {
return unify(&tv2, &tv1, type_env);
}
constrain_fixpoint(&tv1, &tv2);
if tv1.get_typeset().size() == 0 || tv2.get_typeset().size() == 0 {
return Err(format!(
"Error: empty type created when unifying {} and {}",
tv1.name, tv2.name
));
}
let base = match &tv1.base {
Some(base) => base,
None => {
type_env.record_equivalent(tv1, tv2);
return Ok(());
}
};
if let Some(inverse) = base.derived_func.inverse() {
return unify(&base.type_var, &tv2.derived(inverse), type_env);
}
type_env.add_constraint(Constraint::Eq(tv1, tv2));
Ok(())
}
fn infer_definition(
def: &Def,
var_pool: &mut VarPool,
type_env: TypeEnvironment,
last_type_index: &mut usize,
) -> Result<TypeEnvironment, String> {
let apply = &def.apply;
let inst = &apply.inst;
let mut type_env = type_env;
let free_formal_tvs = inst.all_typevars();
let mut original_to_own_typevar: HashMap<&TypeVar, TypeVar> = HashMap::new();
for &tv in &free_formal_tvs {
assert!(original_to_own_typevar
.insert(
tv,
TypeVar::copy_from(tv, format!("own_{}", last_type_index))
)
.is_none());
*last_type_index += 1;
}
for (i, value_type) in apply.value_types.iter().enumerate() {
let singleton = TypeVar::new_singleton(value_type.clone());
assert!(original_to_own_typevar
.insert(free_formal_tvs[i], singleton)
.is_some());
}
let mut formal_tvs = Vec::new();
formal_tvs.extend(inst.value_results.iter().map(|&i| {
substitute(
&original_to_own_typevar,
inst.operands_out[i].type_var().unwrap(),
)
}));
formal_tvs.extend(inst.value_opnums.iter().map(|&i| {
substitute(
&original_to_own_typevar,
inst.operands_in[i].type_var().unwrap(),
)
}));
let mut actual_vars = Vec::new();
actual_vars.extend(inst.value_results.iter().map(|&i| def.defined_vars[i]));
actual_vars.extend(
inst.value_opnums
.iter()
.map(|&i| apply.args[i].unwrap_var()),
);
let mut actual_tvs = Vec::new();
for var_index in actual_vars {
let var = var_pool.get_mut(var_index);
type_env.register(var_index, var);
actual_tvs.push(var.get_or_create_typevar());
}
if let Some(poly) = &inst.polymorphic_info {
let own_ctrl_tv = &original_to_own_typevar[&poly.ctrl_typevar];
let ctrl_index = formal_tvs.iter().position(|tv| tv == own_ctrl_tv).unwrap();
if ctrl_index != 0 {
formal_tvs.swap(0, ctrl_index);
actual_tvs.swap(0, ctrl_index);
}
}
for (actual_tv, formal_tv) in actual_tvs.iter().zip(&formal_tvs) {
if let Err(msg) = unify(actual_tv, formal_tv, &mut type_env) {
return Err(format!(
"fail ti on {} <: {}: {}",
actual_tv.name, formal_tv.name, msg
));
}
}
for constraint in &inst.constraints {
type_env.add_constraint(constraint.translate_with_map(&original_to_own_typevar));
}
Ok(type_env)
}
pub(crate) fn infer_transform(
src: DefIndex,
dst: &[DefIndex],
def_pool: &DefPool,
var_pool: &mut VarPool,
) -> Result<TypeEnvironment, String> {
let mut type_env = TypeEnvironment::new();
let mut last_type_index = 0;
type_env = infer_definition(def_pool.get(src), var_pool, type_env, &mut last_type_index)
.map_err(|err| format!("In src pattern: {}", err))?;
let src_typesets = type_env
.vars
.iter()
.map(|&var_index| {
let var = var_pool.get_mut(var_index);
let tv = type_env.get_equivalent(&var.get_or_create_typevar());
(var_index, tv.get_typeset())
})
.collect::<Vec<_>>();
for (i, &def_index) in dst.iter().enumerate() {
let def = def_pool.get(def_index);
type_env = infer_definition(def, var_pool, type_env, &mut last_type_index)
.map_err(|err| format!("line {}: {}", i, err))?;
}
for (var_index, src_typeset) in src_typesets {
let var = var_pool.get(var_index);
if !var.has_free_typevar() {
continue;
}
let tv = type_env.get_equivalent(&var.get_typevar().unwrap());
let new_typeset = tv.get_typeset();
assert!(
new_typeset.is_subset(&src_typeset),
"type sets can only get narrower"
);
if new_typeset != src_typeset {
type_env.add_constraint(Constraint::InTypeset(tv.clone(), new_typeset.clone()));
}
}
type_env.normalize(var_pool);
Ok(type_env.extract(var_pool))
}