1use ark_ff::PrimeField;
2use ark_relations::gr1cs::SynthesisError;
3use ark_std::vec::Vec;
4
5use crate::{boolean::Boolean, eq::EqGadget};
6
7use super::*;
8
9impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
10 for UInt<N, T, ConstraintF>
11{
12 #[tracing::instrument(target = "gr1cs", skip(self, other))]
13 fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
14 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
15 let chunks_are_eq = self
16 .bits
17 .chunks(chunk_size)
18 .zip(other.bits.chunks(chunk_size))
19 .map(|(a, b)| {
20 let a = Boolean::le_bits_to_fp(a)?;
21 let b = Boolean::le_bits_to_fp(b)?;
22 a.is_eq(&b)
23 })
24 .collect::<Result<Vec<_>, _>>()?;
25 Boolean::kary_and(&chunks_are_eq)
26 }
27
28 #[tracing::instrument(target = "gr1cs", skip(self, other))]
29 fn conditional_enforce_equal(
30 &self,
31 other: &Self,
32 condition: &Boolean<ConstraintF>,
33 ) -> Result<(), SynthesisError> {
34 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
35 for (a, b) in self
36 .bits
37 .chunks(chunk_size)
38 .zip(other.bits.chunks(chunk_size))
39 {
40 let a = Boolean::le_bits_to_fp(a)?;
41 let b = Boolean::le_bits_to_fp(b)?;
42 a.conditional_enforce_equal(&b, condition)?;
43 }
44 Ok(())
45 }
46
47 #[tracing::instrument(target = "gr1cs", skip(self, other))]
48 fn conditional_enforce_not_equal(
49 &self,
50 other: &Self,
51 condition: &Boolean<ConstraintF>,
52 ) -> Result<(), SynthesisError> {
53 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
54 for (a, b) in self
55 .bits
56 .chunks(chunk_size)
57 .zip(other.bits.chunks(chunk_size))
58 {
59 let a = Boolean::le_bits_to_fp(a)?;
60 let b = Boolean::le_bits_to_fp(b)?;
61 a.conditional_enforce_not_equal(&b, condition)?;
62 }
63 Ok(())
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use crate::{
71 alloc::{AllocVar, AllocationMode},
72 prelude::EqGadget,
73 uint::test_utils::{run_binary_exhaustive, run_binary_random},
74 GR1CSVar,
75 };
76 use ark_ff::PrimeField;
77 use ark_test_curves::bls12_381::Fr;
78
79 fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
80 a: UInt<N, T, F>,
81 b: UInt<N, T, F>,
82 ) -> Result<(), SynthesisError> {
83 let cs = a.cs().or(b.cs());
84 let both_constant = a.is_constant() && b.is_constant();
85 let computed = a.is_eq(&b)?;
86 let expected_mode = if both_constant {
87 AllocationMode::Constant
88 } else {
89 AllocationMode::Witness
90 };
91 let expected =
92 Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
93 assert_eq!(expected.value(), computed.value());
94 expected.enforce_equal(&computed)?;
95 if !both_constant {
96 assert!(cs.is_satisfied().unwrap());
97 }
98 Ok(())
99 }
100
101 fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
102 a: UInt<N, T, F>,
103 b: UInt<N, T, F>,
104 ) -> Result<(), SynthesisError> {
105 let cs = a.cs().or(b.cs());
106 let both_constant = a.is_constant() && b.is_constant();
107 let computed = a.is_neq(&b)?;
108 let expected_mode = if both_constant {
109 AllocationMode::Constant
110 } else {
111 AllocationMode::Witness
112 };
113 let expected =
114 Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
115 assert_eq!(expected.value(), computed.value());
116 expected.enforce_equal(&computed)?;
117 if !both_constant {
118 assert!(cs.is_satisfied().unwrap());
119 }
120 Ok(())
121 }
122
123 #[test]
124 fn u8_eq() {
125 run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
126 }
127
128 #[test]
129 fn u16_eq() {
130 run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
131 }
132
133 #[test]
134 fn u32_eq() {
135 run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
136 }
137
138 #[test]
139 fn u64_eq() {
140 run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
141 }
142
143 #[test]
144 fn u128_eq() {
145 run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
146 }
147
148 #[test]
149 fn u8_neq() {
150 run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
151 }
152
153 #[test]
154 fn u16_neq() {
155 run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
156 }
157
158 #[test]
159 fn u32_neq() {
160 run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
161 }
162
163 #[test]
164 fn u64_neq() {
165 run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
166 }
167
168 #[test]
169 fn u128_neq() {
170 run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
171 }
172}