use std::sync::Arc;
use either::Either;
use simplicity::jet::Elements;
use simplicity::node::{CoreConstructible as _, JetConstructible as _};
use simplicity::{Cmr, FailEntropy};
use crate::array::{BTreeSlice, Partition};
use crate::ast::{
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
SingleExpressionInner, Statement,
};
use crate::debug::CallTracker;
use crate::error::{Error, RichError, Span, WithSpan};
use crate::named::{CoreExt, PairBuilder};
use crate::num::{NonZeroPow2Usize, Pow2Usize};
use crate::pattern::{BasePattern, Pattern};
use crate::str::WitnessName;
use crate::types::{StructuralType, TypeDeconstructible};
use crate::value::StructuralValue;
use crate::witness::Arguments;
use crate::{ProgNode, Value};
#[derive(Debug, Clone)]
struct Scope {
variables: Vec<Vec<Pattern>>,
ctx: simplicity::types::Context,
call_tracker: Arc<CallTracker>,
arguments: Arguments,
}
impl Scope {
pub fn new(call_tracker: Arc<CallTracker>, arguments: Arguments) -> Self {
Self {
variables: vec![vec![Pattern::Ignore]],
ctx: simplicity::types::Context::new(),
call_tracker,
arguments,
}
}
pub fn child(&self, input: Pattern) -> Self {
Self {
variables: vec![vec![input]],
ctx: self.ctx.shallow_clone(),
call_tracker: Arc::clone(&self.call_tracker),
arguments: self.arguments.clone(),
}
}
pub fn push_scope(&mut self) {
self.variables.push(Vec::new());
}
pub fn pop_scope(&mut self) {
self.variables.pop().expect("Empty stack");
}
pub fn insert(&mut self, pattern: Pattern) {
self.variables
.last_mut()
.expect("Empty stack")
.push(pattern);
}
fn get_input_pattern(&self) -> Pattern {
let mut it = self.variables.iter().flat_map(|scope| scope.iter());
let first = it.next().expect("Empty stack");
it.cloned()
.fold(first.clone(), |acc, next| Pattern::product(next, acc))
}
pub fn get(&self, target: &BasePattern) -> Option<PairBuilder<ProgNode>> {
BasePattern::from(&self.get_input_pattern()).translate(&self.ctx, target)
}
pub fn ctx(&self) -> &simplicity::types::Context {
&self.ctx
}
pub fn with_debug_symbol<S: AsRef<Span>>(
&mut self,
args: PairBuilder<ProgNode>,
body: &ProgNode,
span: &S,
) -> Result<PairBuilder<ProgNode>, RichError> {
match self.call_tracker.get_cmr(span.as_ref()) {
Some(cmr) => {
let false_and_args = ProgNode::bit(self.ctx(), false).pair(args);
let nop_assert = ProgNode::assertl_drop(body, cmr);
false_and_args.comp(&nop_assert).with_span(span)
}
None => args.comp(body).with_span(span),
}
}
pub fn get_argument(&self, name: &WitnessName) -> &Value {
self.arguments
.get(name)
.expect("Precondition: Arguments are consistent with parameters")
}
}
fn compile_blk(
stmts: &[Statement],
scope: &mut Scope,
index: usize,
last_expr: Option<&Expression>,
) -> Result<PairBuilder<ProgNode>, RichError> {
if index >= stmts.len() {
return match last_expr {
Some(expr) => expr.compile(scope),
None => Ok(PairBuilder::unit(scope.ctx())),
};
}
match &stmts[index] {
Statement::Assignment(assignment) => {
let expr = assignment.expression().compile(scope)?;
scope.insert(assignment.pattern().clone());
let left = expr.pair(PairBuilder::iden(scope.ctx()));
let right = compile_blk(stmts, scope, index + 1, last_expr)?;
left.comp(&right).with_span(assignment)
}
Statement::Expression(expression) => {
let left = expression.compile(scope)?;
let right = compile_blk(stmts, scope, index + 1, last_expr)?;
let pair = left.pair(right);
let drop_iden = ProgNode::drop_(&ProgNode::iden(scope.ctx()));
pair.comp(&drop_iden).with_span(expression)
}
}
}
impl Program {
pub fn compile(&self, arguments: Arguments) -> Result<ProgNode, RichError> {
let mut scope = Scope::new(Arc::clone(self.call_tracker()), arguments);
self.main().compile(&mut scope).map(PairBuilder::build)
}
}
impl Expression {
fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
match self.inner() {
ExpressionInner::Block(stmts, expr) => {
scope.push_scope();
let res = compile_blk(stmts, scope, 0, expr.as_ref().map(Arc::as_ref));
scope.pop_scope();
res
}
ExpressionInner::Single(e) => e.compile(scope),
}
}
}
impl SingleExpression {
fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
let expr = match self.inner() {
SingleExpressionInner::Constant(value) => {
let value = StructuralValue::from(value);
PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
}
SingleExpressionInner::Witness(name) => PairBuilder::witness(scope.ctx(), name.clone()),
SingleExpressionInner::Parameter(name) => {
let value = StructuralValue::from(scope.get_argument(name));
PairBuilder::unit_scribe(scope.ctx(), value.as_ref())
}
SingleExpressionInner::Variable(identifier) => scope
.get(&BasePattern::Identifier(identifier.clone()))
.ok_or(Error::UndefinedVariable(identifier.clone()))
.with_span(self)?,
SingleExpressionInner::Expression(expr) => expr.compile(scope)?,
SingleExpressionInner::Tuple(elements) | SingleExpressionInner::Array(elements) => {
let compiled = elements
.iter()
.map(|e| e.compile(scope))
.collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
let tree = BTreeSlice::from_slice(&compiled);
tree.fold(PairBuilder::pair)
.unwrap_or_else(|| PairBuilder::unit(scope.ctx()))
}
SingleExpressionInner::List(elements) => {
let compiled = elements
.iter()
.map(|e| e.compile(scope))
.collect::<Result<Vec<PairBuilder<ProgNode>>, RichError>>()?;
let bound = self.ty().as_list().unwrap().1;
let partition = Partition::from_slice(&compiled, bound);
partition.fold(
|block, _size: usize| {
let tree = BTreeSlice::from_slice(block);
match tree.fold(PairBuilder::pair) {
None => PairBuilder::unit(scope.ctx()).injl(),
Some(pair) => pair.injr(),
}
},
PairBuilder::pair,
)
}
SingleExpressionInner::Option(None) => PairBuilder::unit(scope.ctx()).injl(),
SingleExpressionInner::Either(Either::Left(inner)) => {
inner.compile(scope).map(PairBuilder::injl)?
}
SingleExpressionInner::Either(Either::Right(inner))
| SingleExpressionInner::Option(Some(inner)) => {
inner.compile(scope).map(PairBuilder::injr)?
}
SingleExpressionInner::Call(call) => call.compile(scope)?,
SingleExpressionInner::Match(match_) => match_.compile(scope)?,
};
scope
.ctx()
.unify(
&expr.as_ref().cached_data().arrow().target,
&StructuralType::from(self.ty()).to_unfinalized(scope.ctx()),
"",
)
.with_span(self)?;
Ok(expr)
}
}
impl Call {
fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
let args_ast = SingleExpression::tuple(self.args().clone(), *self.as_ref());
let args = args_ast.compile(scope)?;
match self.name() {
CallName::Jet(name) => {
let jet = ProgNode::jet(scope.ctx(), *name);
scope.with_debug_symbol(args, &jet, self)
}
CallName::UnwrapLeft(..) => {
let input_and_unit =
PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
let extract_inner = ProgNode::assertl_take(
&ProgNode::iden(scope.ctx()),
Cmr::fail(FailEntropy::ZERO),
);
let body = input_and_unit.comp(&extract_inner).with_span(self)?;
scope.with_debug_symbol(args, body.as_ref(), self)
}
CallName::UnwrapRight(..) | CallName::Unwrap => {
let input_and_unit =
PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
let extract_inner = ProgNode::assertr_take(
Cmr::fail(FailEntropy::ZERO),
&ProgNode::iden(scope.ctx()),
);
let body = input_and_unit.comp(&extract_inner).with_span(self)?;
scope.with_debug_symbol(args, body.as_ref(), self)
}
CallName::IsNone(..) => {
let input_and_unit =
PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx()));
let is_right = ProgNode::case_true_false(scope.ctx());
let body = input_and_unit.comp(&is_right).with_span(self)?;
args.comp(&body).with_span(self)
}
CallName::Assert => {
let jet = ProgNode::jet(scope.ctx(), Elements::Verify);
scope.with_debug_symbol(args, &jet, self)
}
CallName::Panic => {
let fail = ProgNode::fail(scope.ctx(), FailEntropy::ZERO);
scope.with_debug_symbol(args, &fail, self)
}
CallName::Debug => {
let iden = ProgNode::iden(scope.ctx());
scope.with_debug_symbol(args, &iden, self)
}
CallName::TypeCast(..) => {
Ok(args)
}
CallName::Custom(function) => {
let mut function_scope = scope.child(function.params_pattern());
let body = function.body().compile(&mut function_scope)?;
args.comp(&body).with_span(self)
}
CallName::Fold(function, bound) => {
let mut function_scope = scope.child(function.params_pattern());
let body = function.body().compile(&mut function_scope)?;
let fold_body = list_fold(*bound, body.as_ref()).with_span(self)?;
args.comp(&fold_body).with_span(self)
}
CallName::ForWhile(function, bit_width) => {
let mut function_scope = scope.child(function.params_pattern());
let body = function.body().compile(&mut function_scope)?;
let fold_body = for_while(*bit_width, body).with_span(self)?;
args.comp(&fold_body).with_span(self)
}
}
}
}
fn list_fold(bound: NonZeroPow2Usize, f: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
let mut f_array = f.clone();
let ctx = f.inference_context();
let ioh = ProgNode::i().h(ctx);
let mut f_fold = ProgNode::case(ioh.as_ref(), &f_array)?;
let mut i = NonZeroPow2Usize::TWO;
fn next_f_array(f_array: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
let ctx = f_array.inference_context();
let half1_acc = ProgNode::o().o().h(ctx).pair(ProgNode::i().h(ctx));
let updated_acc = half1_acc.comp(f_array)?;
let half2_acc = ProgNode::o().i().h(ctx).pair(updated_acc);
half2_acc.comp(f_array).map(PairBuilder::build)
}
fn next_f_fold(
f_array: &ProgNode,
f_fold: &ProgNode,
) -> Result<ProgNode, simplicity::types::Error> {
let ctx = f_array.inference_context();
let case_input = ProgNode::o()
.o()
.h(ctx)
.pair(ProgNode::o().i().h(ctx).pair(ProgNode::i().h(ctx)));
let case_left = ProgNode::drop_(f_fold);
let f_n_input = ProgNode::o().h(ctx).pair(ProgNode::i().i().h(ctx));
let f_n_output = f_n_input.comp(f_array)?;
let fold_n_input = ProgNode::i().o().h(ctx).pair(f_n_output);
let case_right = fold_n_input.comp(f_fold)?;
case_input
.comp(&ProgNode::case(&case_left, case_right.as_ref())?)
.map(PairBuilder::build)
}
while i < bound {
f_array = next_f_array(&f_array)?;
f_fold = next_f_fold(&f_array, &f_fold)?;
i = i.mul2();
}
Ok(f_fold)
}
fn for_while(
bit_width: Pow2Usize,
f: PairBuilder<ProgNode>,
) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
fn for_while_0(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
let ctx = f.inference_context();
let f_output = ProgNode::o()
.h(ctx)
.pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, false)))
.comp(f)?;
let case_input = f_output.pair(ProgNode::i().h(ctx));
let x = ProgNode::injl(ProgNode::o().h(ctx).as_ref());
let f_output = ProgNode::o()
.h(ctx)
.pair(ProgNode::i().h(ctx).pair(ProgNode::bit(ctx, true)))
.comp(f)?;
let case_output = ProgNode::case(&x, f_output.as_ref())?;
case_input.comp(&case_output)
}
fn adapt_f(f: &ProgNode) -> Result<PairBuilder<ProgNode>, simplicity::types::Error> {
let ctx = f.inference_context();
let f_input = ProgNode::o().h(ctx).pair(
ProgNode::i()
.o()
.o()
.h(ctx)
.pair(ProgNode::i().o().i().h(ctx).pair(ProgNode::i().i().h(ctx))),
);
f_input.comp(f)
}
#[derive(Debug, Copy, Clone)]
enum Task {
ForWhile0,
Adapt,
}
let max_stack = bit_width.mul2().get() - 1;
let mut stack = vec![Task::ForWhile0; max_stack];
let mut i = Pow2Usize::ONE.mul2();
while i <= bit_width {
let index = i.get() - 1;
let (prefix, tail) = stack.as_mut_slice().split_at_mut(index);
let suffix = &mut tail[..index];
debug_assert_eq!(prefix.len(), suffix.len());
suffix.copy_from_slice(prefix);
tail[index] = Task::Adapt;
i = i.mul2();
}
let mut for_while_f = f;
while let Some(task) = stack.pop() {
match task {
Task::ForWhile0 => {
for_while_f = for_while_0(for_while_f.as_ref())?;
}
Task::Adapt => {
for_while_f = adapt_f(for_while_f.as_ref())?;
}
}
}
Ok(for_while_f)
}
impl Match {
fn compile(&self, scope: &mut Scope) -> Result<PairBuilder<ProgNode>, RichError> {
scope.push_scope();
scope.insert(
self.left()
.pattern()
.as_variable()
.cloned()
.map(Pattern::Identifier)
.unwrap_or(Pattern::Ignore),
);
let left = self.left().expression().compile(scope)?;
scope.pop_scope();
scope.push_scope();
scope.insert(
self.right()
.pattern()
.as_variable()
.cloned()
.map(Pattern::Identifier)
.unwrap_or(Pattern::Ignore),
);
let right = self.right().expression().compile(scope)?;
scope.pop_scope();
let scrutinee = self.scrutinee().compile(scope)?;
let input = scrutinee.pair(PairBuilder::iden(scope.ctx()));
let output = ProgNode::case(left.as_ref(), right.as_ref()).with_span(self)?;
input.comp(&output).with_span(self)
}
}