Skip to main content

ark_r1cs_std/uint/
eq.rs

1use ark_ff::PrimeField;
2use ark_relations::gr1cs::SynthesisError;
3use ark_std::vec::Vec;
4
5use crate::{boolean::Boolean, eq::EqGadget};
6
7use super::*;
8
9impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
10    for UInt<N, T, ConstraintF>
11{
12    #[tracing::instrument(target = "gr1cs", skip(self, other))]
13    fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
14        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
15        let chunks_are_eq = self
16            .bits
17            .chunks(chunk_size)
18            .zip(other.bits.chunks(chunk_size))
19            .map(|(a, b)| {
20                let a = Boolean::le_bits_to_fp(a)?;
21                let b = Boolean::le_bits_to_fp(b)?;
22                a.is_eq(&b)
23            })
24            .collect::<Result<Vec<_>, _>>()?;
25        Boolean::kary_and(&chunks_are_eq)
26    }
27
28    #[tracing::instrument(target = "gr1cs", skip(self, other))]
29    fn conditional_enforce_equal(
30        &self,
31        other: &Self,
32        condition: &Boolean<ConstraintF>,
33    ) -> Result<(), SynthesisError> {
34        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
35        for (a, b) in self
36            .bits
37            .chunks(chunk_size)
38            .zip(other.bits.chunks(chunk_size))
39        {
40            let a = Boolean::le_bits_to_fp(a)?;
41            let b = Boolean::le_bits_to_fp(b)?;
42            a.conditional_enforce_equal(&b, condition)?;
43        }
44        Ok(())
45    }
46
47    #[tracing::instrument(target = "gr1cs", skip(self, other))]
48    fn conditional_enforce_not_equal(
49        &self,
50        other: &Self,
51        condition: &Boolean<ConstraintF>,
52    ) -> Result<(), SynthesisError> {
53        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
54        for (a, b) in self
55            .bits
56            .chunks(chunk_size)
57            .zip(other.bits.chunks(chunk_size))
58        {
59            let a = Boolean::le_bits_to_fp(a)?;
60            let b = Boolean::le_bits_to_fp(b)?;
61            a.conditional_enforce_not_equal(&b, condition)?;
62        }
63        Ok(())
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use crate::{
71        alloc::{AllocVar, AllocationMode},
72        prelude::EqGadget,
73        uint::test_utils::{run_binary_exhaustive, run_binary_random},
74        GR1CSVar,
75    };
76    use ark_ff::PrimeField;
77    use ark_test_curves::bls12_381::Fr;
78
79    fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
80        a: UInt<N, T, F>,
81        b: UInt<N, T, F>,
82    ) -> Result<(), SynthesisError> {
83        let cs = a.cs().or(b.cs());
84        let both_constant = a.is_constant() && b.is_constant();
85        let computed = a.is_eq(&b)?;
86        let expected_mode = if both_constant {
87            AllocationMode::Constant
88        } else {
89            AllocationMode::Witness
90        };
91        let expected =
92            Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
93        assert_eq!(expected.value(), computed.value());
94        expected.enforce_equal(&computed)?;
95        if !both_constant {
96            assert!(cs.is_satisfied().unwrap());
97        }
98        Ok(())
99    }
100
101    fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
102        a: UInt<N, T, F>,
103        b: UInt<N, T, F>,
104    ) -> Result<(), SynthesisError> {
105        let cs = a.cs().or(b.cs());
106        let both_constant = a.is_constant() && b.is_constant();
107        let computed = a.is_neq(&b)?;
108        let expected_mode = if both_constant {
109            AllocationMode::Constant
110        } else {
111            AllocationMode::Witness
112        };
113        let expected =
114            Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
115        assert_eq!(expected.value(), computed.value());
116        expected.enforce_equal(&computed)?;
117        if !both_constant {
118            assert!(cs.is_satisfied().unwrap());
119        }
120        Ok(())
121    }
122
123    #[test]
124    fn u8_eq() {
125        run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
126    }
127
128    #[test]
129    fn u16_eq() {
130        run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
131    }
132
133    #[test]
134    fn u32_eq() {
135        run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
136    }
137
138    #[test]
139    fn u64_eq() {
140        run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
141    }
142
143    #[test]
144    fn u128_eq() {
145        run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
146    }
147
148    #[test]
149    fn u8_neq() {
150        run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
151    }
152
153    #[test]
154    fn u16_neq() {
155        run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
156    }
157
158    #[test]
159    fn u32_neq() {
160        run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
161    }
162
163    #[test]
164    fn u64_neq() {
165        run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
166    }
167
168    #[test]
169    fn u128_neq() {
170        run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
171    }
172}