ark-r1cs-std 0.6.0

A standard library for constraint system gadgets
Documentation
use ark_ff::PrimeField;
use ark_relations::gr1cs::SynthesisError;
use ark_std::vec::Vec;

use crate::{boolean::Boolean, eq::EqGadget};

use super::*;

impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
    for UInt<N, T, ConstraintF>
{
    #[tracing::instrument(target = "gr1cs", skip(self, other))]
    fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
        let chunks_are_eq = self
            .bits
            .chunks(chunk_size)
            .zip(other.bits.chunks(chunk_size))
            .map(|(a, b)| {
                let a = Boolean::le_bits_to_fp(a)?;
                let b = Boolean::le_bits_to_fp(b)?;
                a.is_eq(&b)
            })
            .collect::<Result<Vec<_>, _>>()?;
        Boolean::kary_and(&chunks_are_eq)
    }

    #[tracing::instrument(target = "gr1cs", skip(self, other))]
    fn conditional_enforce_equal(
        &self,
        other: &Self,
        condition: &Boolean<ConstraintF>,
    ) -> Result<(), SynthesisError> {
        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
        for (a, b) in self
            .bits
            .chunks(chunk_size)
            .zip(other.bits.chunks(chunk_size))
        {
            let a = Boolean::le_bits_to_fp(a)?;
            let b = Boolean::le_bits_to_fp(b)?;
            a.conditional_enforce_equal(&b, condition)?;
        }
        Ok(())
    }

    #[tracing::instrument(target = "gr1cs", skip(self, other))]
    fn conditional_enforce_not_equal(
        &self,
        other: &Self,
        condition: &Boolean<ConstraintF>,
    ) -> Result<(), SynthesisError> {
        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
        for (a, b) in self
            .bits
            .chunks(chunk_size)
            .zip(other.bits.chunks(chunk_size))
        {
            let a = Boolean::le_bits_to_fp(a)?;
            let b = Boolean::le_bits_to_fp(b)?;
            a.conditional_enforce_not_equal(&b, condition)?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        alloc::{AllocVar, AllocationMode},
        prelude::EqGadget,
        uint::test_utils::{run_binary_exhaustive, run_binary_random},
        GR1CSVar,
    };
    use ark_ff::PrimeField;
    use ark_test_curves::bls12_381::Fr;

    fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
        a: UInt<N, T, F>,
        b: UInt<N, T, F>,
    ) -> Result<(), SynthesisError> {
        let cs = a.cs().or(b.cs());
        let both_constant = a.is_constant() && b.is_constant();
        let computed = a.is_eq(&b)?;
        let expected_mode = if both_constant {
            AllocationMode::Constant
        } else {
            AllocationMode::Witness
        };
        let expected =
            Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
        assert_eq!(expected.value(), computed.value());
        expected.enforce_equal(&computed)?;
        if !both_constant {
            assert!(cs.is_satisfied().unwrap());
        }
        Ok(())
    }

    fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
        a: UInt<N, T, F>,
        b: UInt<N, T, F>,
    ) -> Result<(), SynthesisError> {
        let cs = a.cs().or(b.cs());
        let both_constant = a.is_constant() && b.is_constant();
        let computed = a.is_neq(&b)?;
        let expected_mode = if both_constant {
            AllocationMode::Constant
        } else {
            AllocationMode::Witness
        };
        let expected =
            Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
        assert_eq!(expected.value(), computed.value());
        expected.enforce_equal(&computed)?;
        if !both_constant {
            assert!(cs.is_satisfied().unwrap());
        }
        Ok(())
    }

    #[test]
    fn u8_eq() {
        run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
    }

    #[test]
    fn u16_eq() {
        run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
    }

    #[test]
    fn u32_eq() {
        run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
    }

    #[test]
    fn u64_eq() {
        run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
    }

    #[test]
    fn u128_eq() {
        run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
    }

    #[test]
    fn u8_neq() {
        run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
    }

    #[test]
    fn u16_neq() {
        run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
    }

    #[test]
    fn u32_neq() {
        run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
    }

    #[test]
    fn u64_neq() {
        run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
    }

    #[test]
    fn u128_neq() {
        run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
    }
}