#include "typeInference.h"
#include <stdbool.h>
#include "bounded.h"
#include "limitations.h"
#include "simplicity_alloc.h"
#include "simplicity_assert.h"
static_assert(DAG_LEN_MAX <= (SIZE_MAX - NUMBER_OF_TYPENAMES_MAX) / 4, "TYPE_DAG_LEN_MAX doesn't fit in size_t.");
#define TYPE_DAG_LEN_MAX (NUMBER_OF_TYPENAMES_MAX + 4*DAG_LEN_MAX)
typedef struct unification_arrow {
unification_var source, target;
} unification_arrow;
static unification_var* findRoot(unification_var* alpha) {
while (alpha->parent != NULL) {
if (alpha->parent->parent != NULL) alpha->parent = alpha->parent->parent;
alpha = alpha->parent;
}
return alpha;
}
static bool applyBinding_cont(unification_var* alpha, binding* bound, unification_cont** cont, size_t* bindings_used) {
if (!alpha->isBound) {
alpha->isBound = true;
alpha->bound = *bound;
*cont = (*cont)->next;
return true;
}
if (&alpha->bound == bound) {
rustsimplicity_0_6_assert(false);
*cont = (*cont)->next;
return true;
}
if (alpha->bound.kind != bound->kind) return false;
if (ONE == bound->kind) {
*cont = (*cont)->next;
return true;
} else {
(*cont)->alpha = alpha->bound.arg[0];
(*cont)->beta = bound->arg[0];
bound->cont = (unification_cont){ .alpha = alpha->bound.arg[1]
, .beta = bound->arg[1]
, .next = (*cont)->next
};
(*cont)->next = &(bound->cont);
rustsimplicity_0_6_assert(0 < *bindings_used);
(*bindings_used)--;
}
return true;
}
static unification_var* unify_cont(unification_cont* cont, size_t* bindings_used) {
unification_var* result = NULL;
while (cont) {
unification_var* alpha = findRoot(cont->alpha);
unification_var* beta = findRoot(cont->beta);
if (alpha == beta) {
cont = cont->next;
} else {
if (alpha->rank < beta->rank) {
unification_var* tmp = beta; beta = alpha; alpha = tmp;
}
beta->parent = alpha;
if (beta->isBound) {
if (!applyBinding_cont(alpha, &beta->bound, &cont, bindings_used)) return NULL;
} else {
cont = cont->next;
}
if (alpha->rank == beta->rank) alpha->rank++;
}
if (!result) result = alpha;
}
return result;
}
static bool applyBinding(unification_var* alpha, binding* bound, size_t* bindings_used) {
unification_cont scratch = {0};
unification_cont* cont = &scratch;
if (!applyBinding_cont(findRoot(alpha), bound, &cont, bindings_used)) return false;
return NULL == cont || unify_cont(cont, bindings_used);
}
static unification_var* unify(unification_var* alpha, unification_var* beta, size_t* bindings_used) {
return alpha && beta ? unify_cont(&(unification_cont){ .alpha = alpha, .beta = beta }, bindings_used) : NULL;
}
static size_t max_extra_vars(const combinator_counters* census) {
return 4*(census->case_cnt)
+ (census->disconnect_cnt)
+ (census->injl_cnt)
+ (census->injr_cnt)
+ (census->take_cnt)
+ (census->drop_cnt);
}
static simplicity_err typeInference( unification_arrow* arrow, const dag_node* dag, const uint_fast32_t len,
unification_var* extra_var, unification_var* bound_var, size_t word256_ix, size_t* bindings_used
) {
for (uint_fast32_t i = 0; i < len; ++i) {
switch (dag[i].tag) {
#define UNIFY(a, b) { if (!unify((a), (b), bindings_used)) return SIMPLICITY_ERR_TYPE_INFERENCE_UNIFICATION; }
#define APPLY_BINDING(a, b) { if (!applyBinding((a), (b), bindings_used)) return SIMPLICITY_ERR_TYPE_INFERENCE_UNIFICATION; }
case COMP:
arrow[i] = (unification_arrow){0};
UNIFY(&(arrow[dag[i].child[0]].source), &(arrow[i].source));
UNIFY(&(arrow[dag[i].child[1]].target), &(arrow[i].target));
UNIFY(&(arrow[dag[i].child[0]].target), &(arrow[dag[i].child[1]].source));
break;
case ASSERTL:
case ASSERTR:
case CASE:
*bindings_used += 2;
extra_var[0] = extra_var[1] = extra_var[2] = (unification_var){0};
extra_var[3] = (unification_var)
{ .isBound = true
, .bound = { .kind = SUM
, .arg = { &extra_var[0], &extra_var[1] }
} };
arrow[i] = (unification_arrow){ .source =
{ .isBound = true
, .bound = { .kind = PRODUCT
, .arg = { &extra_var[3], &extra_var[2] }
} } };
if (ASSERTR != dag[i].tag) {
*bindings_used += 1;
APPLY_BINDING(&(arrow[dag[i].child[0]].source), &((binding)
{ .kind = PRODUCT
, .arg = { &extra_var[0], &extra_var[2] }
}));
UNIFY(&(arrow[dag[i].child[0]].target), &(arrow[i].target));
}
if (ASSERTL != dag[i].tag) {
*bindings_used += 1;
APPLY_BINDING(&(arrow[dag[i].child[1]].source), &((binding)
{ .kind = PRODUCT
, .arg = { &extra_var[1], &extra_var[2] }
}));
UNIFY(&(arrow[dag[i].child[1]].target), &(arrow[i].target));
}
extra_var += 4;
break;
case PAIR:
*bindings_used += 1;
arrow[i] = (unification_arrow){ .target =
{ .isBound = true
, .bound = { .kind = PRODUCT
, .arg = { &(arrow[dag[i].child[0]].target), &(arrow[dag[i].child[1]].target) }
} } };
UNIFY(unify(&(arrow[dag[i].child[0]].source), &(arrow[dag[i].child[1]].source), bindings_used), &(arrow[i].source));
break;
case DISCONNECT:
*bindings_used += 3;
*extra_var = (unification_var){0};
arrow[i] = (unification_arrow){ .target =
{ .isBound = true
, .bound = { .kind = PRODUCT
, .arg = { extra_var, &(arrow[dag[i].child[1]].target) }
} } };
APPLY_BINDING(&(arrow[dag[i].child[0]].source), &((binding)
{ .kind = PRODUCT
, .arg = { &(bound_var[word256_ix]), &(arrow[i].source) }
}));
APPLY_BINDING(&(arrow[dag[i].child[0]].target), &((binding)
{ .kind = PRODUCT
, .arg = { extra_var, &(arrow[dag[i].child[1]].source) }
}));
extra_var++;
break;
case INJL:
case INJR:
*bindings_used += 1;
*extra_var = (unification_var){0};
arrow[i] = (unification_arrow){ .target =
{ .isBound = true
, .bound = { .kind = SUM
, .arg = { INJL == dag[i].tag ? &(arrow[dag[i].child[0]].target) : extra_var
, INJL == dag[i].tag ? extra_var : &(arrow[dag[i].child[0]].target)
} } } };
UNIFY(&(arrow[dag[i].child[0]].source), &(arrow[i].source));
extra_var++;
break;
case TAKE:
case DROP:
*bindings_used += 1;
*extra_var = (unification_var){0};
arrow[i] = (unification_arrow){ .source =
{ .isBound = true
, .bound = { .kind = PRODUCT
, .arg = { TAKE == dag[i].tag ? &(arrow[dag[i].child[0]].source) : extra_var
, TAKE == dag[i].tag ? extra_var : &(arrow[dag[i].child[0]].source)
} } } };
UNIFY(&(arrow[dag[i].child[0]].target), &(arrow[i].target));
extra_var++;
break;
case IDEN:
arrow[i] = (unification_arrow){0};
UNIFY(&(arrow[i].source), &(arrow[i].target));
break;
case UNIT:
arrow[i] = (unification_arrow){ .target = { .isBound = true, .bound = { .kind = ONE } } };
break;
case HIDDEN:
case WITNESS:
arrow[i] = (unification_arrow){0};
break;
case JET:
arrow[i] = (unification_arrow){0};
UNIFY(&(bound_var[dag[i].sourceIx]),&arrow[i].source);
UNIFY(&(bound_var[dag[i].targetIx]),&arrow[i].target);
break;
case WORD:
arrow[i] = (unification_arrow){ .source = { .isBound = true, .bound = { .kind = ONE } } };
UNIFY(&(bound_var[dag[i].targetIx]),&arrow[i].target);
#undef APPLY_BINDING
#undef UNIFY
}
}
return SIMPLICITY_NO_ERROR;
}
static bool isFrozen(unification_var* var) {
rustsimplicity_0_6_assert(!var->isBound || ONE != var->bound.kind || 0 == var->bound.frozen_ix);
return !var->isBound || ONE == var->bound.kind || var->bound.frozen_ix;
}
static size_t getFrozenIx(unification_var* var) {
return var->isBound ? var->bound.frozen_ix : 0;
}
static bool freeze(size_t* result, type* type_dag, size_t* type_dag_used, unification_var* var) {
var = findRoot(var);
if (isFrozen(var)) {
*result = getFrozenIx(var);
return true;
}
var->next = NULL;
rustsimplicity_0_6_assert(!var->bound.occursCheck);
var->bound.occursCheck = true;
while (var) {
unification_var* typeArg[2] = { findRoot(var->bound.arg[0]), findRoot(var->bound.arg[1]) };
if (!isFrozen(typeArg[0])) {
if (typeArg[0]->bound.occursCheck) return false;
typeArg[0]->bound.occursCheck = true;
typeArg[0]->next = var;
var = typeArg[0];
} else if (!isFrozen(typeArg[1])) {
if (typeArg[1]->bound.occursCheck) return false;
typeArg[1]->bound.occursCheck = true;
typeArg[1]->next = var;
var = typeArg[1];
} else {
*result = var->bound.frozen_ix = (*type_dag_used)++;
type_dag[var->bound.frozen_ix] = (type)
{ .kind = var->bound.kind
, .typeArg = { getFrozenIx(typeArg[0]), getFrozenIx(typeArg[1]) }
};
var = var->next;
}
}
return true;
}
static simplicity_err freezeTypes(type* type_dag, dag_node* dag, unification_arrow* arrow, const size_t len) {
type_dag[0] = (type){ .kind = ONE };
size_t type_dag_used = 1;
for (size_t i = 0; i < len; ++i) {
if (!(freeze(&(dag[i].sourceType), type_dag, &type_dag_used, &(arrow[i].source)) &&
freeze(&(dag[i].targetType), type_dag, &type_dag_used, &(arrow[i].target)))) {
return SIMPLICITY_ERR_TYPE_INFERENCE_OCCURS_CHECK;
}
}
rustsimplicity_0_6_computeTypeAnalyses(type_dag, type_dag_used);
return SIMPLICITY_NO_ERROR;
}
simplicity_err rustsimplicity_0_6_mallocTypeInference(type** type_dag, rustsimplicity_0_6_callback_mallocBoundVars mallocBoundVars, dag_node* dag, const uint_fast32_t len, const combinator_counters* census) {
*type_dag = NULL;
static_assert(DAG_LEN_MAX <= SIZE_MAX / sizeof(unification_arrow), "arrow array too large.");
static_assert(1 <= DAG_LEN_MAX, "DAG_LEN_MAX is zero.");
static_assert(DAG_LEN_MAX - 1 <= UINT32_MAX, "arrow array index does not fit in uint32_t.");
rustsimplicity_0_6_assert(1 <= len);
rustsimplicity_0_6_assert(len <= DAG_LEN_MAX);
unification_arrow* arrow = rustsimplicity_0_6_malloc(len * sizeof(unification_arrow));
unification_var* bound_var = NULL;
size_t word256_ix, extra_var_start;
const size_t orig_bindings_used = mallocBoundVars(&bound_var, &word256_ix, &extra_var_start, max_extra_vars(census));
size_t bindings_used = orig_bindings_used;
static_assert(1 <= NUMBER_OF_TYPENAMES_MAX, "NUMBER_OF_TYPENAMES_MAX is zero.");
rustsimplicity_0_6_assert(orig_bindings_used <= NUMBER_OF_TYPENAMES_MAX - 1);
simplicity_err result = arrow && bound_var ? SIMPLICITY_NO_ERROR : SIMPLICITY_ERR_MALLOC;
if (IS_OK(result)) {
result = typeInference(arrow, dag, len, bound_var + extra_var_start, bound_var, word256_ix, &bindings_used);
}
if (IS_OK(result)) {
static_assert(TYPE_DAG_LEN_MAX <= SIZE_MAX / sizeof(type), "type_dag array too large.");
static_assert(1 <= TYPE_DAG_LEN_MAX, "TYPE_DAG_LEN_MAX is zero.");
static_assert(TYPE_DAG_LEN_MAX - 1 <= UINT32_MAX, "type_dag array index does not fit in uint32_t.");
rustsimplicity_0_6_assert(bindings_used <= orig_bindings_used + 4*len);
*type_dag = rustsimplicity_0_6_malloc((1 + bindings_used) * sizeof(type));
result = *type_dag ? SIMPLICITY_NO_ERROR : SIMPLICITY_ERR_MALLOC;
if (IS_OK(result)) {
result = freezeTypes(*type_dag, dag, arrow, len);
}
if (!IS_OK(result)) {
rustsimplicity_0_6_free(*type_dag);
*type_dag = NULL;
}
}
rustsimplicity_0_6_free(arrow);
rustsimplicity_0_6_free(bound_var);
return result;
}