1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::vec::Vec;
4
5use crate::boolean::Boolean;
6use crate::eq::EqGadget;
7
8use super::*;
9
10impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
11 for UInt<N, T, ConstraintF>
12{
13 #[tracing::instrument(target = "r1cs", skip(self, other))]
14 fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
15 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
16 let chunks_are_eq = self
17 .bits
18 .chunks(chunk_size)
19 .zip(other.bits.chunks(chunk_size))
20 .map(|(a, b)| {
21 let a = Boolean::le_bits_to_fp(a)?;
22 let b = Boolean::le_bits_to_fp(b)?;
23 a.is_eq(&b)
24 })
25 .collect::<Result<Vec<_>, _>>()?;
26 Boolean::kary_and(&chunks_are_eq)
27 }
28
29 #[tracing::instrument(target = "r1cs", skip(self, other))]
30 fn conditional_enforce_equal(
31 &self,
32 other: &Self,
33 condition: &Boolean<ConstraintF>,
34 ) -> Result<(), SynthesisError> {
35 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
36 for (a, b) in self
37 .bits
38 .chunks(chunk_size)
39 .zip(other.bits.chunks(chunk_size))
40 {
41 let a = Boolean::le_bits_to_fp(a)?;
42 let b = Boolean::le_bits_to_fp(b)?;
43 a.conditional_enforce_equal(&b, condition)?;
44 }
45 Ok(())
46 }
47
48 #[tracing::instrument(target = "r1cs", skip(self, other))]
49 fn conditional_enforce_not_equal(
50 &self,
51 other: &Self,
52 condition: &Boolean<ConstraintF>,
53 ) -> Result<(), SynthesisError> {
54 let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
55 for (a, b) in self
56 .bits
57 .chunks(chunk_size)
58 .zip(other.bits.chunks(chunk_size))
59 {
60 let a = Boolean::le_bits_to_fp(a)?;
61 let b = Boolean::le_bits_to_fp(b)?;
62 a.conditional_enforce_not_equal(&b, condition)?;
63 }
64 Ok(())
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::{
72 alloc::{AllocVar, AllocationMode},
73 prelude::EqGadget,
74 uint::test_utils::{run_binary_exhaustive, run_binary_random},
75 R1CSVar,
76 };
77 use ark_ff::PrimeField;
78 use ark_test_curves::bls12_381::Fr;
79
80 fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
81 a: UInt<N, T, F>,
82 b: UInt<N, T, F>,
83 ) -> Result<(), SynthesisError> {
84 let cs = a.cs().or(b.cs());
85 let both_constant = a.is_constant() && b.is_constant();
86 let computed = a.is_eq(&b)?;
87 let expected_mode = if both_constant {
88 AllocationMode::Constant
89 } else {
90 AllocationMode::Witness
91 };
92 let expected =
93 Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
94 assert_eq!(expected.value(), computed.value());
95 expected.enforce_equal(&computed)?;
96 if !both_constant {
97 assert!(cs.is_satisfied().unwrap());
98 }
99 Ok(())
100 }
101
102 fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
103 a: UInt<N, T, F>,
104 b: UInt<N, T, F>,
105 ) -> Result<(), SynthesisError> {
106 let cs = a.cs().or(b.cs());
107 let both_constant = a.is_constant() && b.is_constant();
108 let computed = a.is_neq(&b)?;
109 let expected_mode = if both_constant {
110 AllocationMode::Constant
111 } else {
112 AllocationMode::Witness
113 };
114 let expected =
115 Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
116 assert_eq!(expected.value(), computed.value());
117 expected.enforce_equal(&computed)?;
118 if !both_constant {
119 assert!(cs.is_satisfied().unwrap());
120 }
121 Ok(())
122 }
123
124 #[test]
125 fn u8_eq() {
126 run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
127 }
128
129 #[test]
130 fn u16_eq() {
131 run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
132 }
133
134 #[test]
135 fn u32_eq() {
136 run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
137 }
138
139 #[test]
140 fn u64_eq() {
141 run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
142 }
143
144 #[test]
145 fn u128_eq() {
146 run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
147 }
148
149 #[test]
150 fn u8_neq() {
151 run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
152 }
153
154 #[test]
155 fn u16_neq() {
156 run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
157 }
158
159 #[test]
160 fn u32_neq() {
161 run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
162 }
163
164 #[test]
165 fn u64_neq() {
166 run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
167 }
168
169 #[test]
170 fn u128_neq() {
171 run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
172 }
173}