use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use either::Either;
use miniscript::iter::{Tree, TreeLike};
use simplicity::jet::Elements;
use crate::debug::{CallTracker, DebugSymbols, TrackedCallName};
use crate::error::{Error, RichError, Span, WithSpan};
use crate::num::{NonZeroPow2Usize, Pow2Usize};
use crate::parse::MatchPattern;
use crate::pattern::Pattern;
use crate::str::{AliasName, FunctionName, Identifier, ModuleName, WitnessName};
use crate::types::{
AliasedType, ResolvedType, StructuralType, TypeConstructible, TypeDeconstructible, UIntType,
};
use crate::value::{UIntValue, Value};
use crate::witness::{Parameters, WitnessTypes, WitnessValues};
use crate::{impl_eq_hash, parse};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Program {
main: Expression,
parameters: Parameters,
witness_types: WitnessTypes,
call_tracker: Arc<CallTracker>,
}
impl Program {
pub fn main(&self) -> &Expression {
&self.main
}
pub fn parameters(&self) -> &Parameters {
&self.parameters
}
pub fn witness_types(&self) -> &WitnessTypes {
&self.witness_types
}
pub fn debug_symbols(&self, file: &str) -> DebugSymbols {
self.call_tracker.with_file(file)
}
pub(crate) fn call_tracker(&self) -> &Arc<CallTracker> {
&self.call_tracker
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum Item {
TypeAlias,
Function(Function),
Module,
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum Function {
Custom,
Main(Expression),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum Statement {
Assignment(Assignment),
Expression(Expression),
}
#[derive(Clone, Debug)]
pub struct Assignment {
pattern: Pattern,
expression: Expression,
span: Span,
}
impl Assignment {
pub fn pattern(&self) -> &Pattern {
&self.pattern
}
pub fn expression(&self) -> &Expression {
&self.expression
}
}
impl_eq_hash!(Assignment; pattern, expression);
#[derive(Clone, Debug)]
pub struct Expression {
inner: ExpressionInner,
ty: ResolvedType,
span: Span,
}
impl_eq_hash!(Expression; inner, ty);
impl Expression {
pub fn inner(&self) -> &ExpressionInner {
&self.inner
}
pub fn ty(&self) -> &ResolvedType {
&self.ty
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum ExpressionInner {
Single(SingleExpression),
Block(Arc<[Statement]>, Option<Arc<Expression>>),
}
#[derive(Clone, Debug)]
pub struct SingleExpression {
inner: SingleExpressionInner,
ty: ResolvedType,
span: Span,
}
impl SingleExpression {
pub fn tuple(args: Arc<[Expression]>, span: Span) -> Self {
let ty = ResolvedType::tuple(
args.iter()
.map(Expression::ty)
.cloned()
.collect::<Vec<ResolvedType>>(),
);
let inner = SingleExpressionInner::Tuple(args);
Self { inner, ty, span }
}
pub fn inner(&self) -> &SingleExpressionInner {
&self.inner
}
pub fn ty(&self) -> &ResolvedType {
&self.ty
}
}
impl_eq_hash!(SingleExpression; inner, ty);
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum SingleExpressionInner {
Constant(Value),
Witness(WitnessName),
Parameter(WitnessName),
Variable(Identifier),
Expression(Arc<Expression>),
Tuple(Arc<[Expression]>),
Array(Arc<[Expression]>),
List(Arc<[Expression]>),
Either(Either<Arc<Expression>, Arc<Expression>>),
Option(Option<Arc<Expression>>),
Call(Call),
Match(Match),
}
#[derive(Clone, Debug)]
pub struct Call {
name: CallName,
args: Arc<[Expression]>,
span: Span,
}
impl Call {
pub fn name(&self) -> &CallName {
&self.name
}
pub fn args(&self) -> &Arc<[Expression]> {
&self.args
}
}
impl_eq_hash!(Call; name, args);
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum CallName {
Jet(Elements),
UnwrapLeft(ResolvedType),
UnwrapRight(ResolvedType),
IsNone(ResolvedType),
Unwrap,
Assert,
Panic,
Debug,
TypeCast(ResolvedType),
Custom(CustomFunction),
Fold(CustomFunction, NonZeroPow2Usize),
ForWhile(CustomFunction, Pow2Usize),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct CustomFunction {
params: Arc<[FunctionParam]>,
body: Arc<Expression>,
}
impl CustomFunction {
pub fn params(&self) -> &[FunctionParam] {
&self.params
}
pub fn body(&self) -> &Expression {
&self.body
}
pub fn params_pattern(&self) -> Pattern {
Pattern::tuple(
self.params()
.iter()
.map(FunctionParam::identifier)
.cloned()
.map(Pattern::Identifier),
)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct FunctionParam {
identifier: Identifier,
ty: ResolvedType,
}
impl FunctionParam {
pub fn identifier(&self) -> &Identifier {
&self.identifier
}
pub fn ty(&self) -> &ResolvedType {
&self.ty
}
}
#[derive(Clone, Debug)]
pub struct Match {
scrutinee: Arc<Expression>,
left: MatchArm,
right: MatchArm,
span: Span,
}
impl Match {
pub fn scrutinee(&self) -> &Expression {
&self.scrutinee
}
pub fn left(&self) -> &MatchArm {
&self.left
}
pub fn right(&self) -> &MatchArm {
&self.right
}
}
impl_eq_hash!(Match; scrutinee, left, right);
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct MatchArm {
pattern: MatchPattern,
expression: Arc<Expression>,
}
impl MatchArm {
pub fn pattern(&self) -> &MatchPattern {
&self.pattern
}
pub fn expression(&self) -> &Expression {
&self.expression
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleItem {
Ignored,
Module(Module),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct Module {
name: ModuleName,
assignments: Arc<[ModuleAssignment]>,
span: Span,
}
impl Module {
pub fn assignments(&self) -> &[ModuleAssignment] {
&self.assignments
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ModuleAssignment {
name: WitnessName,
value: Value,
span: Span,
}
impl ModuleAssignment {
pub fn name(&self) -> &WitnessName {
&self.name
}
pub fn value(&self) -> &Value {
&self.value
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum ExprTree<'a> {
Expression(&'a Expression),
Block(&'a [Statement], &'a Option<Arc<Expression>>),
Statement(&'a Statement),
Assignment(&'a Assignment),
Single(&'a SingleExpression),
Call(&'a Call),
Match(&'a Match),
}
impl TreeLike for ExprTree<'_> {
fn as_node(&self) -> Tree<Self> {
use SingleExpressionInner as S;
match self {
Self::Expression(expr) => match expr.inner() {
ExpressionInner::Block(statements, maybe_expr) => {
Tree::Unary(Self::Block(statements, maybe_expr))
}
ExpressionInner::Single(single) => Tree::Unary(Self::Single(single)),
},
Self::Block(statements, maybe_expr) => Tree::Nary(
statements
.iter()
.map(Self::Statement)
.chain(maybe_expr.iter().map(Arc::as_ref).map(Self::Expression))
.collect(),
),
Self::Statement(statement) => match statement {
Statement::Assignment(assignment) => Tree::Unary(Self::Assignment(assignment)),
Statement::Expression(expression) => Tree::Unary(Self::Expression(expression)),
},
Self::Assignment(assignment) => Tree::Unary(Self::Expression(assignment.expression())),
Self::Single(single) => match single.inner() {
S::Constant(_)
| S::Witness(_)
| S::Parameter(_)
| S::Variable(_)
| S::Option(None) => Tree::Nullary,
S::Expression(l)
| S::Either(Either::Left(l))
| S::Either(Either::Right(l))
| S::Option(Some(l)) => Tree::Unary(Self::Expression(l)),
S::Tuple(elements) | S::Array(elements) | S::List(elements) => {
Tree::Nary(elements.iter().map(Self::Expression).collect())
}
S::Call(call) => Tree::Unary(Self::Call(call)),
S::Match(match_) => Tree::Unary(Self::Match(match_)),
},
Self::Call(call) => Tree::Nary(call.args().iter().map(Self::Expression).collect()),
Self::Match(match_) => Tree::Nary(Arc::new([
Self::Expression(match_.scrutinee()),
Self::Expression(match_.left().expression()),
Self::Expression(match_.right().expression()),
])),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Default)]
struct Scope {
variables: Vec<HashMap<Identifier, ResolvedType>>,
aliases: HashMap<AliasName, ResolvedType>,
parameters: HashMap<WitnessName, ResolvedType>,
witnesses: HashMap<WitnessName, ResolvedType>,
functions: HashMap<FunctionName, CustomFunction>,
is_main: bool,
call_tracker: CallTracker,
}
impl Scope {
pub fn is_topmost(&self) -> bool {
self.variables.is_empty()
}
pub fn push_scope(&mut self) {
self.variables.push(HashMap::new());
}
pub fn push_main_scope(&mut self) {
assert!(!self.is_main, "Already inside main function");
assert!(self.is_topmost(), "Current scope is not topmost");
self.push_scope();
self.is_main = true;
}
pub fn pop_scope(&mut self) {
self.variables.pop().expect("Stack is empty");
}
pub fn pop_main_scope(&mut self) {
assert!(self.is_main, "Current scope is not inside main function");
self.pop_scope();
self.is_main = false;
assert!(
self.is_topmost(),
"Current scope is not nested in topmost scope"
)
}
pub fn insert_variable(&mut self, identifier: Identifier, ty: ResolvedType) {
self.variables
.last_mut()
.expect("Stack is empty")
.insert(identifier, ty);
}
pub fn get_variable(&self, identifier: &Identifier) -> Option<&ResolvedType> {
self.variables
.iter()
.rev()
.find_map(|scope| scope.get(identifier))
}
pub fn resolve(&self, ty: &AliasedType) -> Result<ResolvedType, Error> {
let get_alias =
|name: &AliasName| -> Option<ResolvedType> { self.aliases.get(name).cloned() };
ty.resolve(get_alias).map_err(Error::UndefinedAlias)
}
pub fn insert_alias(&mut self, name: AliasName, ty: AliasedType) -> Result<(), Error> {
let resolved_ty = self.resolve(&ty)?;
self.aliases.insert(name, resolved_ty);
Ok(())
}
pub fn insert_parameter(&mut self, name: WitnessName, ty: ResolvedType) -> Result<(), Error> {
match self.parameters.entry(name.clone()) {
Entry::Occupied(entry) if entry.get() == &ty => Ok(()),
Entry::Occupied(entry) => Err(Error::ExpressionTypeMismatch(entry.get().clone(), ty)),
Entry::Vacant(entry) => {
entry.insert(ty);
Ok(())
}
}
}
pub fn insert_witness(&mut self, name: WitnessName, ty: ResolvedType) -> Result<(), Error> {
if !self.is_main {
return Err(Error::WitnessOutsideMain);
}
match self.witnesses.entry(name.clone()) {
Entry::Occupied(_) => Err(Error::WitnessReused(name)),
Entry::Vacant(entry) => {
entry.insert(ty);
Ok(())
}
}
}
pub fn destruct(self) -> (Parameters, WitnessTypes, CallTracker) {
(
Parameters::from(self.parameters),
WitnessTypes::from(self.witnesses),
self.call_tracker,
)
}
pub fn insert_function(
&mut self,
name: FunctionName,
function: CustomFunction,
) -> Result<(), Error> {
match self.functions.entry(name.clone()) {
Entry::Occupied(_) => Err(Error::FunctionRedefined(name)),
Entry::Vacant(entry) => {
entry.insert(function);
Ok(())
}
}
}
pub fn get_function(&self, name: &FunctionName) -> Option<&CustomFunction> {
self.functions.get(name)
}
pub fn track_call<S: AsRef<Span>>(&mut self, span: &S, name: TrackedCallName) {
self.call_tracker.track_call(*span.as_ref(), name);
}
}
trait AbstractSyntaxTree: Sized {
type From;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError>;
}
impl Program {
pub fn analyze(from: &parse::Program) -> Result<Self, RichError> {
let unit = ResolvedType::unit();
let mut scope = Scope::default();
let items = from
.items()
.iter()
.map(|s| Item::analyze(s, &unit, &mut scope))
.collect::<Result<Vec<Item>, RichError>>()?;
debug_assert!(scope.is_topmost());
let (parameters, witness_types, call_tracker) = scope.destruct();
let mut iter = items.into_iter().filter_map(|item| match item {
Item::Function(Function::Main(expr)) => Some(expr),
_ => None,
});
let main = iter.next().ok_or(Error::MainRequired).with_span(from)?;
if iter.next().is_some() {
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from);
}
Ok(Self {
main,
parameters,
witness_types,
call_tracker: Arc::new(call_tracker),
})
}
}
impl AbstractSyntaxTree for Item {
type From = parse::Item;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Items cannot return anything");
assert!(scope.is_topmost(), "Items live in the topmost scope only");
match from {
parse::Item::TypeAlias(alias) => {
scope
.insert_alias(alias.name().clone(), alias.ty().clone())
.with_span(alias)?;
Ok(Self::TypeAlias)
}
parse::Item::Function(function) => {
Function::analyze(function, ty, scope).map(Self::Function)
}
parse::Item::Module => Ok(Self::Module),
}
}
}
impl AbstractSyntaxTree for Function {
type From = parse::Function;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Function definitions cannot return anything");
assert!(scope.is_topmost(), "Items live in the topmost scope only");
if from.name().as_inner() != "main" {
let params = from
.params()
.iter()
.map(|param| {
let identifier = param.identifier().clone();
let ty = scope.resolve(param.ty())?;
Ok(FunctionParam { identifier, ty })
})
.collect::<Result<Arc<[FunctionParam]>, Error>>()
.with_span(from)?;
let ret = from
.ret()
.as_ref()
.map(|aliased| scope.resolve(aliased).with_span(from))
.transpose()?
.unwrap_or_else(ResolvedType::unit);
scope.push_scope();
for param in params.iter() {
scope.insert_variable(param.identifier().clone(), param.ty().clone());
}
let body = Expression::analyze(from.body(), &ret, scope).map(Arc::new)?;
scope.pop_scope();
debug_assert!(scope.is_topmost());
let function = CustomFunction { params, body };
scope
.insert_function(from.name().clone(), function)
.with_span(from)?;
return Ok(Self::Custom);
}
if !from.params().is_empty() {
return Err(Error::MainNoInputs).with_span(from);
}
if let Some(aliased) = from.ret() {
let resolved = scope.resolve(aliased).with_span(from)?;
if !resolved.is_unit() {
return Err(Error::MainNoOutput).with_span(from);
}
}
scope.push_main_scope();
let body = Expression::analyze(from.body(), ty, scope)?;
scope.pop_main_scope();
Ok(Self::Main(body))
}
}
impl AbstractSyntaxTree for Statement {
type From = parse::Statement;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Statements cannot return anything");
match from {
parse::Statement::Assignment(assignment) => {
Assignment::analyze(assignment, ty, scope).map(Self::Assignment)
}
parse::Statement::Expression(expression) => {
Expression::analyze(expression, ty, scope).map(Self::Expression)
}
}
}
}
impl AbstractSyntaxTree for Assignment {
type From = parse::Assignment;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Assignments cannot return anything");
let ty_expr = scope.resolve(from.ty()).with_span(from)?;
let expression = Expression::analyze(from.expression(), &ty_expr, scope)?;
let typed_variables = from.pattern().is_of_type(&ty_expr).with_span(from)?;
for (identifier, ty) in typed_variables {
scope.insert_variable(identifier, ty);
}
Ok(Self {
pattern: from.pattern().clone(),
expression,
span: *from.as_ref(),
})
}
}
impl Expression {
pub fn analyze_const(from: &parse::Expression, ty: &ResolvedType) -> Result<Self, RichError> {
let mut empty_scope = Scope::default();
Self::analyze(from, ty, &mut empty_scope)
}
}
impl AbstractSyntaxTree for Expression {
type From = parse::Expression;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
match from.inner() {
parse::ExpressionInner::Single(single) => {
let ast_single = SingleExpression::analyze(single, ty, scope)?;
Ok(Self {
ty: ty.clone(),
inner: ExpressionInner::Single(ast_single),
span: *from.as_ref(),
})
}
parse::ExpressionInner::Block(statements, expression) => {
scope.push_scope();
let ast_statements = statements
.iter()
.map(|s| Statement::analyze(s, &ResolvedType::unit(), scope))
.collect::<Result<Arc<[Statement]>, RichError>>()?;
let ast_expression = match expression {
Some(expression) => Expression::analyze(expression, ty, scope)
.map(Arc::new)
.map(Some),
None if ty.is_unit() => Ok(None),
None => Err(Error::ExpressionTypeMismatch(
ty.clone(),
ResolvedType::unit(),
))
.with_span(from),
}?;
scope.pop_scope();
Ok(Self {
ty: ty.clone(),
inner: ExpressionInner::Block(ast_statements, ast_expression),
span: *from.as_ref(),
})
}
}
}
}
impl AbstractSyntaxTree for SingleExpression {
type From = parse::SingleExpression;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
let inner = match from.inner() {
parse::SingleExpressionInner::Boolean(bit) => {
if !ty.is_boolean() {
return Err(Error::ExpressionTypeMismatch(
ty.clone(),
ResolvedType::boolean(),
))
.with_span(from);
}
SingleExpressionInner::Constant(Value::from(*bit))
}
parse::SingleExpressionInner::Decimal(decimal) => {
let ty = ty
.as_integer()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
UIntValue::parse_decimal(decimal, ty)
.with_span(from)
.map(Value::from)
.map(SingleExpressionInner::Constant)?
}
parse::SingleExpressionInner::Binary(bits) => {
let ty = ty
.as_integer()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
let value = UIntValue::parse_binary(bits, ty).with_span(from)?;
SingleExpressionInner::Constant(Value::from(value))
}
parse::SingleExpressionInner::Hexadecimal(bytes) => {
let value = Value::parse_hexadecimal(bytes, ty).with_span(from)?;
SingleExpressionInner::Constant(value)
}
parse::SingleExpressionInner::Witness(name) => {
scope
.insert_witness(name.clone(), ty.clone())
.with_span(from)?;
SingleExpressionInner::Witness(name.clone())
}
parse::SingleExpressionInner::Parameter(name) => {
scope
.insert_parameter(name.shallow_clone(), ty.clone())
.with_span(from)?;
SingleExpressionInner::Parameter(name.shallow_clone())
}
parse::SingleExpressionInner::Variable(identifier) => {
let bound_ty = scope
.get_variable(identifier)
.ok_or(Error::UndefinedVariable(identifier.clone()))
.with_span(from)?;
if ty != bound_ty {
return Err(Error::ExpressionTypeMismatch(ty.clone(), bound_ty.clone()))
.with_span(from);
}
scope.insert_variable(identifier.clone(), ty.clone());
SingleExpressionInner::Variable(identifier.clone())
}
parse::SingleExpressionInner::Expression(parse) => {
Expression::analyze(parse, ty, scope)
.map(Arc::new)
.map(SingleExpressionInner::Expression)?
}
parse::SingleExpressionInner::Tuple(tuple) => {
let types = ty
.as_tuple()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
if tuple.len() != types.len() {
return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from);
}
tuple
.iter()
.zip(types.iter())
.map(|(el_parse, el_ty)| Expression::analyze(el_parse, el_ty, scope))
.collect::<Result<Arc<[Expression]>, RichError>>()
.map(SingleExpressionInner::Tuple)?
}
parse::SingleExpressionInner::Array(array) => {
let (el_ty, size) = ty
.as_array()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
if array.len() != size {
return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from);
}
array
.iter()
.map(|el_parse| Expression::analyze(el_parse, el_ty, scope))
.collect::<Result<Arc<[Expression]>, RichError>>()
.map(SingleExpressionInner::Array)?
}
parse::SingleExpressionInner::List(list) => {
let (el_ty, bound) = ty
.as_list()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
if bound.get() <= list.len() {
return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from);
}
list.iter()
.map(|e| Expression::analyze(e, el_ty, scope))
.collect::<Result<Arc<[Expression]>, RichError>>()
.map(SingleExpressionInner::List)?
}
parse::SingleExpressionInner::Either(either) => {
let (ty_l, ty_r) = ty
.as_either()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
match either {
Either::Left(parse_l) => Expression::analyze(parse_l, ty_l, scope)
.map(Arc::new)
.map(Either::Left),
Either::Right(parse_r) => Expression::analyze(parse_r, ty_r, scope)
.map(Arc::new)
.map(Either::Right),
}
.map(SingleExpressionInner::Either)?
}
parse::SingleExpressionInner::Option(maybe_parse) => {
let ty = ty
.as_option()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;
match maybe_parse {
Some(parse) => {
Some(Expression::analyze(parse, ty, scope).map(Arc::new)).transpose()
}
None => Ok(None),
}
.map(SingleExpressionInner::Option)?
}
parse::SingleExpressionInner::Call(call) => {
Call::analyze(call, ty, scope).map(SingleExpressionInner::Call)?
}
parse::SingleExpressionInner::Match(match_) => {
Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)?
}
};
Ok(Self {
inner,
ty: ty.clone(),
span: *from.as_ref(),
})
}
}
impl AbstractSyntaxTree for Call {
type From = parse::Call;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
fn check_argument_types(
parse_args: &[parse::Expression],
expected_tys: &[ResolvedType],
) -> Result<(), Error> {
if parse_args.len() != expected_tys.len() {
Err(Error::InvalidNumberOfArguments(
expected_tys.len(),
parse_args.len(),
))
} else {
Ok(())
}
}
fn check_output_type(
observed_ty: &ResolvedType,
expected_ty: &ResolvedType,
) -> Result<(), Error> {
if observed_ty != expected_ty {
Err(Error::ExpressionTypeMismatch(
expected_ty.clone(),
observed_ty.clone(),
))
} else {
Ok(())
}
}
fn analyze_arguments(
parse_args: &[parse::Expression],
args_tys: &[ResolvedType],
scope: &mut Scope,
) -> Result<Arc<[Expression]>, RichError> {
let args = parse_args
.iter()
.zip(args_tys.iter())
.map(|(arg_parse, arg_ty)| Expression::analyze(arg_parse, arg_ty, scope))
.collect::<Result<Arc<[Expression]>, RichError>>()?;
Ok(args)
}
let name = CallName::analyze(from, ty, scope)?;
let args = match name.clone() {
CallName::Jet(jet) => {
let args_tys = crate::jet::source_type(jet)
.iter()
.map(AliasedType::resolve_builtin)
.collect::<Result<Vec<ResolvedType>, AliasName>>()
.map_err(Error::UndefinedAlias)
.with_span(from)?;
check_argument_types(from.args(), &args_tys).with_span(from)?;
let out_ty = crate::jet::target_type(jet)
.resolve_builtin()
.map_err(Error::UndefinedAlias)
.with_span(from)?;
check_output_type(&out_ty, ty).with_span(from)?;
scope.track_call(from, TrackedCallName::Jet);
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::UnwrapLeft(right_ty) => {
let args_tys = [ResolvedType::either(ty.clone(), right_ty)];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let args = analyze_arguments(from.args(), &args_tys, scope)?;
let [arg_ty] = args_tys;
scope.track_call(from, TrackedCallName::UnwrapLeft(arg_ty));
args
}
CallName::UnwrapRight(left_ty) => {
let args_tys = [ResolvedType::either(left_ty, ty.clone())];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let args = analyze_arguments(from.args(), &args_tys, scope)?;
let [arg_ty] = args_tys;
scope.track_call(from, TrackedCallName::UnwrapRight(arg_ty));
args
}
CallName::IsNone(some_ty) => {
let args_tys = [ResolvedType::option(some_ty)];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let out_ty = ResolvedType::boolean();
check_output_type(&out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::Unwrap => {
let args_tys = [ResolvedType::option(ty.clone())];
check_argument_types(from.args(), &args_tys).with_span(from)?;
scope.track_call(from, TrackedCallName::Unwrap);
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::Assert => {
let args_tys = [ResolvedType::boolean()];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let out_ty = ResolvedType::unit();
check_output_type(&out_ty, ty).with_span(from)?;
scope.track_call(from, TrackedCallName::Assert);
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::Panic => {
let args_tys = [];
check_argument_types(from.args(), &args_tys).with_span(from)?;
scope.track_call(from, TrackedCallName::Panic);
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::Debug => {
let args_tys = [ty.clone()];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let args = analyze_arguments(from.args(), &args_tys, scope)?;
let [arg_ty] = args_tys;
scope.track_call(from, TrackedCallName::Debug(arg_ty));
args
}
CallName::TypeCast(source) => {
if StructuralType::from(&source) != StructuralType::from(ty) {
return Err(Error::InvalidCast(source, ty.clone())).with_span(from);
}
let args_tys = [source];
check_argument_types(from.args(), &args_tys).with_span(from)?;
analyze_arguments(from.args(), &args_tys, scope)?
}
CallName::Custom(function) => {
let args_ty = function
.params()
.iter()
.map(FunctionParam::ty)
.cloned()
.collect::<Vec<ResolvedType>>();
check_argument_types(from.args(), &args_ty).with_span(from)?;
let out_ty = function.body().ty();
check_output_type(out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_ty, scope)?
}
CallName::Fold(function, bound) => {
let element_ty = function.params().first().expect("foldable function").ty();
let list_ty = ResolvedType::list(element_ty.clone(), bound);
let accumulator_ty = function
.params()
.get(1)
.expect("foldable function")
.ty()
.clone();
let args_ty = [list_ty, accumulator_ty];
check_argument_types(from.args(), &args_ty).with_span(from)?;
let out_ty = function.body().ty();
check_output_type(out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_ty, scope)?
}
CallName::ForWhile(function, _bit_width) => {
let accumulator_ty = function
.params()
.first()
.expect("loopable function")
.ty()
.clone();
let context_ty = function
.params()
.get(1)
.expect("loopable function")
.ty()
.clone();
let args_ty = [accumulator_ty, context_ty];
check_argument_types(from.args(), &args_ty).with_span(from)?;
let out_ty = function.body().ty();
check_output_type(out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_ty, scope)?
}
};
Ok(Self {
name,
args,
span: *from.as_ref(),
})
}
}
impl AbstractSyntaxTree for CallName {
type From = parse::Call;
fn analyze(
from: &Self::From,
_ty: &ResolvedType,
scope: &mut Scope,
) -> Result<Self, RichError> {
match from.name() {
parse::CallName::Jet(name) => match Elements::from_str(name.as_inner()) {
Ok(Elements::CheckSigVerify | Elements::Verify) | Err(_) => {
Err(Error::JetDoesNotExist(name.clone())).with_span(from)
}
Ok(jet) => Ok(Self::Jet(jet)),
},
parse::CallName::UnwrapLeft(right_ty) => scope
.resolve(right_ty)
.map(Self::UnwrapLeft)
.with_span(from),
parse::CallName::UnwrapRight(left_ty) => scope
.resolve(left_ty)
.map(Self::UnwrapRight)
.with_span(from),
parse::CallName::IsNone(some_ty) => {
scope.resolve(some_ty).map(Self::IsNone).with_span(from)
}
parse::CallName::Unwrap => Ok(Self::Unwrap),
parse::CallName::Assert => Ok(Self::Assert),
parse::CallName::Panic => Ok(Self::Panic),
parse::CallName::Debug => Ok(Self::Debug),
parse::CallName::TypeCast(target) => {
scope.resolve(target).map(Self::TypeCast).with_span(from)
}
parse::CallName::Custom(name) => scope
.get_function(name)
.cloned()
.map(Self::Custom)
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from),
parse::CallName::Fold(name, bound) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
{
Err(Error::FunctionNotFoldable(name.clone())).with_span(from)
} else {
Ok(Self::Fold(function, *bound))
}
}
parse::CallName::ForWhile(name) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
if function.params().len() != 3 {
return Err(Error::FunctionNotLoopable(name.clone())).with_span(from);
}
match function.body().ty().as_either() {
Some((_, out_r)) if out_r == function.params().first().unwrap().ty() => {}
_ => {
return Err(Error::FunctionNotLoopable(name.clone())).with_span(from);
}
}
match function.params().get(2).unwrap().ty().as_integer() {
Some(
int_ty @ (UIntType::U1
| UIntType::U2
| UIntType::U4
| UIntType::U8
| UIntType::U16),
) => Ok(Self::ForWhile(function, int_ty.bit_width())),
_ => Err(Error::FunctionNotLoopable(name.clone())).with_span(from),
}
}
}
}
}
impl AbstractSyntaxTree for Match {
type From = parse::Match;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
let scrutinee_ty = from.scrutinee_type();
let scrutinee_ty = scope.resolve(&scrutinee_ty).with_span(from)?;
let scrutinee =
Expression::analyze(from.scrutinee(), &scrutinee_ty, scope).map(Arc::new)?;
scope.push_scope();
if let Some((id_l, ty_l)) = from.left().pattern().as_typed_variable() {
let ty_l = scope.resolve(ty_l).with_span(from)?;
scope.insert_variable(id_l.clone(), ty_l);
}
let ast_l = Expression::analyze(from.left().expression(), ty, scope).map(Arc::new)?;
scope.pop_scope();
scope.push_scope();
if let Some((id_r, ty_r)) = from.right().pattern().as_typed_variable() {
let ty_r = scope.resolve(ty_r).with_span(from)?;
scope.insert_variable(id_r.clone(), ty_r);
}
let ast_r = Expression::analyze(from.right().expression(), ty, scope).map(Arc::new)?;
scope.pop_scope();
Ok(Self {
scrutinee,
left: MatchArm {
pattern: from.left().pattern().clone(),
expression: ast_l,
},
right: MatchArm {
pattern: from.right().pattern().clone(),
expression: ast_r,
},
span: *from.as_ref(),
})
}
}
fn analyze_named_module(
name: ModuleName,
from: &parse::ModuleProgram,
) -> Result<HashMap<WitnessName, Value>, RichError> {
let unit = ResolvedType::unit();
let mut scope = Scope::default();
let items = from
.items()
.iter()
.map(|s| ModuleItem::analyze(s, &unit, &mut scope))
.collect::<Result<Vec<ModuleItem>, RichError>>()?;
debug_assert!(scope.is_topmost());
let mut iter = items.into_iter().filter_map(|item| match item {
ModuleItem::Module(module) if module.name == name => Some(module),
_ => None,
});
let witness_module = iter
.next()
.ok_or(Error::ModuleRequired(name.shallow_clone()))
.with_span(from)?;
if iter.next().is_some() {
return Err(Error::ModuleRedefined(name)).with_span(from);
}
let mut map = HashMap::new();
for assignment in witness_module.assignments() {
if map.contains_key(assignment.name()) {
return Err(Error::WitnessReassigned(assignment.name().shallow_clone()))
.with_span(assignment);
}
map.insert(
assignment.name().shallow_clone(),
assignment.value().clone(),
);
}
Ok(map)
}
impl WitnessValues {
pub fn analyze(from: &parse::ModuleProgram) -> Result<Self, RichError> {
analyze_named_module(ModuleName::witness(), from).map(Self::from)
}
}
impl crate::witness::Arguments {
pub fn analyze(from: &parse::ModuleProgram) -> Result<Self, RichError> {
analyze_named_module(ModuleName::param(), from).map(Self::from)
}
}
impl AbstractSyntaxTree for ModuleItem {
type From = parse::ModuleItem;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Items cannot return anything");
assert!(scope.is_topmost(), "Items live in the topmost scope only");
match from {
parse::ModuleItem::Ignored => Ok(Self::Ignored),
parse::ModuleItem::Module(witness_module) => {
Module::analyze(witness_module, ty, scope).map(Self::Module)
}
}
}
}
impl AbstractSyntaxTree for Module {
type From = parse::Module;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Modules cannot return anything");
assert!(scope.is_topmost(), "Modules live in the topmost scope only");
let assignments = from
.assignments()
.iter()
.map(|s| ModuleAssignment::analyze(s, ty, scope))
.collect::<Result<Arc<[ModuleAssignment]>, RichError>>()?;
debug_assert!(scope.is_topmost());
Ok(Self {
name: from.name().shallow_clone(),
span: *from.as_ref(),
assignments,
})
}
}
impl AbstractSyntaxTree for ModuleAssignment {
type From = parse::ModuleAssignment;
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Assignments cannot return anything");
let ty_expr = scope.resolve(from.ty()).with_span(from)?;
let expression = Expression::analyze(from.expression(), &ty_expr, scope)?;
let value = Value::from_const_expr(&expression)
.ok_or(Error::ExpressionUnexpectedType(ty_expr.clone()))
.with_span(from.expression())?;
Ok(Self {
name: from.name().clone(),
value,
span: *from.as_ref(),
})
}
}
impl AsRef<Span> for Assignment {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for Expression {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for SingleExpression {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for Call {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for Match {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for Module {
fn as_ref(&self) -> &Span {
&self.span
}
}
impl AsRef<Span> for ModuleAssignment {
fn as_ref(&self) -> &Span {
&self.span
}
}