use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait};
use console::{
network::prelude::*,
program::{Literal, LiteralType, Plaintext, PlaintextType, Register, RegisterType, Value},
types::Boolean,
};
use snarkvm_algorithms::snark::varuna::VarunaVersion;
use snarkvm_synthesizer_snark::{Proof, VerifyingKey};
pub type SnarkVerify<N> = SnarkVerification<N, { SnarkVerifyVariant::Varuna as u8 }>;
pub type SnarkVerifyBatch<N> = SnarkVerification<N, { SnarkVerifyVariant::VarunaBatch as u8 }>;
pub const MAX_SNARK_VERIFY_CIRCUITS: u32 = 1 << 5; pub const MAX_SNARK_VERIFY_INSTANCES: u32 = 1 << 7;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum SnarkVerifyVariant {
Varuna,
VarunaBatch,
}
impl SnarkVerifyVariant {
pub const fn new(variant: u8) -> Self {
match variant {
0 => Self::Varuna,
1 => Self::VarunaBatch,
_ => panic!("Invalid 'snark.verify' instruction opcode"),
}
}
pub const fn opcode(&self) -> &'static str {
match self {
Self::Varuna => "snark.verify",
Self::VarunaBatch => "snark.verify.batch",
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct SnarkVerification<N: Network, const VARIANT: u8> {
operands: Vec<Operand<N>>,
destination: Register<N>,
}
impl<N: Network, const VARIANT: u8> SnarkVerification<N, VARIANT> {
#[inline]
pub fn new(operands: Vec<Operand<N>>, destination: Register<N>) -> Result<Self> {
ensure!(operands.len() == 4, "Instruction '{}' must have four operands", Self::opcode());
Ok(Self { operands, destination })
}
#[inline]
pub const fn opcode() -> Opcode {
Opcode::Snark(SnarkVerifyVariant::new(VARIANT).opcode())
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
debug_assert!(self.operands.len() == 4, "Instruction '{}' must have four operands", Self::opcode());
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
vec![self.destination.clone()]
}
#[inline]
pub fn contains_external_struct(&self) -> bool {
false
}
}
#[rustfmt::skip]
macro_rules! do_snark_verification {
($variant: expr, $function_name: expr, $verifying_key: expr, $varuna_version: expr, $inputs: expr, $proof: expr) => {{
let verifying_key = || match $verifying_key {
Value::Plaintext(plaintext) => VerifyingKey::<N>::from_bytes_le(&plaintext.as_byte_array()?),
_ => bail!("Expected the first operand to be a byte array."),
};
let verifying_keys = || match $verifying_key {
Value::Plaintext(Plaintext::Array(array, _)) => {
array
.into_iter()
.map(|plaintext| {
VerifyingKey::<N>::from_bytes_le(&plaintext.as_byte_array()?)
})
.collect::<Result<Vec<VerifyingKey<N>>, _>>()
}
_ => bail!("Expected the first operand to be a two-dimensional byte array."),
};
let varuna_version = || match $varuna_version {
Value::Plaintext(Plaintext::Literal(Literal::U8(version), _)) => VarunaVersion::from_bytes_le(&[*version]),
_ => bail!("Expected the Varuna version to be a U8 literal."),
};
let inputs = || match $inputs {
Value::Plaintext(plaintext) => Ok(plaintext.as_field_array()?.into_iter().map(|f| *f).collect::<Vec<_>>()),
_ => bail!("Expected the second operand to be an array of fields."),
};
let batch_inputs = || match $inputs {
Value::Plaintext(Plaintext::Array(outer, _)) => {
outer
.into_iter()
.map(|mid| {
match mid {
Plaintext::Array(inner, _) => {
inner
.into_iter()
.map(|row| {
let fs = row.as_field_array()?;
Ok(fs.into_iter().map(|f| *f).collect::<Vec<N::Field>>())
})
.collect::<Result<Vec<Vec<N::Field>>>>()
}
_ => bail!("Expected an inner array (second dimension) of fields."),
}
})
.collect::<Result<Vec<Vec<Vec<N::Field>>>>>()
}
_ => bail!("Expected the second operand to be a three-dimensional array of fields."),
};
let varuna_proof = || match $proof {
Value::Plaintext(plaintext) => {
let bytes = plaintext.as_byte_array()?;
Proof::<N>::from_bytes_le(&bytes)
}
_ => bail!("Expected the third operand to be a byte array."),
};
let trimmed_inputs = |vk: &VerifyingKey<N>, inputs: &[N::Field]| -> Result<Vec<N::Field>> {
let num_public_inputs = vk.circuit_info.num_public_inputs as usize;
ensure!(
inputs.len() >= num_public_inputs,
"The number of public inputs ({}) is less than the expected number of public inputs ({}).",
inputs.len(),
num_public_inputs
);
for input in &inputs[num_public_inputs..] {
ensure!(input.is_zero(), "Excess public inputs must be zero.");
}
Ok(inputs[..num_public_inputs].to_vec())
};
match $variant {
SnarkVerifyVariant::Varuna => {
let vk = verifying_key()?;
let inputs_vec = inputs()?;
let trimmed = trimmed_inputs(&vk, &inputs_vec)?;
vk.verify($function_name, varuna_version()?, &trimmed, &varuna_proof()?)
}
SnarkVerifyVariant::VarunaBatch => {
let vks = verifying_keys()?;
let batch_inputs_vec = batch_inputs()?;
ensure!(
vks.len() == batch_inputs_vec.len(),
"The number of verifying keys ({}) does not match the number of input batches ({})",
vks.len(),
batch_inputs_vec.len()
);
let trimmed_batch: Vec<Vec<Vec<N::Field>>> = vks
.iter()
.zip_eq(batch_inputs_vec.iter())
.map(|(vk, instances)| {
instances
.iter()
.map(|instance_inputs| trimmed_inputs(vk, instance_inputs))
.collect::<Result<Vec<Vec<N::Field>>>>()
})
.collect::<Result<Vec<Vec<Vec<N::Field>>>>>()?;
VerifyingKey::verify_batch(
$function_name,
varuna_version()?,
vks.into_iter().zip_eq(trimmed_batch).collect(),
&varuna_proof()?
).is_ok()
}
}
}};
}
pub fn evaluate_varuna_proof<N: Network>(
variant: SnarkVerifyVariant,
_function_name: &str,
verifying_key: &Value<N>,
varuna_version: Value<N>,
inputs: &Value<N>,
proof: &Value<N>,
) -> Result<bool> {
evaluate_varuna_proof_internal(variant, _function_name, verifying_key, varuna_version, inputs, proof)
}
fn evaluate_varuna_proof_internal<N: Network>(
variant: SnarkVerifyVariant,
_function_name: &str,
verifying_key: &Value<N>,
varuna_version: Value<N>,
inputs: &Value<N>,
proof: &Value<N>,
) -> Result<bool> {
Ok(do_snark_verification!(variant, _function_name, verifying_key, varuna_version, inputs, proof))
}
fn check_nd_array_type<N: Network>(
register_type: &RegisterType<N>,
base_literal_type: LiteralType,
dimensions: usize,
) -> bool {
if dimensions == 0 {
return matches!(register_type, RegisterType::Plaintext(PlaintextType::Literal(lit)) if *lit == base_literal_type);
}
let mut arr = match register_type {
RegisterType::Plaintext(PlaintextType::Array(a)) => a,
_ => return false,
};
for _ in 1..dimensions {
match arr.next_element_type() {
PlaintextType::Array(next) => arr = next,
_ => return false,
}
}
matches!(arr.next_element_type(), PlaintextType::Literal(lit) if *lit == base_literal_type)
&& matches!(arr.base_element_type(), PlaintextType::Literal(lit) if *lit == base_literal_type)
}
impl<N: Network, const VARIANT: u8> SnarkVerification<N, VARIANT> {
#[inline]
pub fn evaluate(&self, _stack: &impl StackTrait<N>, _registers: &mut impl RegistersTrait<N>) -> Result<()> {
bail!("Instruction '{}' is currently only supported in finalize", Self::opcode());
}
#[inline]
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
_stack: &impl StackTrait<N>,
_registers: &mut impl RegistersCircuit<N, A>,
) -> Result<()> {
bail!("Instruction '{}' is currently only supported in finalize", Self::opcode());
}
#[inline]
pub fn finalize(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
if self.operands.len() != 4 {
bail!("Instruction '{}' expects 4 operands, found {} operands", Self::opcode(), self.operands.len())
}
let verifying_key = registers.load(stack, &self.operands[0])?;
let varuna_version = registers.load(stack, &self.operands[1])?;
let inputs = registers.load(stack, &self.operands[2])?;
let proof = registers.load(stack, &self.operands[3])?;
let _function_name = "snark.verify";
let output = evaluate_varuna_proof_internal(
SnarkVerifyVariant::new(VARIANT),
_function_name,
&verifying_key,
varuna_version,
&inputs,
&proof,
)?;
let output = Literal::Boolean(Boolean::new(output));
registers.store_literal(stack, &self.destination, output)
}
#[inline]
pub fn output_types(
&self,
_stack: &impl StackTrait<N>,
input_types: &[RegisterType<N>],
) -> Result<Vec<RegisterType<N>>> {
if input_types.len() != 4 {
bail!("Instruction '{}' expects 4 inputs, found {} inputs", Self::opcode(), input_types.len())
}
let variant = SnarkVerifyVariant::new(VARIANT);
let (result, expected_type, num_vks) = match variant {
SnarkVerifyVariant::Varuna => (check_nd_array_type(&input_types[0], LiteralType::U8, 1), "a byte array", 1),
SnarkVerifyVariant::VarunaBatch => {
let num_vks = match &input_types[0] {
RegisterType::Plaintext(PlaintextType::Array(array_type)) => **array_type.length(),
_ => 0,
};
(check_nd_array_type(&input_types[0], LiteralType::U8, 2), "a 2-dimensional byte array", num_vks)
}
};
if !result {
bail!(
"Instruction '{}' expects the first input to be {}. Found input of type '{}'",
Self::opcode(),
expected_type,
&input_types[0]
);
}
ensure!(
matches!(input_types[1], RegisterType::Plaintext(PlaintextType::Literal(LiteralType::U8))),
"Instruction '{}' expects the second input to be a U8 literal. Found input of type '{}'",
Self::opcode(),
&input_types[1]
);
let (result, expected_type, num_circuits, num_instances) = match variant {
SnarkVerifyVariant::Varuna => {
(check_nd_array_type(&input_types[2], LiteralType::Field, 1), "an array of fields", 1, 1)
}
SnarkVerifyVariant::VarunaBatch => {
let (num_circuits, num_instances) = match &input_types[2] {
RegisterType::Plaintext(PlaintextType::Array(array_type)) => {
let num_circuits = **array_type.length();
let num_instances = match array_type.next_element_type() {
PlaintextType::Array(inner_array_type) => **inner_array_type.length() * num_circuits,
_ => bail!(
"Instruction '{}' expects the third input to be a 3-dimensional array of fields. Found input of type '{}'",
Self::opcode(),
&input_types[2]
),
};
(num_circuits, num_instances)
}
_ => (0, 0),
};
(
check_nd_array_type(&input_types[2], LiteralType::Field, 3),
"a 3-dimensional array of fields",
num_circuits,
num_instances,
)
}
};
if !result {
bail!(
"Instruction '{}' expects the third input to be {}. Found input of type '{}'",
Self::opcode(),
expected_type,
&input_types[2]
);
}
ensure!(
num_circuits == num_vks,
"Instruction '{}' expects the number of circuits ({num_circuits}) to match the number of verifying keys ({num_vks}).",
Self::opcode()
);
ensure!(
num_circuits <= MAX_SNARK_VERIFY_CIRCUITS,
"Instruction '{}' supports a maximum of {MAX_SNARK_VERIFY_CIRCUITS} batched circuits, found {num_circuits} circuits.",
Self::opcode()
);
ensure!(
num_instances <= MAX_SNARK_VERIFY_INSTANCES,
"Instruction '{}' supports a maximum of {MAX_SNARK_VERIFY_INSTANCES} batched instances, found {num_instances} instances.",
Self::opcode()
);
match &input_types[3] {
RegisterType::Plaintext(PlaintextType::Array(array_type))
if array_type.base_element_type() == &PlaintextType::Literal(LiteralType::U8) =>
{
}
_ => bail!(
"Instruction '{}' expects the fourth input to be a byte array. Found input of type '{}'",
Self::opcode(),
input_types[3]
),
}
Ok(vec![RegisterType::Plaintext(PlaintextType::Literal(LiteralType::Boolean))])
}
}
impl<N: Network, const VARIANT: u8> Parser for SnarkVerification<N, VARIANT> {
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
let (string, _) = tag(*Self::opcode())(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, first) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, second) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, third) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, fourth) = Operand::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, _) = tag("into")(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, destination) = Register::parse(string)?;
Ok((string, Self { operands: vec![first, second, third, fourth], destination }))
}
}
impl<N: Network, const VARIANT: u8> FromStr for SnarkVerification<N, VARIANT> {
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network, const VARIANT: u8> Debug for SnarkVerification<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network, const VARIANT: u8> Display for SnarkVerification<N, VARIANT> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if self.operands.len() != 4 {
return Err(fmt::Error);
}
write!(f, "{} ", Self::opcode())?;
self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))?;
write!(f, "into {}", self.destination)
}
}
impl<N: Network, const VARIANT: u8> FromBytes for SnarkVerification<N, VARIANT> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let mut operands = Vec::with_capacity(4);
for _ in 0..4 {
operands.push(Operand::read_le(&mut reader)?);
}
let destination = Register::read_le(&mut reader)?;
Ok(Self { operands, destination })
}
}
impl<N: Network, const VARIANT: u8> ToBytes for SnarkVerification<N, VARIANT> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
if self.operands.len() != 4 {
return Err(error(format!("The number of operands must be 4, found {}", self.operands.len())));
}
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
self.destination.write_le(&mut writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use console::network::MainnetV0;
type CurrentNetwork = MainnetV0;
#[test]
fn test_parse() {
let (string, is) = SnarkVerify::<CurrentNetwork>::parse("snark.verify r0 r1 r2 r3 into r4").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(is.operands.len(), 4, "The number of operands is incorrect");
assert_eq!(is.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(is.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
assert_eq!(is.operands[2], Operand::Register(Register::Locator(2)), "The third operand is incorrect");
assert_eq!(is.operands[3], Operand::Register(Register::Locator(3)), "The fourth operand is incorrect");
assert_eq!(is.destination, Register::Locator(4), "The destination register is incorrect");
let (string, is) = SnarkVerifyBatch::<CurrentNetwork>::parse("snark.verify.batch r0 r1 r2 r3 into r4").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(is.operands.len(), 4, "The number of operands is incorrect");
assert_eq!(is.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
assert_eq!(is.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
assert_eq!(is.operands[2], Operand::Register(Register::Locator(2)), "The third operand is incorrect");
assert_eq!(is.operands[3], Operand::Register(Register::Locator(3)), "The fourth operand is incorrect");
assert_eq!(is.destination, Register::Locator(4), "The destination register is incorrect");
}
}