ark_r1cs_std/uint/
select.rs1use super::*;
2
3impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
4 for UInt<N, T, ConstraintF>
5{
6 #[tracing::instrument(target = "gr1cs", skip(cond, true_value, false_value))]
7 fn conditionally_select(
8 cond: &Boolean<ConstraintF>,
9 true_value: &Self,
10 false_value: &Self,
11 ) -> Result<Self, SynthesisError> {
12 let selected_bits = true_value
13 .bits
14 .iter()
15 .zip(&false_value.bits)
16 .map(|(t, f)| cond.select(t, f));
17 let mut bits = [Boolean::FALSE; N];
18 for (result, new) in bits.iter_mut().zip(selected_bits) {
19 *result = new?;
20 }
21
22 let value = cond.value().ok().and_then(|cond| {
23 if cond {
24 true_value.value().ok()
25 } else {
26 false_value.value().ok()
27 }
28 });
29 Ok(Self { bits, value })
30 }
31}
32
33#[cfg(test)]
34mod tests {
35 use super::*;
36 use crate::uint::test_utils::{run_binary_exhaustive, run_binary_random};
37 use ark_test_curves::bls12_381::Fr;
38
39 fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
40 a: UInt<N, T, F>,
41 b: UInt<N, T, F>,
42 ) -> Result<(), SynthesisError> {
43 let cs = a.cs().or(b.cs());
44 let both_constant = a.is_constant() && b.is_constant();
45 let expected_mode = if both_constant {
46 AllocationMode::Constant
47 } else {
48 AllocationMode::Witness
49 };
50 for cond in [true, false] {
51 let expected = UInt::new_variable(
52 cs.clone(),
53 || Ok(if cond { a.value()? } else { b.value()? }),
54 expected_mode,
55 )?;
56 let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
57 let computed = cond.select(&a, &b)?;
58
59 assert_eq!(expected.value(), computed.value());
60 expected.enforce_equal(&computed)?;
61 if !both_constant {
62 assert!(cs.is_satisfied().unwrap());
63 }
64 }
65 Ok(())
66 }
67
68 #[test]
69 fn u8_select() {
70 run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
71 }
72
73 #[test]
74 fn u16_select() {
75 run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
76 }
77
78 #[test]
79 fn u32_select() {
80 run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
81 }
82
83 #[test]
84 fn u64_select() {
85 run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
86 }
87
88 #[test]
89 fn u128_select() {
90 run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
91 }
92}