use dusk_curves::bls12_381::BlsScalar;
use dusk_plonk::prelude::*;
use dusk_safe::Safe;
use super::Hades;
use crate::hades::round_constants::ROUNDS;
use crate::hades::{MDS_MATRIX, ROUND_CONSTANTS, WIDTH};
pub(crate) struct GadgetPermutation<'a> {
composer: &'a mut Composer,
}
impl<'a> GadgetPermutation<'a> {
pub fn new(composer: &'a mut Composer) -> Self {
Self { composer }
}
}
impl<'a> Safe<Witness, WIDTH> for GadgetPermutation<'a> {
fn permute(&mut self, state: &mut [Witness; WIDTH]) {
self.perm(state);
}
fn tag(&mut self, input: &[u8]) -> Witness {
let tag = BlsScalar::hash_to_scalar(input);
self.composer.append_constant(tag)
}
fn add(&mut self, right: &Witness, left: &Witness) -> Witness {
let constraint = Constraint::new().left(1).a(*left).right(1).b(*right);
self.composer.gate_add(constraint)
}
}
impl<'a> Hades<Witness> for GadgetPermutation<'a> {
fn add_round_constants(
&mut self,
round: usize,
state: &mut [Witness; WIDTH],
) {
if round == 0 {
state.iter_mut().enumerate().for_each(|(i, w)| {
let constant = ROUND_CONSTANTS[0][i];
let constraint =
Constraint::new().left(1).a(*w).constant(constant);
*w = self.composer.gate_add(constraint);
});
}
}
fn quintic_s_box(&mut self, value: &mut Witness) {
let constraint = Constraint::new().mult(1).a(*value).b(*value);
let v2 = self.composer.gate_mul(constraint);
let constraint = Constraint::new().mult(1).a(v2).b(v2);
let v4 = self.composer.gate_mul(constraint);
let constraint = Constraint::new().mult(1).a(v4).b(*value);
*value = self.composer.gate_mul(constraint);
}
fn mul_matrix(&mut self, round: usize, state: &mut [Witness; WIDTH]) {
let mut result = [Composer::ZERO; WIDTH];
for j in 0..WIDTH {
let c = match round + 1 < ROUNDS {
true => ROUND_CONSTANTS[round + 1][j],
false => BlsScalar::zero(),
};
let constraint = Constraint::new()
.left(MDS_MATRIX[j][0])
.a(state[0])
.right(MDS_MATRIX[j][1])
.b(state[1])
.fourth(MDS_MATRIX[j][2])
.d(state[2]);
result[j] = self.composer.gate_add(constraint);
let constraint = Constraint::new()
.left(MDS_MATRIX[j][3])
.a(state[3])
.right(MDS_MATRIX[j][4])
.b(state[4])
.fourth(1)
.d(result[j])
.constant(c);
result[j] = self.composer.gate_add(constraint);
}
state.copy_from_slice(&result);
}
}
#[cfg(feature = "encryption")]
impl dusk_safe::Encryption<Witness, WIDTH> for GadgetPermutation<'_> {
fn subtract(&mut self, minuend: &Witness, subtrahend: &Witness) -> Witness {
let constraint = Constraint::new()
.left(1)
.a(*minuend)
.right(-BlsScalar::one())
.b(*subtrahend);
self.composer.gate_add(constraint)
}
fn is_equal(&mut self, lhs: &Witness, rhs: &Witness) -> bool {
self.composer.assert_equal(*lhs, *rhs);
true
}
}
#[cfg(test)]
mod tests {
use core::result::Result;
use ff::Field;
use rand::SeedableRng;
use rand::rngs::StdRng;
use super::*;
use crate::hades::ScalarPermutation;
#[derive(Default)]
struct TestCircuit {
i: [BlsScalar; WIDTH],
o: [BlsScalar; WIDTH],
}
impl Circuit for TestCircuit {
fn circuit(&self, composer: &mut Composer) -> Result<(), Error> {
let zero = Composer::ZERO;
let mut perm: [Witness; WIDTH] = [zero; WIDTH];
let mut i_wit: [Witness; WIDTH] = [zero; WIDTH];
self.i.iter().zip(i_wit.iter_mut()).for_each(|(i, w)| {
*w = composer.append_witness(*i);
});
let mut o_wit: [Witness; WIDTH] = [zero; WIDTH];
self.o.iter().zip(o_wit.iter_mut()).for_each(|(o, w)| {
*w = composer.append_witness(*o);
});
GadgetPermutation::new(composer).permute(&mut i_wit);
perm.copy_from_slice(&i_wit);
i_wit.iter().zip(o_wit.iter()).for_each(|(p, o)| {
composer.assert_equal(*p, *o);
});
Ok(())
}
}
fn hades() -> ([BlsScalar; WIDTH], [BlsScalar; WIDTH]) {
let mut input = [BlsScalar::zero(); WIDTH];
let mut rng = StdRng::seed_from_u64(0xbeef);
input
.iter_mut()
.for_each(|s| *s = BlsScalar::random(&mut rng));
let mut output = [BlsScalar::zero(); WIDTH];
output.copy_from_slice(&input);
ScalarPermutation::new().permute(&mut output);
(input, output)
}
fn setup() -> Result<(Prover, Verifier), Error> {
const CAPACITY: usize = 1 << 10;
let mut rng = StdRng::seed_from_u64(0xbeef);
let pp = PublicParameters::setup(CAPACITY, &mut rng)?;
let label = b"hades_gadget_tester";
Compiler::compile::<TestCircuit>(&pp, label)
}
#[test]
fn preimage() -> Result<(), Error> {
let (prover, verifier) = setup()?;
let (i, o) = hades();
let circuit = TestCircuit { i, o };
let mut rng = StdRng::seed_from_u64(0xbeef);
let (proof, public_inputs) = prover.prove(&mut rng, &circuit)?;
verifier.verify(&proof, &public_inputs)?;
Ok(())
}
#[test]
fn preimage_constant() -> Result<(), Error> {
let (prover, verifier) = setup()?;
let i = [BlsScalar::from(5000u64); WIDTH];
let mut o = [BlsScalar::from(5000u64); WIDTH];
ScalarPermutation::new().permute(&mut o);
let circuit = TestCircuit { i, o };
let mut rng = StdRng::seed_from_u64(0xbeef);
let (proof, public_inputs) = prover.prove(&mut rng, &circuit)?;
verifier.verify(&proof, &public_inputs)?;
Ok(())
}
#[test]
fn preimage_fails() -> Result<(), Error> {
let (prover, _) = setup()?;
let x_scalar = BlsScalar::from(31u64);
let mut i = [BlsScalar::zero(); WIDTH];
i[1] = x_scalar;
let mut o = [BlsScalar::from(31u64); WIDTH];
ScalarPermutation::new().permute(&mut o);
let circuit = TestCircuit { i, o };
let mut rng = StdRng::seed_from_u64(0xbeef);
assert!(
prover.prove(&mut rng, &circuit).is_err(),
"proving should fail since the circuit is invalid"
);
Ok(())
}
}