#include "deserialize.h"
#include <limits.h>
#include "limitations.h"
#include "simplicity_alloc.h"
#include "simplicity_assert.h"
static simplicity_err getWord32Array(uint32_t* result, const size_t len, bitstream* stream) {
for (size_t i = 0; i < len; ++i) {
int32_t bits16 = rustsimplicity_0_6_readNBits(16, stream);
if (bits16 < 0) return (simplicity_err)bits16;
result[i] = (uint32_t)bits16 << 16;
bits16 = rustsimplicity_0_6_readNBits(16, stream);
if (bits16 < 0) return (simplicity_err)bits16;
result[i] |= (uint32_t)bits16;
}
return SIMPLICITY_NO_ERROR;
}
static simplicity_err getHash(sha256_midstate* result, bitstream* stream) {
return getWord32Array(result->s, 8, stream);
}
static simplicity_err decodeNode(dag_node* dag, rustsimplicity_0_6_callback_decodeJet decodeJet, uint_fast32_t i, bitstream* stream) {
int32_t bit = read1Bit(stream);
if (bit < 0) return (simplicity_err)bit;
dag[i] = (dag_node){0};
if (bit) {
bit = read1Bit(stream);
if (bit < 0) return (simplicity_err)bit;
if (bit) {
return decodeJet(&dag[i], stream);
} else {
int32_t depth = rustsimplicity_0_6_decodeUptoMaxInt(stream);
if (depth < 0) return (simplicity_err)depth;
if (32 < depth) return SIMPLICITY_ERR_DATA_OUT_OF_RANGE;
{
simplicity_err error = rustsimplicity_0_6_readBitstring(&dag[i].compactValue, (size_t)1 << (depth - 1), stream);
if (!IS_OK(error)) return error;
}
dag[i].tag = WORD;
dag[i].targetIx = (size_t)depth;
dag[i].cmr = rustsimplicity_0_6_computeWordCMR(&dag[i].compactValue, (size_t)(depth - 1));
}
} else {
int32_t code = rustsimplicity_0_6_readNBits(2, stream);
if (code < 0) return (simplicity_err)code;
int32_t subcode = rustsimplicity_0_6_readNBits(code < 3 ? 2 : 1, stream);
if (subcode < 0) return (simplicity_err)subcode;
for (int32_t j = 0; j < 2 - code; ++j) {
int32_t ix = rustsimplicity_0_6_decodeUptoMaxInt(stream);
if (ix < 0) return (simplicity_err)ix;
if (i < (uint_fast32_t)ix) return SIMPLICITY_ERR_DATA_OUT_OF_RANGE;
dag[i].child[j] = i - (uint_fast32_t)ix;
}
switch (code) {
case 0:
switch (subcode) {
case 0: dag[i].tag = COMP; break;
case 1:
dag[i].tag = (HIDDEN == dag[dag[i].child[0]].tag) ? ASSERTR
: (HIDDEN == dag[dag[i].child[1]].tag) ? ASSERTL
: CASE;
break;
case 2: dag[i].tag = PAIR; break;
case 3: dag[i].tag = DISCONNECT; break;
}
break;
case 1:
switch (subcode) {
case 0: dag[i].tag = INJL; break;
case 1: dag[i].tag = INJR; break;
case 2: dag[i].tag = TAKE; break;
case 3: dag[i].tag = DROP; break;
}
break;
case 2:
switch (subcode) {
case 0: dag[i].tag = IDEN; break;
case 1: dag[i].tag = UNIT; break;
case 2: return SIMPLICITY_ERR_FAIL_CODE;
case 3: return SIMPLICITY_ERR_RESERVED_CODE;
}
break;
case 3:
switch (subcode) {
case 0:
dag[i].tag = HIDDEN;
return getHash(&(dag[i].cmr), stream);
case 1:
dag[i].tag = WITNESS;
break;
}
break;
}
for (int32_t j = 0; j < 2 - code; ++j) {
if (HIDDEN == dag[dag[i].child[j]].tag && dag[i].tag != (j ? ASSERTL : ASSERTR)) return SIMPLICITY_ERR_HIDDEN;
}
rustsimplicity_0_6_computeCommitmentMerkleRoot(dag, i);
}
return SIMPLICITY_NO_ERROR;
}
static simplicity_err decodeDag(dag_node* dag, rustsimplicity_0_6_callback_decodeJet decodeJet, const uint_fast32_t len, combinator_counters* census, bitstream* stream) {
for (uint_fast32_t i = 0; i < len; ++i) {
simplicity_err error = decodeNode(dag, decodeJet, i, stream);
if (!IS_OK(error)) return error;
enumerator(census, dag[i].tag);
}
return SIMPLICITY_NO_ERROR;
}
int_fast32_t rustsimplicity_0_6_decodeMallocDag(dag_node** dag, rustsimplicity_0_6_callback_decodeJet decodeJet, combinator_counters* census, bitstream* stream) {
*dag = NULL;
int32_t dagLen = rustsimplicity_0_6_decodeUptoMaxInt(stream);
if (dagLen <= 0) return dagLen;
static_assert(DAG_LEN_MAX <= (uint32_t)INT32_MAX, "DAG_LEN_MAX exceeds supported parsing range.");
if (DAG_LEN_MAX < (uint32_t)dagLen) return SIMPLICITY_ERR_DATA_OUT_OF_RANGE;
static_assert(DAG_LEN_MAX <= SIZE_MAX / sizeof(dag_node), "dag array too large.");
static_assert(1 <= DAG_LEN_MAX, "DAG_LEN_MAX is zero.");
static_assert(DAG_LEN_MAX - 1 <= UINT32_MAX, "dag array index does not fit in uint32_t.");
*dag = rustsimplicity_0_6_malloc((size_t)dagLen * sizeof(dag_node));
if (!*dag) return SIMPLICITY_ERR_MALLOC;
if (census) *census = (combinator_counters){0};
simplicity_err error = decodeDag(*dag, decodeJet, (uint_fast32_t)dagLen, census, stream);
if (IS_OK(error)) {
error = HIDDEN == (*dag)[dagLen - 1].tag
? SIMPLICITY_ERR_HIDDEN_ROOT
: rustsimplicity_0_6_verifyCanonicalOrder(*dag, (uint_fast32_t)(dagLen));
}
if (IS_OK(error)) {
return dagLen;
} else {
rustsimplicity_0_6_free(*dag);
*dag = NULL;
return (int_fast32_t)error;
}
}