use crate::{Opcode, Operand, Registers, Stack};
use console::{
network::prelude::*,
program::{Register, RegisterType},
};
pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
enum Variant {
AssertEq,
AssertNeq,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AssertInstruction<N: Network, const VARIANT: u8> {
operands: Vec<Operand<N>>,
}
impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
#[inline]
pub const fn opcode() -> Opcode {
match VARIANT {
0 => Opcode::Assert("assert.eq"),
1 => Opcode::Assert("assert.neq"),
_ => panic!("Invalid 'assert' instruction opcode"),
}
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
vec![]
}
}
impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
#[inline]
pub fn evaluate<A: circuit::Aleo<Network = N>>(
&self,
stack: &Stack<N>,
registers: &mut Registers<N, A>,
) -> Result<()> {
if self.operands.len() != 2 {
bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
}
let inputs: Vec<_> = self.operands.iter().map(|operand| registers.load(stack, operand)).try_collect()?;
let (input_a, input_b) = (inputs[0].clone(), inputs[1].clone());
match VARIANT {
0 => {
if input_a != input_b {
bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
}
}
1 => {
if input_a == input_b {
bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
}
}
_ => bail!("Invalid 'assert' variant: {VARIANT}"),
}
Ok(())
}
#[inline]
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
stack: &Stack<N>,
registers: &mut Registers<N, A>,
) -> Result<()> {
if self.operands.len() != 2 {
bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
}
let inputs: Vec<_> =
self.operands.iter().map(|operand| registers.load_circuit(stack, operand)).try_collect()?;
let (input_a, input_b) = (inputs[0].clone(), inputs[1].clone());
match VARIANT {
0 => A::assert(input_a.is_equal(&input_b)),
1 => A::assert(input_a.is_not_equal(&input_b)),
_ => bail!("Invalid 'assert' variant: {VARIANT}"),
}
Ok(())
}
#[inline]
pub fn output_types(&self, _stack: &Stack<N>, input_types: &[RegisterType<N>]) -> Result<Vec<RegisterType<N>>> {
if input_types.len() != 2 {
bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
}
if input_types[0] != input_types[1] {
bail!(
"Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
Self::opcode(),
input_types[0],
input_types[1]
)
}
if self.operands.len() != 2 {
bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
}
match VARIANT {
0 | 1 => Ok(vec![]),
_ => bail!("Invalid 'assert' variant: {VARIANT}"),
}
}
}
impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
let (string, _) = tag(*Self::opcode())(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, first) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, second) = Operand::parse(string)?;
Ok((string, Self { operands: vec![first, second] }))
}
}
impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if self.operands.len() != 2 {
eprintln!("The number of operands must be 2, found {}", self.operands.len());
return Err(fmt::Error);
}
write!(f, "{} ", Self::opcode())?;
self.operands.iter().try_for_each(|operand| write!(f, "{} ", operand))
}
}
impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let mut operands = Vec::with_capacity(2);
for _ in 0..2 {
operands.push(Operand::read_le(&mut reader)?);
}
Ok(Self { operands })
}
}
impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
if self.operands.len() != 2 {
return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
}
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
}
}
#[cfg(test)]
mod tests {
use super::*;
use circuit::AleoV0;
use console::{
network::Testnet3,
program::{Literal, LiteralType},
};
type CurrentNetwork = Testnet3;
type CurrentAleo = AleoV0;
fn sample_stack(
opcode: Opcode,
type_a: LiteralType,
type_b: LiteralType,
mode_a: circuit::Mode,
mode_b: circuit::Mode,
) -> Result<(Stack<CurrentNetwork>, Vec<Operand<CurrentNetwork>>)> {
use crate::{Process, Program};
use console::program::Identifier;
let opcode = opcode.to_string();
let function_name = Identifier::<CurrentNetwork>::from_str("run")?;
let r0 = Register::Locator(0);
let r1 = Register::Locator(1);
let program = Program::from_str(&format!(
"program testing.aleo;
function {function_name}:
input {r0} as {type_a}.{mode_a};
input {r1} as {type_b}.{mode_b};
{opcode} {r0} {r1};
"
))?;
let operand_a = Operand::Register(r0);
let operand_b = Operand::Register(r1);
let operands = vec![operand_a, operand_b];
let stack = Stack::new(&Process::load()?, &program)?;
Ok((stack, operands))
}
fn sample_registers(
stack: &Stack<CurrentNetwork>,
literal_a: &Literal<CurrentNetwork>,
literal_b: &Literal<CurrentNetwork>,
mode_a: Option<circuit::Mode>,
mode_b: Option<circuit::Mode>,
) -> Result<Registers<CurrentNetwork, CurrentAleo>> {
use crate::{Authorization, CallStack};
use console::program::{Identifier, Plaintext, Value};
let function_name = Identifier::from_str("run")?;
let mut registers = Registers::<CurrentNetwork, CurrentAleo>::new(
CallStack::evaluate(Authorization::new(&[]))?,
stack.get_register_types(&function_name)?.clone(),
);
let r0 = Register::Locator(0);
let r1 = Register::Locator(1);
let value_a = Value::Plaintext(Plaintext::from(literal_a));
let value_b = Value::Plaintext(Plaintext::from(literal_b));
registers.store(stack, &r0, value_a.clone())?;
registers.store(stack, &r1, value_b.clone())?;
if let (Some(mode_a), Some(mode_b)) = (mode_a, mode_b) {
use circuit::Inject;
let circuit_a = circuit::Value::new(mode_a, value_a);
let circuit_b = circuit::Value::new(mode_b, value_b);
registers.store_circuit(stack, &r0, circuit_a)?;
registers.store_circuit(stack, &r1, circuit_b)?;
}
Ok(registers)
}
fn check_assert<const VARIANT: u8>(
operation: impl FnOnce(Vec<Operand<CurrentNetwork>>) -> AssertInstruction<CurrentNetwork, VARIANT>,
opcode: Opcode,
literal_a: &Literal<CurrentNetwork>,
literal_b: &Literal<CurrentNetwork>,
mode_a: &circuit::Mode,
mode_b: &circuit::Mode,
) {
println!("Checking '{opcode}' for '{literal_a}.{mode_a}' and '{literal_b}.{mode_b}'");
let type_a = literal_a.to_type();
let type_b = literal_b.to_type();
assert_eq!(type_a, type_b, "The two literals must be the *same* type for this test");
let (stack, operands) = sample_stack(opcode, type_a, type_b, *mode_a, *mode_b).unwrap();
let operation = operation(operands);
{
let mut registers = sample_registers(&stack, literal_a, literal_a, None, None).unwrap();
let result_a = operation.evaluate(&stack, &mut registers);
match VARIANT {
0 => assert!(result_a.is_ok(), "Instruction '{operation}' failed (console): {literal_a} {literal_a}"),
1 => assert!(
result_a.is_err(),
"Instruction '{operation}' should have failed (console): {literal_a} {literal_a}"
),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
let mut registers = sample_registers(&stack, literal_a, literal_a, Some(*mode_a), Some(*mode_a)).unwrap();
let result_b = operation.execute::<CurrentAleo>(&stack, &mut registers);
match VARIANT {
0 => assert!(
result_b.is_ok(),
"Instruction '{operation}' failed (circuit): {literal_a}.{mode_a} {literal_a}.{mode_a}"
),
1 => assert!(
result_b.is_ok(), "Instruction '{operation}' should not have panicked (circuit): {literal_a}.{mode_a} {literal_a}.{mode_a}"
),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
match VARIANT {
0 => assert!(
<CurrentAleo as circuit::Environment>::is_satisfied(),
"Instruction '{operation}' should be satisfied (circuit): {literal_a}.{mode_a} {literal_a}.{mode_a}"
),
1 => assert!(
!<CurrentAleo as circuit::Environment>::is_satisfied(),
"Instruction '{operation}' should not be satisfied (circuit): {literal_a}.{mode_a} {literal_a}.{mode_a}"
),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
<CurrentAleo as circuit::Environment>::reset();
}
if literal_a != literal_b {
let mut registers = sample_registers(&stack, literal_a, literal_b, None, None).unwrap();
let result_a = operation.evaluate(&stack, &mut registers);
match VARIANT {
0 => assert!(
result_a.is_err(),
"Instruction '{operation}' should have failed (console): {literal_a} {literal_b}"
),
1 => assert!(result_a.is_ok(), "Instruction '{operation}' failed (console): {literal_a} {literal_b}"),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
let mut registers = sample_registers(&stack, literal_a, literal_b, Some(*mode_a), Some(*mode_b)).unwrap();
let result_b = operation.execute::<CurrentAleo>(&stack, &mut registers);
match VARIANT {
0 => assert!(
result_b.is_ok(), "Instruction '{operation}' should not have panicked (circuit): {literal_a}.{mode_a} {literal_b}.{mode_b}"
),
1 => assert!(
result_b.is_ok(),
"Instruction '{operation}' failed (circuit): {literal_a}.{mode_a} {literal_b}.{mode_b}"
),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
match VARIANT {
0 => assert!(
!<CurrentAleo as circuit::Environment>::is_satisfied(),
"Instruction '{operation}' should not be satisfied (circuit): {literal_a}.{mode_a} {literal_b}.{mode_b}"
),
1 => assert!(
<CurrentAleo as circuit::Environment>::is_satisfied(),
"Instruction '{operation}' should be satisfied (circuit): {literal_a}.{mode_a} {literal_b}.{mode_b}"
),
_ => panic!("Found an invalid 'assert' variant in the test"),
}
<CurrentAleo as circuit::Environment>::reset();
}
}
fn check_assert_fails(
opcode: Opcode,
literal_a: &Literal<CurrentNetwork>,
literal_b: &Literal<CurrentNetwork>,
mode_a: &circuit::Mode,
mode_b: &circuit::Mode,
) {
let type_a = literal_a.to_type();
let type_b = literal_b.to_type();
assert_ne!(type_a, type_b, "The two literals must be *different* types for this test");
let result = sample_stack(opcode, type_a, type_b, *mode_a, *mode_b);
assert!(
result.is_err(),
"Stack should have failed to initialize for: {opcode} {type_a}.{mode_a} {type_b}.{mode_b}"
);
}
#[test]
fn test_assert_eq_succeeds() {
let operation = |operands| AssertEq::<CurrentNetwork> { operands };
let opcode = AssertEq::<CurrentNetwork>::opcode();
let mut rng = test_rng();
let literals_a = crate::sample_literals!(CurrentNetwork, &mut rng);
let literals_b = crate::sample_literals!(CurrentNetwork, &mut rng);
let modes_a = [ circuit::Mode::Public, circuit::Mode::Private];
let modes_b = [ circuit::Mode::Public, circuit::Mode::Private];
for (literal_a, literal_b) in literals_a.iter().zip_eq(literals_b.iter()) {
for mode_a in &modes_a {
for mode_b in &modes_b {
check_assert(operation, opcode, literal_a, literal_b, mode_a, mode_b);
}
}
}
}
#[test]
#[serial_test::serial]
fn test_assert_eq_fails() {
use rayon::prelude::*;
let opcode = AssertEq::<CurrentNetwork>::opcode();
let mut rng = test_rng();
let literals_a = crate::sample_literals!(CurrentNetwork, &mut rng);
let literals_b = crate::sample_literals!(CurrentNetwork, &mut rng);
let modes_a = [ circuit::Mode::Public, circuit::Mode::Private];
let modes_b = [ circuit::Mode::Public, circuit::Mode::Private];
literals_a.par_iter().for_each(|literal_a| {
for literal_b in &literals_b {
if literal_a.to_type() != literal_b.to_type() {
for mode_a in &modes_a {
for mode_b in &modes_b {
check_assert_fails(opcode, literal_a, literal_b, mode_a, mode_b);
}
}
}
}
});
}
#[test]
fn test_assert_neq_succeeds() {
let operation = |operands| AssertNeq::<CurrentNetwork> { operands };
let opcode = AssertNeq::<CurrentNetwork>::opcode();
let mut rng = test_rng();
let literals_a = crate::sample_literals!(CurrentNetwork, &mut rng);
let literals_b = crate::sample_literals!(CurrentNetwork, &mut rng);
let modes_a = [ circuit::Mode::Public, circuit::Mode::Private];
let modes_b = [ circuit::Mode::Public, circuit::Mode::Private];
for (literal_a, literal_b) in literals_a.iter().zip_eq(literals_b.iter()) {
for mode_a in &modes_a {
for mode_b in &modes_b {
check_assert(operation, opcode, literal_a, literal_b, mode_a, mode_b);
}
}
}
}
#[test]
#[serial_test::serial]
fn test_assert_neq_fails() {
use rayon::prelude::*;
let opcode = AssertNeq::<CurrentNetwork>::opcode();
let mut rng = test_rng();
let literals_a = crate::sample_literals!(CurrentNetwork, &mut rng);
let literals_b = crate::sample_literals!(CurrentNetwork, &mut rng);
let modes_a = [ circuit::Mode::Public, circuit::Mode::Private];
let modes_b = [ circuit::Mode::Public, circuit::Mode::Private];
literals_a.par_iter().for_each(|literal_a| {
for literal_b in &literals_b {
if literal_a.to_type() != literal_b.to_type() {
for mode_a in &modes_a {
for mode_b in &modes_b {
check_assert_fails(opcode, literal_a, literal_b, mode_a, mode_b);
}
}
}
}
});
}
#[test]
fn test_parse() {
let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
}
}