ark_r1cs_std/uint/
eq.rs

1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::vec::Vec;
4
5use crate::boolean::Boolean;
6use crate::eq::EqGadget;
7
8use super::*;
9
10impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
11    for UInt<N, T, ConstraintF>
12{
13    #[tracing::instrument(target = "r1cs", skip(self, other))]
14    fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
15        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
16        let chunks_are_eq = self
17            .bits
18            .chunks(chunk_size)
19            .zip(other.bits.chunks(chunk_size))
20            .map(|(a, b)| {
21                let a = Boolean::le_bits_to_fp(a)?;
22                let b = Boolean::le_bits_to_fp(b)?;
23                a.is_eq(&b)
24            })
25            .collect::<Result<Vec<_>, _>>()?;
26        Boolean::kary_and(&chunks_are_eq)
27    }
28
29    #[tracing::instrument(target = "r1cs", skip(self, other))]
30    fn conditional_enforce_equal(
31        &self,
32        other: &Self,
33        condition: &Boolean<ConstraintF>,
34    ) -> Result<(), SynthesisError> {
35        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
36        for (a, b) in self
37            .bits
38            .chunks(chunk_size)
39            .zip(other.bits.chunks(chunk_size))
40        {
41            let a = Boolean::le_bits_to_fp(a)?;
42            let b = Boolean::le_bits_to_fp(b)?;
43            a.conditional_enforce_equal(&b, condition)?;
44        }
45        Ok(())
46    }
47
48    #[tracing::instrument(target = "r1cs", skip(self, other))]
49    fn conditional_enforce_not_equal(
50        &self,
51        other: &Self,
52        condition: &Boolean<ConstraintF>,
53    ) -> Result<(), SynthesisError> {
54        let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
55        for (a, b) in self
56            .bits
57            .chunks(chunk_size)
58            .zip(other.bits.chunks(chunk_size))
59        {
60            let a = Boolean::le_bits_to_fp(a)?;
61            let b = Boolean::le_bits_to_fp(b)?;
62            a.conditional_enforce_not_equal(&b, condition)?;
63        }
64        Ok(())
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::{
72        alloc::{AllocVar, AllocationMode},
73        prelude::EqGadget,
74        uint::test_utils::{run_binary_exhaustive, run_binary_random},
75        R1CSVar,
76    };
77    use ark_ff::PrimeField;
78    use ark_test_curves::bls12_381::Fr;
79
80    fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
81        a: UInt<N, T, F>,
82        b: UInt<N, T, F>,
83    ) -> Result<(), SynthesisError> {
84        let cs = a.cs().or(b.cs());
85        let both_constant = a.is_constant() && b.is_constant();
86        let computed = a.is_eq(&b)?;
87        let expected_mode = if both_constant {
88            AllocationMode::Constant
89        } else {
90            AllocationMode::Witness
91        };
92        let expected =
93            Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
94        assert_eq!(expected.value(), computed.value());
95        expected.enforce_equal(&computed)?;
96        if !both_constant {
97            assert!(cs.is_satisfied().unwrap());
98        }
99        Ok(())
100    }
101
102    fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
103        a: UInt<N, T, F>,
104        b: UInt<N, T, F>,
105    ) -> Result<(), SynthesisError> {
106        let cs = a.cs().or(b.cs());
107        let both_constant = a.is_constant() && b.is_constant();
108        let computed = a.is_neq(&b)?;
109        let expected_mode = if both_constant {
110            AllocationMode::Constant
111        } else {
112            AllocationMode::Witness
113        };
114        let expected =
115            Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
116        assert_eq!(expected.value(), computed.value());
117        expected.enforce_equal(&computed)?;
118        if !both_constant {
119            assert!(cs.is_satisfied().unwrap());
120        }
121        Ok(())
122    }
123
124    #[test]
125    fn u8_eq() {
126        run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
127    }
128
129    #[test]
130    fn u16_eq() {
131        run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
132    }
133
134    #[test]
135    fn u32_eq() {
136        run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
137    }
138
139    #[test]
140    fn u64_eq() {
141        run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
142    }
143
144    #[test]
145    fn u128_eq() {
146        run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
147    }
148
149    #[test]
150    fn u8_neq() {
151        run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
152    }
153
154    #[test]
155    fn u16_neq() {
156        run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
157    }
158
159    #[test]
160    fn u32_neq() {
161        run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
162    }
163
164    #[test]
165    fn u64_neq() {
166        run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
167    }
168
169    #[test]
170    fn u128_neq() {
171        run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
172    }
173}