use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait, register_types_equivalent};
use console::{
network::prelude::*,
program::{Register, RegisterType},
};
use snarkvm_synthesizer_error::*;
pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
enum Variant {
AssertEq,
AssertNeq,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AssertInstruction<N: Network, const VARIANT: u8> {
operands: Vec<Operand<N>>,
}
impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
#[inline]
pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
ensure!(operands.len() == 2, "Assert instructions must have two operands");
Ok(Self { operands })
}
#[inline]
pub const fn opcode() -> Opcode {
match VARIANT {
0 => Opcode::Assert("assert.eq"),
1 => Opcode::Assert("assert.neq"),
_ => panic!("Invalid 'assert' instruction opcode"),
}
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
vec![]
}
#[inline]
pub fn contains_external_struct(&self) -> bool {
false
}
}
impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
pub fn evaluate(
&self,
stack: &impl StackTrait<N>,
registers: &mut impl RegistersTrait<N>,
) -> Result<(), EvalError> {
if self.operands.len() != 2 {
return Err(anyhow!(
"Instruction '{}' expects 2 operands, found {} operands",
Self::opcode(),
self.operands.len()
)
.into());
}
let input_a = registers.load(stack, &self.operands[0])?;
let input_b = registers.load(stack, &self.operands[1])?;
match VARIANT {
0 => {
if input_a != input_b {
return Err(AssertError::Eq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
}
}
1 => {
if input_a == input_b {
return Err(AssertError::Neq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
}
}
_ => return Err(AssertError::Invalid { variant: VARIANT }.into()),
}
Ok(())
}
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
stack: &impl StackTrait<N>,
registers: &mut impl RegistersCircuit<N, A>,
) -> Result<(), ExecError> {
if self.operands.len() != 2 {
return Err(anyhow!(
"Instruction '{}' expects 2 operands, found {} operands",
Self::opcode(),
self.operands.len()
)
.into());
}
let input_a = registers.load_circuit(stack, &self.operands[0])?;
let input_b = registers.load_circuit(stack, &self.operands[1])?;
match VARIANT {
0 => A::assert(input_a.is_equal(&input_b))?,
1 => A::assert(input_a.is_not_equal(&input_b))?,
_ => return Err(anyhow!("Invalid 'assert' variant: {VARIANT}").into()),
}
Ok(())
}
#[inline]
pub fn finalize(
&self,
stack: &impl StackTrait<N>,
registers: &mut impl RegistersTrait<N>,
) -> Result<(), FinalizeError> {
self.evaluate(stack, registers)?;
Ok(())
}
pub fn output_types(
&self,
stack: &impl StackTrait<N>,
input_types: &[RegisterType<N>],
) -> Result<Vec<RegisterType<N>>> {
if input_types.len() != 2 {
bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
}
if !register_types_equivalent(stack, &input_types[0], stack, &input_types[1])? {
bail!(
"Instruction '{}' expects inputs of equivalent types. Found inputs of type '{}' and '{}'",
Self::opcode(),
input_types[0],
input_types[1]
)
}
if self.operands.len() != 2 {
bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
}
match VARIANT {
0 | 1 => Ok(vec![]),
_ => bail!("Invalid 'assert' variant: {VARIANT}"),
}
}
}
impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
fn parse(string: &str) -> ParserResult<Self> {
let (string, _) = tag(*Self::opcode())(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, first) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, second) = Operand::parse(string)?;
Ok((string, Self { operands: vec![first, second] }))
}
}
impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
type Err = Error;
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if self.operands.len() != 2 {
return Err(fmt::Error);
}
write!(f, "{}", Self::opcode())?;
self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
}
}
impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let mut operands = Vec::with_capacity(2);
for _ in 0..2 {
operands.push(Operand::read_le(&mut reader)?);
}
Ok(Self { operands })
}
}
impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
if self.operands.len() != 2 {
return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
}
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
}
}
#[cfg(test)]
mod tests {
use super::*;
use console::network::MainnetV0;
type CurrentNetwork = MainnetV0;
#[test]
fn test_parse() {
let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
}
}