ark_r1cs_std/uint/
select.rs1use super::*;
2use crate::select::CondSelectGadget;
3
4impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
5 for UInt<N, T, ConstraintF>
6{
7 #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))]
8 fn conditionally_select(
9 cond: &Boolean<ConstraintF>,
10 true_value: &Self,
11 false_value: &Self,
12 ) -> Result<Self, SynthesisError> {
13 let selected_bits = true_value
14 .bits
15 .iter()
16 .zip(&false_value.bits)
17 .map(|(t, f)| cond.select(t, f));
18 let mut bits = [Boolean::FALSE; N];
19 for (result, new) in bits.iter_mut().zip(selected_bits) {
20 *result = new?;
21 }
22
23 let value = cond.value().ok().and_then(|cond| {
24 if cond {
25 true_value.value().ok()
26 } else {
27 false_value.value().ok()
28 }
29 });
30 Ok(Self { bits, value })
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use super::*;
37 use crate::{
38 alloc::{AllocVar, AllocationMode},
39 prelude::EqGadget,
40 uint::test_utils::{run_binary_exhaustive, run_binary_random},
41 };
42 use ark_ff::PrimeField;
43 use ark_test_curves::bls12_381::Fr;
44
45 fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
46 a: UInt<N, T, F>,
47 b: UInt<N, T, F>,
48 ) -> Result<(), SynthesisError> {
49 let cs = a.cs().or(b.cs());
50 let both_constant = a.is_constant() && b.is_constant();
51 let expected_mode = if both_constant {
52 AllocationMode::Constant
53 } else {
54 AllocationMode::Witness
55 };
56 for cond in [true, false] {
57 let expected = UInt::new_variable(
58 cs.clone(),
59 || Ok(if cond { a.value()? } else { b.value()? }),
60 expected_mode,
61 )?;
62 let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
63 let computed = cond.select(&a, &b)?;
64
65 assert_eq!(expected.value(), computed.value());
66 expected.enforce_equal(&computed)?;
67 if !both_constant {
68 assert!(cs.is_satisfied().unwrap());
69 }
70 }
71 Ok(())
72 }
73
74 #[test]
75 fn u8_select() {
76 run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
77 }
78
79 #[test]
80 fn u16_select() {
81 run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
82 }
83
84 #[test]
85 fn u32_select() {
86 run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
87 }
88
89 #[test]
90 fn u64_select() {
91 run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
92 }
93
94 #[test]
95 fn u128_select() {
96 run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
97 }
98}